├── BBSNet_test.py ├── BBSNet_train.py ├── Images ├── backbone_result.png ├── detailed-comparisons.png ├── pipeline.png └── resultmap.png ├── LICENSE ├── README.md ├── data.py ├── models ├── BBSNet_model.py ├── ResNet.py └── __pycache__ │ ├── BBSNet_model.cpython-37.pyc │ ├── ResNet.cpython-36.pyc │ └── ResNet.cpython-37.pyc ├── options.py └── utils.py /BBSNet_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import sys 4 | sys.path.append('./models') 5 | import numpy as np 6 | import os, argparse 7 | import cv2 8 | from models.BBSNet_model import BBSNet 9 | from data import test_dataset 10 | 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--testsize', type=int, default=352, help='testing size') 14 | parser.add_argument('--gpu_id', type=str, default='0', help='select gpu id') 15 | parser.add_argument('--test_path',type=str,default='../BBS_dataset/RGBD_for_test/',help='test dataset path') 16 | opt = parser.parse_args() 17 | 18 | dataset_path = opt.test_path 19 | 20 | #set device for test 21 | if opt.gpu_id=='0': 22 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 23 | print('USE GPU 0') 24 | elif opt.gpu_id=='1': 25 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 26 | print('USE GPU 1') 27 | 28 | #load the model 29 | model = BBSNet() 30 | #Large epoch size may not generalize well. You can choose a good model to load according to the log file and pth files saved in ('./BBSNet_cpts/') when training. 31 | model.load_state_dict(torch.load('./model_pths/BBSNet.pth')) 32 | model.cuda() 33 | model.eval() 34 | 35 | #test 36 | test_datasets = ['NJU2K','NLPR','STERE', 'DES', 'SSD','LFSD','SIP'] 37 | for dataset in test_datasets: 38 | save_path = './test_maps/BBSNet/ResNet50/' + dataset + '/' 39 | if not os.path.exists(save_path): 40 | os.makedirs(save_path) 41 | image_root = dataset_path + dataset + '/RGB/' 42 | gt_root = dataset_path + dataset + '/GT/' 43 | depth_root=dataset_path +dataset +'/depth/' 44 | test_loader = test_dataset(image_root, gt_root,depth_root, opt.testsize) 45 | for i in range(test_loader.size): 46 | image, gt,depth, name, image_for_post = test_loader.load_data() 47 | gt = np.asarray(gt, np.float32) 48 | gt /= (gt.max() + 1e-8) 49 | image = image.cuda() 50 | depth = depth.cuda() 51 | _,res = model(image,depth) 52 | res = F.upsample(res, size=gt.shape, mode='bilinear', align_corners=False) 53 | res = res.sigmoid().data.cpu().numpy().squeeze() 54 | res = (res - res.min()) / (res.max() - res.min() + 1e-8) 55 | print('save img to: ',save_path+name) 56 | cv2.imwrite(save_path+name,res*255) 57 | print('Test Done!') 58 | -------------------------------------------------------------------------------- /BBSNet_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn.functional as F 4 | import sys 5 | sys.path.append('./models') 6 | import numpy as np 7 | from datetime import datetime 8 | from torchvision.utils import make_grid 9 | from models.BBSNet_model import BBSNet 10 | from data import get_loader,test_dataset 11 | from utils import clip_gradient, adjust_lr 12 | from tensorboardX import SummaryWriter 13 | import logging 14 | import torch.backends.cudnn as cudnn 15 | from options import opt 16 | 17 | #set the device for training 18 | if opt.gpu_id=='0': 19 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 20 | print('USE GPU 0') 21 | elif opt.gpu_id=='1': 22 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 23 | print('USE GPU 1') 24 | cudnn.benchmark = True 25 | 26 | #build the model 27 | model = BBSNet() 28 | if(opt.load is not None): 29 | model.load_state_dict(torch.load(opt.load)) 30 | print('load model from ',opt.load) 31 | 32 | model.cuda() 33 | params = model.parameters() 34 | optimizer = torch.optim.Adam(params, opt.lr) 35 | 36 | #set the path 37 | image_root = opt.rgb_root 38 | gt_root = opt.gt_root 39 | depth_root=opt.depth_root 40 | test_image_root=opt.test_rgb_root 41 | test_gt_root=opt.test_gt_root 42 | test_depth_root=opt.test_depth_root 43 | save_path=opt.save_path 44 | 45 | if not os.path.exists(save_path): 46 | os.makedirs(save_path) 47 | 48 | #load data 49 | print('load data...') 50 | train_loader = get_loader(image_root, gt_root,depth_root, batchsize=opt.batchsize, trainsize=opt.trainsize) 51 | test_loader = test_dataset(test_image_root, test_gt_root,test_depth_root, opt.trainsize) 52 | total_step = len(train_loader) 53 | 54 | logging.basicConfig(filename=save_path+'log.log',format='[%(asctime)s-%(filename)s-%(levelname)s:%(message)s]', level = logging.INFO,filemode='a',datefmt='%Y-%m-%d %I:%M:%S %p') 55 | logging.info("BBSNet-Train") 56 | logging.info("Config") 57 | logging.info('epoch:{};lr:{};batchsize:{};trainsize:{};clip:{};decay_rate:{};load:{};save_path:{};decay_epoch:{}'.format(opt.epoch,opt.lr,opt.batchsize,opt.trainsize,opt.clip,opt.decay_rate,opt.load,save_path,opt.decay_epoch)) 58 | 59 | #set loss function 60 | CE = torch.nn.BCEWithLogitsLoss() 61 | 62 | step=0 63 | writer = SummaryWriter(save_path+'summary') 64 | best_mae=1 65 | best_epoch=0 66 | 67 | #train function 68 | def train(train_loader, model, optimizer, epoch,save_path): 69 | global step 70 | model.train() 71 | loss_all=0 72 | epoch_step=0 73 | try: 74 | for i, (images, gts, depths) in enumerate(train_loader, start=1): 75 | optimizer.zero_grad() 76 | 77 | images = images.cuda() 78 | gts = gts.cuda() 79 | depths=depths.cuda() 80 | 81 | s1,s2 = model(images,depths) 82 | loss1 = CE(s1, gts) 83 | loss2 =CE(s2,gts) 84 | loss = loss1+loss2 85 | loss.backward() 86 | 87 | clip_gradient(optimizer, opt.clip) 88 | optimizer.step() 89 | step+=1 90 | epoch_step+=1 91 | loss_all+=loss.data 92 | if i % 100 == 0 or i == total_step or i==1: 93 | print('{} Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], Loss1: {:.4f} Loss2: {:0.4f}'. 94 | format(datetime.now(), epoch, opt.epoch, i, total_step, loss1.data, loss2.data)) 95 | logging.info('#TRAIN#:Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], Loss1: {:.4f} Loss2: {:0.4f}'. 96 | format( epoch, opt.epoch, i, total_step, loss1.data, loss2.data)) 97 | writer.add_scalar('Loss', loss.data, global_step=step) 98 | grid_image = make_grid(images[0].clone().cpu().data, 1, normalize=True) 99 | writer.add_image('RGB', grid_image, step) 100 | grid_image = make_grid(gts[0].clone().cpu().data, 1, normalize=True) 101 | writer.add_image('Ground_truth', grid_image, step) 102 | res=s1[0].clone() 103 | res = res.sigmoid().data.cpu().numpy().squeeze() 104 | res = (res - res.min()) / (res.max() - res.min() + 1e-8) 105 | writer.add_image('s1', torch.tensor(res), step,dataformats='HW') 106 | res=s2[0].clone() 107 | res = res.sigmoid().data.cpu().numpy().squeeze() 108 | res = (res - res.min()) / (res.max() - res.min() + 1e-8) 109 | writer.add_image('s2', torch.tensor(res), step,dataformats='HW') 110 | 111 | loss_all/=epoch_step 112 | logging.info('#TRAIN#:Epoch [{:03d}/{:03d}], Loss_AVG: {:.4f}'.format( epoch, opt.epoch, loss_all)) 113 | writer.add_scalar('Loss-epoch', loss_all, global_step=epoch) 114 | if (epoch) % 5 == 0: 115 | torch.save(model.state_dict(), save_path+'BBSNet_epoch_{}.pth'.format(epoch)) 116 | except KeyboardInterrupt: 117 | print('Keyboard Interrupt: save model and exit.') 118 | if not os.path.exists(save_path): 119 | os.makedirs(save_path) 120 | torch.save(model.state_dict(), save_path+'BBSNet_epoch_{}.pth'.format(epoch+1)) 121 | print('save checkpoints successfully!') 122 | raise 123 | 124 | #test function 125 | def test(test_loader,model,epoch,save_path): 126 | global best_mae,best_epoch 127 | model.eval() 128 | with torch.no_grad(): 129 | mae_sum=0 130 | for i in range(test_loader.size): 131 | image, gt,depth, name,img_for_post = test_loader.load_data() 132 | gt = np.asarray(gt, np.float32) 133 | gt /= (gt.max() + 1e-8) 134 | image = image.cuda() 135 | depth = depth.cuda() 136 | _,res = model(image,depth) 137 | res = F.upsample(res, size=gt.shape, mode='bilinear', align_corners=False) 138 | res = res.sigmoid().data.cpu().numpy().squeeze() 139 | res = (res - res.min()) / (res.max() - res.min() + 1e-8) 140 | mae_sum+=np.sum(np.abs(res-gt))*1.0/(gt.shape[0]*gt.shape[1]) 141 | mae=mae_sum/test_loader.size 142 | writer.add_scalar('MAE', torch.tensor(mae), global_step=epoch) 143 | print('Epoch: {} MAE: {} #### bestMAE: {} bestEpoch: {}'.format(epoch,mae,best_mae,best_epoch)) 144 | if epoch==1: 145 | best_mae=mae 146 | else: 147 | if mae 6 |
7 | 8 | Figure 1: Pipeline of the BBS-Net. 9 | 10 |

11 | 12 | ## 1. Requirements 13 | 14 | Python 3.7, Pytorch 0.4.0+, Cuda 10.0, TensorboardX 2.0, opencv-python 15 | 16 | ## 2. Data Preparation 17 | 18 | - Download the raw data from [Baidu Pan](https://pan.baidu.com/s/1SxBjlTF4Tb74WjuDsRmM3w) [code: yiy1] or [Google Drive](https://drive.google.com/drive/folders/1gIMun9bM5JDrs98sLjXt7XoFCdvy1DXF?usp=sharing) and trained model (BBSNet.pth) from [Here](https://pan.baidu.com/s/1Fn-Hvdou4DDWcgeTtx081g) [code: dwcp]. Then put them under the following directory: 19 | 20 | -BBS_dataset\ 21 | -RGBD_for_train\ 22 | -RGBD_for_test\ 23 | -test_in_train\ 24 | -BBSNet 25 | -models\ 26 | -model_pths\ 27 | -BBSNet.pth 28 | ... 29 | 30 | - Note that the depth maps of the raw data above are not normalized. If you train and test using the normalized depth maps, the performance will be improved. 31 | 32 | ## 3. Training & Testing 33 | 34 | - Train the BBSNet: 35 | 36 | `python BBSNet_train.py --batchsize 10 --gpu_id 0 ` 37 | 38 | - Test the BBSNet: 39 | 40 | `python BBSNet_test.py --gpu_id 0 ` 41 | 42 | The test maps will be saved to './test_maps/'. 43 | 44 | - Evaluate the result maps: 45 | 46 | You can evaluate the result maps using the tool in [Python_GPU Version](https://github.com/zyjwuyan/SOD_Evaluation_Metrics) or [Matlab Version](http://dpfan.net/d3netbenchmark/). 47 | 48 | - If you need the codes using VGG16 and VGG19 backbones, please send to the email (zhaiyingjier@163.com). Please provide your Name & Institution. Please note the code can be only used for research purpose. 49 | ## 4. Results 50 | ### 4.1 Qualitative Comparison 51 |

52 |
53 | 54 | Figure 2: Qualitative visual comparison of the proposed model versus 8 SOTA 55 | models. 56 | 57 |

58 |

59 |
60 | 61 | Table 1: Quantitative comparison of models using S-measure max F-measure, max E-measureand MAE scores on 7 datasets. 62 | 63 |

64 | 72 | 73 | ### 4.2 Results of multiple backbones 74 | 75 |

76 |
77 | 78 | Table 2: Performance comparison using different backbones. 79 | 80 |

81 | 82 | ### 4.3 Download 83 | - Test maps of the above datasets (ResNet50 backbone) can be download from [here](https://pan.baidu.com/s/1O-AhThLWEDVgQiPhX3QVYw) [code: qgai ]. 84 | - Test maps of vgg16 and vgg19 backbones of our model can be download from [here](https://pan.baidu.com/s/1_hG3hC2Fpt1cbAWuPrHEPA) [code: zuds ]. 85 | - Test maps of DUT-RGBD dataset (using the proposed training-test splits of [DMRA](https://openaccess.thecvf.com/content_ICCV_2019/papers/Piao_Depth-Induced_Multi-Scale_Recurrent_Attention_Network_for_Saliency_Detection_ICCV_2019_paper.pdf)) can be downloaded from [here](https://pan.baidu.com/s/15oc_-nwEKNiU1C9WRho5lg) [code: 3nme ]. 86 | ## 5. Citation 87 | 88 | Please cite the following paper if you use this repository in your reseach. 89 | 90 | @inproceedings{fan2020bbsnet, 91 | title={BBS-Net: RGB-D Salient Object Detection with a Bifurcated Backbone Strategy Network}, 92 | author={Fan, Deng-Ping and Zhai, Yingjie and Borji, Ali and Yang, Jufeng and Shao, Ling}, 93 | booktitle={ECCV}, 94 | year={2020} 95 | } 96 | 97 | - For more information about BBS-Net, please read the [Manuscript (PDF)](https://arxiv.org/pdf/2007.02713.pdf) ([Chinese version](https://pan.baidu.com/s/1zxni7QjBiewwA1Q-m7Cqfg)[code:0r4a]). 98 | - Note that there is a wrong in the Fig.3 (c) of the ECCV version. The second and third BConv3 in the first column of the figure should be BConv5 and BConv7 respectively. 99 | 100 | ## 6. Benchmark RGB-D SOD 101 | 102 | The complete RGB-D SOD benchmark can be found in this page: 103 | 104 | http://dpfan.net/d3netbenchmark/ 105 | 106 | ## 7. Acknowledgement 107 | We implement this project based on the code of ‘Cascaded Partial Decoder for Fast and Accurate Salient Object Detection, CVPR2019’ proposed by Wu et al. 108 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import torch.utils.data as data 4 | import torchvision.transforms as transforms 5 | import random 6 | import numpy as np 7 | from PIL import ImageEnhance 8 | 9 | #several data augumentation strategies 10 | def cv_random_flip(img, label,depth): 11 | flip_flag = random.randint(0, 1) 12 | # flip_flag2= random.randint(0,1) 13 | #left right flip 14 | if flip_flag == 1: 15 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 16 | label = label.transpose(Image.FLIP_LEFT_RIGHT) 17 | depth = depth.transpose(Image.FLIP_LEFT_RIGHT) 18 | #top bottom flip 19 | # if flip_flag2==1: 20 | # img = img.transpose(Image.FLIP_TOP_BOTTOM) 21 | # label = label.transpose(Image.FLIP_TOP_BOTTOM) 22 | # depth = depth.transpose(Image.FLIP_TOP_BOTTOM) 23 | return img, label, depth 24 | def randomCrop(image, label,depth): 25 | border=30 26 | image_width = image.size[0] 27 | image_height = image.size[1] 28 | crop_win_width = np.random.randint(image_width-border , image_width) 29 | crop_win_height = np.random.randint(image_height-border , image_height) 30 | random_region = ( 31 | (image_width - crop_win_width) >> 1, (image_height - crop_win_height) >> 1, (image_width + crop_win_width) >> 1, 32 | (image_height + crop_win_height) >> 1) 33 | return image.crop(random_region), label.crop(random_region),depth.crop(random_region) 34 | def randomRotation(image,label,depth): 35 | mode=Image.BICUBIC 36 | if random.random()>0.8: 37 | random_angle = np.random.randint(-15, 15) 38 | image=image.rotate(random_angle, mode) 39 | label=label.rotate(random_angle, mode) 40 | depth=depth.rotate(random_angle, mode) 41 | return image,label,depth 42 | def colorEnhance(image): 43 | bright_intensity=random.randint(5,15)/10.0 44 | image=ImageEnhance.Brightness(image).enhance(bright_intensity) 45 | contrast_intensity=random.randint(5,15)/10.0 46 | image=ImageEnhance.Contrast(image).enhance(contrast_intensity) 47 | color_intensity=random.randint(0,20)/10.0 48 | image=ImageEnhance.Color(image).enhance(color_intensity) 49 | sharp_intensity=random.randint(0,30)/10.0 50 | image=ImageEnhance.Sharpness(image).enhance(sharp_intensity) 51 | return image 52 | def randomGaussian(image, mean=0.1, sigma=0.35): 53 | def gaussianNoisy(im, mean=mean, sigma=sigma): 54 | for _i in range(len(im)): 55 | im[_i] += random.gauss(mean, sigma) 56 | return im 57 | img = np.asarray(image) 58 | width, height = img.shape 59 | img = gaussianNoisy(img[:].flatten(), mean, sigma) 60 | img = img.reshape([width, height]) 61 | return Image.fromarray(np.uint8(img)) 62 | def randomPeper(img): 63 | 64 | img=np.array(img) 65 | noiseNum=int(0.0015*img.shape[0]*img.shape[1]) 66 | for i in range(noiseNum): 67 | 68 | randX=random.randint(0,img.shape[0]-1) 69 | 70 | randY=random.randint(0,img.shape[1]-1) 71 | 72 | if random.randint(0,1)==0: 73 | 74 | img[randX,randY]=0 75 | 76 | else: 77 | 78 | img[randX,randY]=255 79 | return Image.fromarray(img) 80 | 81 | # dataset for training 82 | #The current loader is not using the normalized depth maps for training and test. If you use the normalized depth maps 83 | #(e.g., 0 represents background and 1 represents foreground.), the performance will be further improved. 84 | class SalObjDataset(data.Dataset): 85 | def __init__(self, image_root, gt_root,depth_root, trainsize): 86 | self.trainsize = trainsize 87 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')] 88 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg') 89 | or f.endswith('.png')] 90 | self.depths=[depth_root + f for f in os.listdir(depth_root) if f.endswith('.bmp') 91 | or f.endswith('.png')] 92 | self.images = sorted(self.images) 93 | self.gts = sorted(self.gts) 94 | self.depths=sorted(self.depths) 95 | self.filter_files() 96 | self.size = len(self.images) 97 | self.img_transform = transforms.Compose([ 98 | transforms.Resize((self.trainsize, self.trainsize)), 99 | transforms.ToTensor(), 100 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 101 | self.gt_transform = transforms.Compose([ 102 | transforms.Resize((self.trainsize, self.trainsize)), 103 | transforms.ToTensor()]) 104 | self.depths_transform = transforms.Compose([transforms.Resize((self.trainsize, self.trainsize)),transforms.ToTensor()]) 105 | 106 | def __getitem__(self, index): 107 | image = self.rgb_loader(self.images[index]) 108 | gt = self.binary_loader(self.gts[index]) 109 | depth=self.binary_loader(self.depths[index]) 110 | image,gt,depth =cv_random_flip(image,gt,depth) 111 | image,gt,depth=randomCrop(image, gt,depth) 112 | image,gt,depth=randomRotation(image, gt,depth) 113 | image=colorEnhance(image) 114 | # gt=randomGaussian(gt) 115 | gt=randomPeper(gt) 116 | image = self.img_transform(image) 117 | gt = self.gt_transform(gt) 118 | depth=self.depths_transform(depth) 119 | 120 | return image, gt, depth 121 | 122 | def filter_files(self): 123 | assert len(self.images) == len(self.gts) and len(self.gts)==len(self.images) 124 | images = [] 125 | gts = [] 126 | depths=[] 127 | for img_path, gt_path,depth_path in zip(self.images, self.gts, self.depths): 128 | img = Image.open(img_path) 129 | gt = Image.open(gt_path) 130 | depth= Image.open(depth_path) 131 | if img.size == gt.size and gt.size==depth.size: 132 | images.append(img_path) 133 | gts.append(gt_path) 134 | depths.append(depth_path) 135 | self.images = images 136 | self.gts = gts 137 | self.depths=depths 138 | 139 | def rgb_loader(self, path): 140 | with open(path, 'rb') as f: 141 | img = Image.open(f) 142 | return img.convert('RGB') 143 | 144 | def binary_loader(self, path): 145 | with open(path, 'rb') as f: 146 | img = Image.open(f) 147 | return img.convert('L') 148 | 149 | def resize(self, img, gt, depth): 150 | assert img.size == gt.size and gt.size==depth.size 151 | w, h = img.size 152 | if h < self.trainsize or w < self.trainsize: 153 | h = max(h, self.trainsize) 154 | w = max(w, self.trainsize) 155 | return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), Image.NEAREST),depth.resize((w, h), Image.NEAREST) 156 | else: 157 | return img, gt, depth 158 | 159 | def __len__(self): 160 | return self.size 161 | 162 | #dataloader for training 163 | def get_loader(image_root, gt_root,depth_root, batchsize, trainsize, shuffle=True, num_workers=12, pin_memory=True): 164 | 165 | dataset = SalObjDataset(image_root, gt_root, depth_root,trainsize) 166 | data_loader = data.DataLoader(dataset=dataset, 167 | batch_size=batchsize, 168 | shuffle=shuffle, 169 | num_workers=num_workers, 170 | pin_memory=pin_memory) 171 | return data_loader 172 | 173 | #test dataset and loader 174 | class test_dataset: 175 | def __init__(self, image_root, gt_root,depth_root, testsize): 176 | self.testsize = testsize 177 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')] 178 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg') 179 | or f.endswith('.png')] 180 | self.depths=[depth_root + f for f in os.listdir(depth_root) if f.endswith('.bmp') 181 | or f.endswith('.png')] 182 | self.images = sorted(self.images) 183 | self.gts = sorted(self.gts) 184 | self.depths=sorted(self.depths) 185 | self.transform = transforms.Compose([ 186 | transforms.Resize((self.testsize, self.testsize)), 187 | transforms.ToTensor(), 188 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 189 | self.gt_transform = transforms.ToTensor() 190 | # self.gt_transform = transforms.Compose([ 191 | # transforms.Resize((self.trainsize, self.trainsize)), 192 | # transforms.ToTensor()]) 193 | self.depths_transform = transforms.Compose([transforms.Resize((self.testsize, self.testsize)),transforms.ToTensor()]) 194 | self.size = len(self.images) 195 | self.index = 0 196 | 197 | def load_data(self): 198 | image = self.rgb_loader(self.images[self.index]) 199 | image = self.transform(image).unsqueeze(0) 200 | gt = self.binary_loader(self.gts[self.index]) 201 | depth=self.binary_loader(self.depths[self.index]) 202 | depth=self.depths_transform(depth).unsqueeze(0) 203 | name = self.images[self.index].split('/')[-1] 204 | image_for_post=self.rgb_loader(self.images[self.index]) 205 | image_for_post=image_for_post.resize(gt.size) 206 | if name.endswith('.jpg'): 207 | name = name.split('.jpg')[0] + '.png' 208 | self.index += 1 209 | self.index = self.index % self.size 210 | return image, gt,depth, name,np.array(image_for_post) 211 | 212 | def rgb_loader(self, path): 213 | with open(path, 'rb') as f: 214 | img = Image.open(f) 215 | return img.convert('RGB') 216 | 217 | def binary_loader(self, path): 218 | with open(path, 'rb') as f: 219 | img = Image.open(f) 220 | return img.convert('L') 221 | def __len__(self): 222 | return self.size 223 | 224 | -------------------------------------------------------------------------------- /models/BBSNet_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models as models 4 | from ResNet import ResNet50 5 | from torch.nn import functional as F 6 | 7 | def conv3x3(in_planes, out_planes, stride=1): 8 | "3x3 convolution with padding" 9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 10 | padding=1, bias=False) 11 | class TransBasicBlock(nn.Module): 12 | expansion = 1 13 | 14 | def __init__(self, inplanes, planes, stride=1, upsample=None, **kwargs): 15 | super(TransBasicBlock, self).__init__() 16 | self.conv1 = conv3x3(inplanes, inplanes) 17 | self.bn1 = nn.BatchNorm2d(inplanes) 18 | self.relu = nn.ReLU(inplace=True) 19 | if upsample is not None and stride != 1: 20 | self.conv2 = nn.ConvTranspose2d(inplanes, planes, 21 | kernel_size=3, stride=stride, padding=1, 22 | output_padding=1, bias=False) 23 | else: 24 | self.conv2 = conv3x3(inplanes, planes, stride) 25 | self.bn2 = nn.BatchNorm2d(planes) 26 | self.upsample = upsample 27 | self.stride = stride 28 | 29 | def forward(self, x): 30 | residual = x 31 | 32 | out = self.conv1(x) 33 | out = self.bn1(out) 34 | out = self.relu(out) 35 | 36 | out = self.conv2(out) 37 | out = self.bn2(out) 38 | 39 | if self.upsample is not None: 40 | residual = self.upsample(x) 41 | 42 | out += residual 43 | out = self.relu(out) 44 | 45 | return out 46 | class ChannelAttention(nn.Module): 47 | def __init__(self, in_planes, ratio=16): 48 | super(ChannelAttention, self).__init__() 49 | 50 | self.max_pool = nn.AdaptiveMaxPool2d(1) 51 | 52 | self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False) 53 | self.relu1 = nn.ReLU() 54 | self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False) 55 | 56 | self.sigmoid = nn.Sigmoid() 57 | def forward(self, x): 58 | max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x)))) 59 | out = max_out 60 | return self.sigmoid(out) 61 | 62 | class SpatialAttention(nn.Module): 63 | def __init__(self, kernel_size=7): 64 | super(SpatialAttention, self).__init__() 65 | 66 | assert kernel_size in (3, 7), 'kernel size must be 3 or 7' 67 | padding = 3 if kernel_size == 7 else 1 68 | 69 | self.conv1 = nn.Conv2d(1, 1, kernel_size, padding=padding, bias=False) 70 | self.sigmoid = nn.Sigmoid() 71 | 72 | def forward(self, x): 73 | max_out, _ = torch.max(x, dim=1, keepdim=True) 74 | x=max_out 75 | x = self.conv1(x) 76 | return self.sigmoid(x) 77 | 78 | class BasicConv2d(nn.Module): 79 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1): 80 | super(BasicConv2d, self).__init__() 81 | self.conv = nn.Conv2d(in_planes, out_planes, 82 | kernel_size=kernel_size, stride=stride, 83 | padding=padding, dilation=dilation, bias=False) 84 | self.bn = nn.BatchNorm2d(out_planes) 85 | self.relu = nn.ReLU(inplace=True) 86 | 87 | def forward(self, x): 88 | x = self.conv(x) 89 | x = self.bn(x) 90 | return x 91 | 92 | #Global Contextual module 93 | class GCM(nn.Module): 94 | def __init__(self, in_channel, out_channel): 95 | super(GCM, self).__init__() 96 | self.relu = nn.ReLU(True) 97 | self.branch0 = nn.Sequential( 98 | BasicConv2d(in_channel, out_channel, 1), 99 | ) 100 | self.branch1 = nn.Sequential( 101 | BasicConv2d(in_channel, out_channel, 1), 102 | BasicConv2d(out_channel, out_channel, kernel_size=(1, 3), padding=(0, 1)), 103 | BasicConv2d(out_channel, out_channel, kernel_size=(3, 1), padding=(1, 0)), 104 | BasicConv2d(out_channel, out_channel, 3, padding=3, dilation=3) 105 | ) 106 | self.branch2 = nn.Sequential( 107 | BasicConv2d(in_channel, out_channel, 1), 108 | BasicConv2d(out_channel, out_channel, kernel_size=(1, 5), padding=(0, 2)), 109 | BasicConv2d(out_channel, out_channel, kernel_size=(5, 1), padding=(2, 0)), 110 | BasicConv2d(out_channel, out_channel, 3, padding=5, dilation=5) 111 | ) 112 | self.branch3 = nn.Sequential( 113 | BasicConv2d(in_channel, out_channel, 1), 114 | BasicConv2d(out_channel, out_channel, kernel_size=(1, 7), padding=(0, 3)), 115 | BasicConv2d(out_channel, out_channel, kernel_size=(7, 1), padding=(3, 0)), 116 | BasicConv2d(out_channel, out_channel, 3, padding=7, dilation=7) 117 | ) 118 | self.conv_cat = BasicConv2d(4*out_channel, out_channel, 3, padding=1) 119 | self.conv_res = BasicConv2d(in_channel, out_channel, 1) 120 | 121 | def forward(self, x): 122 | x0 = self.branch0(x) 123 | x1 = self.branch1(x) 124 | x2 = self.branch2(x) 125 | x3 = self.branch3(x) 126 | 127 | x_cat = self.conv_cat(torch.cat((x0, x1, x2, x3), 1)) 128 | 129 | x = self.relu(x_cat + self.conv_res(x)) 130 | return x 131 | 132 | #aggregation of the high-level(teacher) features 133 | class aggregation_init(nn.Module): 134 | 135 | def __init__(self, channel): 136 | super(aggregation_init, self).__init__() 137 | self.relu = nn.ReLU(True) 138 | 139 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 140 | self.conv_upsample1 = BasicConv2d(channel, channel, 3, padding=1) 141 | self.conv_upsample2 = BasicConv2d(channel, channel, 3, padding=1) 142 | self.conv_upsample3 = BasicConv2d(channel, channel, 3, padding=1) 143 | self.conv_upsample4 = BasicConv2d(channel, channel, 3, padding=1) 144 | self.conv_upsample5 = BasicConv2d(2*channel, 2*channel, 3, padding=1) 145 | 146 | self.conv_concat2 = BasicConv2d(2*channel, 2*channel, 3, padding=1) 147 | self.conv_concat3 = BasicConv2d(3*channel, 3*channel, 3, padding=1) 148 | self.conv4 = BasicConv2d(3*channel, 3*channel, 3, padding=1) 149 | self.conv5 = nn.Conv2d(3*channel, 1, 1) 150 | 151 | def forward(self, x1, x2, x3): 152 | x1_1 = x1 153 | x2_1 = self.conv_upsample1(self.upsample(x1)) * x2 154 | x3_1 = self.conv_upsample2(self.upsample(self.upsample(x1))) \ 155 | * self.conv_upsample3(self.upsample(x2)) * x3 156 | 157 | x2_2 = torch.cat((x2_1, self.conv_upsample4(self.upsample(x1_1))), 1) 158 | x2_2 = self.conv_concat2(x2_2) 159 | 160 | x3_2 = torch.cat((x3_1, self.conv_upsample5(self.upsample(x2_2))), 1) 161 | x3_2 = self.conv_concat3(x3_2) 162 | 163 | x = self.conv4(x3_2) 164 | x = self.conv5(x) 165 | 166 | return x 167 | 168 | #aggregation of the low-level(student) features 169 | class aggregation_final(nn.Module): 170 | 171 | def __init__(self, channel): 172 | super(aggregation_final, self).__init__() 173 | self.relu = nn.ReLU(True) 174 | 175 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 176 | self.conv_upsample1 = BasicConv2d(channel, channel, 3, padding=1) 177 | self.conv_upsample2 = BasicConv2d(channel, channel, 3, padding=1) 178 | self.conv_upsample3 = BasicConv2d(channel, channel, 3, padding=1) 179 | self.conv_upsample4 = BasicConv2d(channel, channel, 3, padding=1) 180 | self.conv_upsample5 = BasicConv2d(2*channel, 2*channel, 3, padding=1) 181 | 182 | self.conv_concat2 = BasicConv2d(2*channel, 2*channel, 3, padding=1) 183 | self.conv_concat3 = BasicConv2d(3*channel, 3*channel, 3, padding=1) 184 | 185 | def forward(self, x1, x2, x3): 186 | x1_1 = x1 187 | x2_1 = self.conv_upsample1(self.upsample(x1)) * x2 188 | x3_1 = self.conv_upsample2(self.upsample(x1)) \ 189 | * self.conv_upsample3(x2) * x3 190 | 191 | x2_2 = torch.cat((x2_1, self.conv_upsample4(self.upsample(x1_1))), 1) 192 | x2_2 = self.conv_concat2(x2_2) 193 | 194 | x3_2 = torch.cat((x3_1, self.conv_upsample5(x2_2)), 1) 195 | x3_2 = self.conv_concat3(x3_2) 196 | 197 | return x3_2 198 | 199 | #Refinement flow 200 | class Refine(nn.Module): 201 | def __init__(self): 202 | super(Refine,self).__init__() 203 | self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 204 | 205 | def forward(self, attention,x1,x2,x3): 206 | #Note that there is an error in the manuscript. In the paper, the refinement strategy is depicted as ""f'=f*S1"", it should be ""f'=f+f*S1"". 207 | x1 = x1+torch.mul(x1, self.upsample2(attention)) 208 | x2 = x2+torch.mul(x2,self.upsample2(attention)) 209 | x3 = x3+torch.mul(x3,attention) 210 | 211 | return x1,x2,x3 212 | 213 | #BBSNet 214 | class BBSNet(nn.Module): 215 | def __init__(self, channel=32): 216 | super(BBSNet, self).__init__() 217 | 218 | #Backbone model 219 | self.resnet = ResNet50('rgb') 220 | self.resnet_depth=ResNet50('rgbd') 221 | 222 | #Decoder 1 223 | self.rfb2_1 = GCM(512, channel) 224 | self.rfb3_1 = GCM(1024, channel) 225 | self.rfb4_1 = GCM(2048, channel) 226 | self.agg1 = aggregation_init(channel) 227 | 228 | #Decoder 2 229 | self.rfb0_2 = GCM(64, channel) 230 | self.rfb1_2 = GCM(256, channel) 231 | self.rfb5_2 = GCM(512, channel) 232 | self.agg2 = aggregation_final(channel) 233 | 234 | #upsample function 235 | self.upsample = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True) 236 | self.upsample4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) 237 | self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 238 | 239 | #Refinement flow 240 | self.HA = Refine() 241 | 242 | #Components of DEM module 243 | self.atten_depth_channel_0=ChannelAttention(64) 244 | self.atten_depth_channel_1=ChannelAttention(256) 245 | self.atten_depth_channel_2=ChannelAttention(512) 246 | self.atten_depth_channel_3_1=ChannelAttention(1024) 247 | self.atten_depth_channel_4_1=ChannelAttention(2048) 248 | 249 | self.atten_depth_spatial_0=SpatialAttention() 250 | self.atten_depth_spatial_1=SpatialAttention() 251 | self.atten_depth_spatial_2=SpatialAttention() 252 | self.atten_depth_spatial_3_1=SpatialAttention() 253 | self.atten_depth_spatial_4_1=SpatialAttention() 254 | 255 | #Components of PTM module 256 | self.inplanes = 32*2 257 | self.deconv1 = self._make_transpose(TransBasicBlock, 32*2, 3, stride=2) 258 | self.inplanes =32 259 | self.deconv2 = self._make_transpose(TransBasicBlock, 32, 3, stride=2) 260 | self.agant1 = self._make_agant_layer(32*3, 32*2) 261 | self.agant2 = self._make_agant_layer(32*2, 32) 262 | self.out0_conv = nn.Conv2d(32*3, 1, kernel_size=1, stride=1, bias=True) 263 | self.out1_conv = nn.Conv2d(32*2, 1, kernel_size=1, stride=1, bias=True) 264 | self.out2_conv = nn.Conv2d(32*1, 1, kernel_size=1, stride=1, bias=True) 265 | 266 | if self.training: 267 | self.initialize_weights() 268 | 269 | def forward(self, x, x_depth): 270 | x = self.resnet.conv1(x) 271 | x = self.resnet.bn1(x) 272 | x = self.resnet.relu(x) 273 | x = self.resnet.maxpool(x) 274 | 275 | x_depth = self.resnet_depth.conv1(x_depth) 276 | x_depth = self.resnet_depth.bn1(x_depth) 277 | x_depth = self.resnet_depth.relu(x_depth) 278 | x_depth = self.resnet_depth.maxpool(x_depth) 279 | 280 | #layer0 merge 281 | temp = x_depth.mul(self.atten_depth_channel_0(x_depth)) 282 | temp = temp.mul(self.atten_depth_spatial_0(temp)) 283 | x=x+temp 284 | #layer0 merge end 285 | 286 | x1 = self.resnet.layer1(x) # 256 x 64 x 64 287 | x1_depth=self.resnet_depth.layer1(x_depth) 288 | 289 | #layer1 merge 290 | temp = x1_depth.mul(self.atten_depth_channel_1(x1_depth)) 291 | temp = temp.mul(self.atten_depth_spatial_1(temp)) 292 | x1=x1+temp 293 | #layer1 merge end 294 | 295 | x2 = self.resnet.layer2(x1) # 512 x 32 x 32 296 | x2_depth=self.resnet_depth.layer2(x1_depth) 297 | 298 | #layer2 merge 299 | temp = x2_depth.mul(self.atten_depth_channel_2(x2_depth)) 300 | temp = temp.mul(self.atten_depth_spatial_2(temp)) 301 | x2=x2+temp 302 | #layer2 merge end 303 | 304 | x2_1 = x2 305 | 306 | x3_1 = self.resnet.layer3_1(x2_1) # 1024 x 16 x 16 307 | x3_1_depth=self.resnet_depth.layer3_1(x2_depth) 308 | 309 | #layer3_1 merge 310 | temp = x3_1_depth.mul(self.atten_depth_channel_3_1(x3_1_depth)) 311 | temp = temp.mul(self.atten_depth_spatial_3_1(temp)) 312 | x3_1=x3_1+temp 313 | #layer3_1 merge end 314 | 315 | x4_1 = self.resnet.layer4_1(x3_1) # 2048 x 8 x 8 316 | x4_1_depth=self.resnet_depth.layer4_1(x3_1_depth) 317 | 318 | #layer4_1 merge 319 | temp = x4_1_depth.mul(self.atten_depth_channel_4_1(x4_1_depth)) 320 | temp = temp.mul(self.atten_depth_spatial_4_1(temp)) 321 | x4_1=x4_1+temp 322 | #layer4_1 merge end 323 | 324 | #produce initial saliency map by decoder1 325 | x2_1 = self.rfb2_1(x2_1) 326 | x3_1 = self.rfb3_1(x3_1) 327 | x4_1 = self.rfb4_1(x4_1) 328 | attention_map = self.agg1(x4_1, x3_1, x2_1) 329 | 330 | #Refine low-layer features by initial map 331 | x,x1,x5 = self.HA(attention_map.sigmoid(), x,x1,x2) 332 | 333 | #produce final saliency map by decoder2 334 | x0_2 = self.rfb0_2(x) 335 | x1_2 = self.rfb1_2(x1) 336 | x5_2 = self.rfb5_2(x5) 337 | y = self.agg2(x5_2, x1_2, x0_2) #*4 338 | 339 | #PTM module 340 | y =self.agant1(y) 341 | y = self.deconv1(y) 342 | y = self.agant2(y) 343 | y = self.deconv2(y) 344 | y = self.out2_conv(y) 345 | 346 | return self.upsample(attention_map),y 347 | 348 | def _make_agant_layer(self, inplanes, planes): 349 | layers = nn.Sequential( 350 | nn.Conv2d(inplanes, planes, kernel_size=1, 351 | stride=1, padding=0, bias=False), 352 | nn.BatchNorm2d(planes), 353 | nn.ReLU(inplace=True) 354 | ) 355 | return layers 356 | 357 | def _make_transpose(self, block, planes, blocks, stride=1): 358 | upsample = None 359 | if stride != 1: 360 | upsample = nn.Sequential( 361 | nn.ConvTranspose2d(self.inplanes, planes, 362 | kernel_size=2, stride=stride, 363 | padding=0, bias=False), 364 | nn.BatchNorm2d(planes), 365 | ) 366 | elif self.inplanes != planes: 367 | upsample = nn.Sequential( 368 | nn.Conv2d(self.inplanes, planes, 369 | kernel_size=1, stride=stride, bias=False), 370 | nn.BatchNorm2d(planes), 371 | ) 372 | 373 | layers = [] 374 | 375 | for i in range(1, blocks): 376 | layers.append(block(self.inplanes, self.inplanes)) 377 | 378 | layers.append(block(self.inplanes, planes, stride, upsample)) 379 | self.inplanes = planes 380 | 381 | return nn.Sequential(*layers) 382 | 383 | #initialize the weights 384 | def initialize_weights(self): 385 | res50 = models.resnet50(pretrained=True) 386 | pretrained_dict = res50.state_dict() 387 | all_params = {} 388 | for k, v in self.resnet.state_dict().items(): 389 | if k in pretrained_dict.keys(): 390 | v = pretrained_dict[k] 391 | all_params[k] = v 392 | elif '_1' in k: 393 | name = k.split('_1')[0] + k.split('_1')[1] 394 | v = pretrained_dict[name] 395 | all_params[k] = v 396 | elif '_2' in k: 397 | name = k.split('_2')[0] + k.split('_2')[1] 398 | v = pretrained_dict[name] 399 | all_params[k] = v 400 | assert len(all_params.keys()) == len(self.resnet.state_dict().keys()) 401 | self.resnet.load_state_dict(all_params) 402 | 403 | all_params = {} 404 | for k, v in self.resnet_depth.state_dict().items(): 405 | if k=='conv1.weight': 406 | all_params[k]=torch.nn.init.normal_(v, mean=0, std=1) 407 | elif k in pretrained_dict.keys(): 408 | v = pretrained_dict[k] 409 | all_params[k] = v 410 | elif '_1' in k: 411 | name = k.split('_1')[0] + k.split('_1')[1] 412 | v = pretrained_dict[name] 413 | all_params[k] = v 414 | elif '_2' in k: 415 | name = k.split('_2')[0] + k.split('_2')[1] 416 | v = pretrained_dict[name] 417 | all_params[k] = v 418 | assert len(all_params.keys()) == len(self.resnet_depth.state_dict().keys()) 419 | self.resnet_depth.load_state_dict(all_params) 420 | 421 | -------------------------------------------------------------------------------- /models/ResNet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | 4 | 5 | def conv3x3(in_planes, out_planes, stride=1): 6 | """3x3 convolution with padding""" 7 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 8 | padding=1, bias=False) 9 | 10 | 11 | class BasicBlock(nn.Module): 12 | expansion = 1 13 | 14 | def __init__(self, inplanes, planes, stride=1, downsample=None): 15 | super(BasicBlock, self).__init__() 16 | self.conv1 = conv3x3(inplanes, planes, stride) 17 | self.bn1 = nn.BatchNorm2d(planes) 18 | self.relu = nn.ReLU(inplace=True) 19 | self.conv2 = conv3x3(planes, planes) 20 | self.bn2 = nn.BatchNorm2d(planes) 21 | self.downsample = downsample 22 | self.stride = stride 23 | 24 | def forward(self, x): 25 | residual = x 26 | 27 | out = self.conv1(x) 28 | out = self.bn1(out) 29 | out = self.relu(out) 30 | 31 | out = self.conv2(out) 32 | out = self.bn2(out) 33 | 34 | if self.downsample is not None: 35 | residual = self.downsample(x) 36 | 37 | out += residual 38 | out = self.relu(out) 39 | 40 | return out 41 | 42 | 43 | class Bottleneck(nn.Module): 44 | expansion = 4 45 | 46 | def __init__(self, inplanes, planes, stride=1, downsample=None): 47 | super(Bottleneck, self).__init__() 48 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 49 | self.bn1 = nn.BatchNorm2d(planes) 50 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 51 | padding=1, bias=False) 52 | self.bn2 = nn.BatchNorm2d(planes) 53 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 54 | self.bn3 = nn.BatchNorm2d(planes * 4) 55 | self.relu = nn.ReLU(inplace=True) 56 | self.downsample = downsample 57 | self.stride = stride 58 | 59 | def forward(self, x): 60 | residual = x 61 | 62 | out = self.conv1(x) 63 | out = self.bn1(out) 64 | out = self.relu(out) 65 | 66 | out = self.conv2(out) 67 | out = self.bn2(out) 68 | out = self.relu(out) 69 | 70 | out = self.conv3(out) 71 | out = self.bn3(out) 72 | 73 | if self.downsample is not None: 74 | residual = self.downsample(x) 75 | 76 | out += residual 77 | out = self.relu(out) 78 | 79 | return out 80 | 81 | 82 | class ResNet50(nn.Module): 83 | def __init__(self,mode='rgb'): 84 | self.inplanes = 64 85 | super(ResNet50, self).__init__() 86 | if(mode=='rgb'): 87 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 88 | bias=False) 89 | elif(mode=='rgbd'): 90 | self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, 91 | bias=False) 92 | elif(mode=="share"): 93 | self.conv1=nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 94 | bias=False) 95 | self.conv1_d=nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, 96 | bias=False) 97 | else: 98 | raise 99 | self.bn1 = nn.BatchNorm2d(64) 100 | self.relu = nn.ReLU(inplace=True) 101 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 102 | self.layer1 = self._make_layer(Bottleneck, 64, 3) 103 | self.layer2 = self._make_layer(Bottleneck, 128, 4, stride=2) 104 | self.layer3_1 = self._make_layer(Bottleneck, 256, 6, stride=2) 105 | self.layer4_1 = self._make_layer(Bottleneck, 512, 3, stride=2) 106 | 107 | self.inplanes = 512 108 | 109 | for m in self.modules(): 110 | if isinstance(m, nn.Conv2d): 111 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 112 | m.weight.data.normal_(0, math.sqrt(2. / n)) 113 | elif isinstance(m, nn.BatchNorm2d): 114 | m.weight.data.fill_(1) 115 | m.bias.data.zero_() 116 | 117 | def _make_layer(self, block, planes, blocks, stride=1): 118 | downsample = None 119 | if stride != 1 or self.inplanes != planes * block.expansion: 120 | downsample = nn.Sequential( 121 | nn.Conv2d(self.inplanes, planes * block.expansion, 122 | kernel_size=1, stride=stride, bias=False), 123 | nn.BatchNorm2d(planes * block.expansion), 124 | ) 125 | 126 | layers = [] 127 | layers.append(block(self.inplanes, planes, stride, downsample)) 128 | self.inplanes = planes * block.expansion 129 | for i in range(1, blocks): 130 | layers.append(block(self.inplanes, planes)) 131 | 132 | return nn.Sequential(*layers) 133 | 134 | def forward(self, x): 135 | x = self.conv1(x) 136 | x = self.bn1(x) 137 | x = self.relu(x) 138 | x = self.maxpool(x) 139 | 140 | x = self.layer1(x) 141 | x = self.layer2(x) 142 | x1 = self.layer3_1(x) 143 | x1 = self.layer4_1(x1) 144 | 145 | return x1, x1 146 | -------------------------------------------------------------------------------- /models/__pycache__/BBSNet_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyjwuyan/BBS-Net/d48bd4f658844d78ca5b39d36bae65ce50688a00/models/__pycache__/BBSNet_model.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/ResNet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyjwuyan/BBS-Net/d48bd4f658844d78ca5b39d36bae65ce50688a00/models/__pycache__/ResNet.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/ResNet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyjwuyan/BBS-Net/d48bd4f658844d78ca5b39d36bae65ce50688a00/models/__pycache__/ResNet.cpython-37.pyc -------------------------------------------------------------------------------- /options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | parser = argparse.ArgumentParser() 3 | parser.add_argument('--epoch', type=int, default=200, help='epoch number') 4 | parser.add_argument('--lr', type=float, default=1e-4, help='learning rate') 5 | parser.add_argument('--batchsize', type=int, default=10, help='training batch size') 6 | parser.add_argument('--trainsize', type=int, default=352, help='training dataset size') 7 | parser.add_argument('--clip', type=float, default=0.5, help='gradient clipping margin') 8 | parser.add_argument('--decay_rate', type=float, default=0.1, help='decay rate of learning rate') 9 | parser.add_argument('--decay_epoch', type=int, default=60, help='every n epochs decay learning rate') 10 | parser.add_argument('--load', type=str, default=None, help='train from checkpoints') 11 | parser.add_argument('--gpu_id', type=str, default='0', help='train use gpu') 12 | parser.add_argument('--rgb_root', type=str, default='../BBS_dataset/RGBD_for_train/RGB/', help='the training rgb images root') 13 | parser.add_argument('--depth_root', type=str, default='../BBS_dataset/RGBD_for_train/depth/', help='the training depth images root') 14 | parser.add_argument('--gt_root', type=str, default='../BBS_dataset/RGBD_for_train/GT/', help='the training gt images root') 15 | parser.add_argument('--test_rgb_root', type=str, default='../BBS_dataset/test_in_train/RGB/', help='the test rgb images root') 16 | parser.add_argument('--test_depth_root', type=str, default='../BBS_dataset/test_in_train/depth/', help='the test depth images root') 17 | parser.add_argument('--test_gt_root', type=str, default='../BBS_dataset/test_in_train/GT/', help='the test gt images root') 18 | parser.add_argument('--save_path', type=str, default='./BBSNet_cpts/', help='the path to save models and logs') 19 | opt = parser.parse_args() 20 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | def clip_gradient(optimizer, grad_clip): 2 | for group in optimizer.param_groups: 3 | for param in group['params']: 4 | if param.grad is not None: 5 | param.grad.data.clamp_(-grad_clip, grad_clip) 6 | 7 | 8 | def adjust_lr(optimizer, init_lr, epoch, decay_rate=0.1, decay_epoch=30): 9 | decay = decay_rate ** (epoch // decay_epoch) 10 | for param_group in optimizer.param_groups: 11 | param_group['lr'] = decay*init_lr 12 | lr=param_group['lr'] 13 | return lr 14 | --------------------------------------------------------------------------------