├── BBSNet_test.py
├── BBSNet_train.py
├── Images
├── backbone_result.png
├── detailed-comparisons.png
├── pipeline.png
└── resultmap.png
├── LICENSE
├── README.md
├── data.py
├── models
├── BBSNet_model.py
├── ResNet.py
└── __pycache__
│ ├── BBSNet_model.cpython-37.pyc
│ ├── ResNet.cpython-36.pyc
│ └── ResNet.cpython-37.pyc
├── options.py
└── utils.py
/BBSNet_test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import sys
4 | sys.path.append('./models')
5 | import numpy as np
6 | import os, argparse
7 | import cv2
8 | from models.BBSNet_model import BBSNet
9 | from data import test_dataset
10 |
11 |
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument('--testsize', type=int, default=352, help='testing size')
14 | parser.add_argument('--gpu_id', type=str, default='0', help='select gpu id')
15 | parser.add_argument('--test_path',type=str,default='../BBS_dataset/RGBD_for_test/',help='test dataset path')
16 | opt = parser.parse_args()
17 |
18 | dataset_path = opt.test_path
19 |
20 | #set device for test
21 | if opt.gpu_id=='0':
22 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
23 | print('USE GPU 0')
24 | elif opt.gpu_id=='1':
25 | os.environ["CUDA_VISIBLE_DEVICES"] = "1"
26 | print('USE GPU 1')
27 |
28 | #load the model
29 | model = BBSNet()
30 | #Large epoch size may not generalize well. You can choose a good model to load according to the log file and pth files saved in ('./BBSNet_cpts/') when training.
31 | model.load_state_dict(torch.load('./model_pths/BBSNet.pth'))
32 | model.cuda()
33 | model.eval()
34 |
35 | #test
36 | test_datasets = ['NJU2K','NLPR','STERE', 'DES', 'SSD','LFSD','SIP']
37 | for dataset in test_datasets:
38 | save_path = './test_maps/BBSNet/ResNet50/' + dataset + '/'
39 | if not os.path.exists(save_path):
40 | os.makedirs(save_path)
41 | image_root = dataset_path + dataset + '/RGB/'
42 | gt_root = dataset_path + dataset + '/GT/'
43 | depth_root=dataset_path +dataset +'/depth/'
44 | test_loader = test_dataset(image_root, gt_root,depth_root, opt.testsize)
45 | for i in range(test_loader.size):
46 | image, gt,depth, name, image_for_post = test_loader.load_data()
47 | gt = np.asarray(gt, np.float32)
48 | gt /= (gt.max() + 1e-8)
49 | image = image.cuda()
50 | depth = depth.cuda()
51 | _,res = model(image,depth)
52 | res = F.upsample(res, size=gt.shape, mode='bilinear', align_corners=False)
53 | res = res.sigmoid().data.cpu().numpy().squeeze()
54 | res = (res - res.min()) / (res.max() - res.min() + 1e-8)
55 | print('save img to: ',save_path+name)
56 | cv2.imwrite(save_path+name,res*255)
57 | print('Test Done!')
58 |
--------------------------------------------------------------------------------
/BBSNet_train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn.functional as F
4 | import sys
5 | sys.path.append('./models')
6 | import numpy as np
7 | from datetime import datetime
8 | from torchvision.utils import make_grid
9 | from models.BBSNet_model import BBSNet
10 | from data import get_loader,test_dataset
11 | from utils import clip_gradient, adjust_lr
12 | from tensorboardX import SummaryWriter
13 | import logging
14 | import torch.backends.cudnn as cudnn
15 | from options import opt
16 |
17 | #set the device for training
18 | if opt.gpu_id=='0':
19 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
20 | print('USE GPU 0')
21 | elif opt.gpu_id=='1':
22 | os.environ["CUDA_VISIBLE_DEVICES"] = "1"
23 | print('USE GPU 1')
24 | cudnn.benchmark = True
25 |
26 | #build the model
27 | model = BBSNet()
28 | if(opt.load is not None):
29 | model.load_state_dict(torch.load(opt.load))
30 | print('load model from ',opt.load)
31 |
32 | model.cuda()
33 | params = model.parameters()
34 | optimizer = torch.optim.Adam(params, opt.lr)
35 |
36 | #set the path
37 | image_root = opt.rgb_root
38 | gt_root = opt.gt_root
39 | depth_root=opt.depth_root
40 | test_image_root=opt.test_rgb_root
41 | test_gt_root=opt.test_gt_root
42 | test_depth_root=opt.test_depth_root
43 | save_path=opt.save_path
44 |
45 | if not os.path.exists(save_path):
46 | os.makedirs(save_path)
47 |
48 | #load data
49 | print('load data...')
50 | train_loader = get_loader(image_root, gt_root,depth_root, batchsize=opt.batchsize, trainsize=opt.trainsize)
51 | test_loader = test_dataset(test_image_root, test_gt_root,test_depth_root, opt.trainsize)
52 | total_step = len(train_loader)
53 |
54 | logging.basicConfig(filename=save_path+'log.log',format='[%(asctime)s-%(filename)s-%(levelname)s:%(message)s]', level = logging.INFO,filemode='a',datefmt='%Y-%m-%d %I:%M:%S %p')
55 | logging.info("BBSNet-Train")
56 | logging.info("Config")
57 | logging.info('epoch:{};lr:{};batchsize:{};trainsize:{};clip:{};decay_rate:{};load:{};save_path:{};decay_epoch:{}'.format(opt.epoch,opt.lr,opt.batchsize,opt.trainsize,opt.clip,opt.decay_rate,opt.load,save_path,opt.decay_epoch))
58 |
59 | #set loss function
60 | CE = torch.nn.BCEWithLogitsLoss()
61 |
62 | step=0
63 | writer = SummaryWriter(save_path+'summary')
64 | best_mae=1
65 | best_epoch=0
66 |
67 | #train function
68 | def train(train_loader, model, optimizer, epoch,save_path):
69 | global step
70 | model.train()
71 | loss_all=0
72 | epoch_step=0
73 | try:
74 | for i, (images, gts, depths) in enumerate(train_loader, start=1):
75 | optimizer.zero_grad()
76 |
77 | images = images.cuda()
78 | gts = gts.cuda()
79 | depths=depths.cuda()
80 |
81 | s1,s2 = model(images,depths)
82 | loss1 = CE(s1, gts)
83 | loss2 =CE(s2,gts)
84 | loss = loss1+loss2
85 | loss.backward()
86 |
87 | clip_gradient(optimizer, opt.clip)
88 | optimizer.step()
89 | step+=1
90 | epoch_step+=1
91 | loss_all+=loss.data
92 | if i % 100 == 0 or i == total_step or i==1:
93 | print('{} Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], Loss1: {:.4f} Loss2: {:0.4f}'.
94 | format(datetime.now(), epoch, opt.epoch, i, total_step, loss1.data, loss2.data))
95 | logging.info('#TRAIN#:Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], Loss1: {:.4f} Loss2: {:0.4f}'.
96 | format( epoch, opt.epoch, i, total_step, loss1.data, loss2.data))
97 | writer.add_scalar('Loss', loss.data, global_step=step)
98 | grid_image = make_grid(images[0].clone().cpu().data, 1, normalize=True)
99 | writer.add_image('RGB', grid_image, step)
100 | grid_image = make_grid(gts[0].clone().cpu().data, 1, normalize=True)
101 | writer.add_image('Ground_truth', grid_image, step)
102 | res=s1[0].clone()
103 | res = res.sigmoid().data.cpu().numpy().squeeze()
104 | res = (res - res.min()) / (res.max() - res.min() + 1e-8)
105 | writer.add_image('s1', torch.tensor(res), step,dataformats='HW')
106 | res=s2[0].clone()
107 | res = res.sigmoid().data.cpu().numpy().squeeze()
108 | res = (res - res.min()) / (res.max() - res.min() + 1e-8)
109 | writer.add_image('s2', torch.tensor(res), step,dataformats='HW')
110 |
111 | loss_all/=epoch_step
112 | logging.info('#TRAIN#:Epoch [{:03d}/{:03d}], Loss_AVG: {:.4f}'.format( epoch, opt.epoch, loss_all))
113 | writer.add_scalar('Loss-epoch', loss_all, global_step=epoch)
114 | if (epoch) % 5 == 0:
115 | torch.save(model.state_dict(), save_path+'BBSNet_epoch_{}.pth'.format(epoch))
116 | except KeyboardInterrupt:
117 | print('Keyboard Interrupt: save model and exit.')
118 | if not os.path.exists(save_path):
119 | os.makedirs(save_path)
120 | torch.save(model.state_dict(), save_path+'BBSNet_epoch_{}.pth'.format(epoch+1))
121 | print('save checkpoints successfully!')
122 | raise
123 |
124 | #test function
125 | def test(test_loader,model,epoch,save_path):
126 | global best_mae,best_epoch
127 | model.eval()
128 | with torch.no_grad():
129 | mae_sum=0
130 | for i in range(test_loader.size):
131 | image, gt,depth, name,img_for_post = test_loader.load_data()
132 | gt = np.asarray(gt, np.float32)
133 | gt /= (gt.max() + 1e-8)
134 | image = image.cuda()
135 | depth = depth.cuda()
136 | _,res = model(image,depth)
137 | res = F.upsample(res, size=gt.shape, mode='bilinear', align_corners=False)
138 | res = res.sigmoid().data.cpu().numpy().squeeze()
139 | res = (res - res.min()) / (res.max() - res.min() + 1e-8)
140 | mae_sum+=np.sum(np.abs(res-gt))*1.0/(gt.shape[0]*gt.shape[1])
141 | mae=mae_sum/test_loader.size
142 | writer.add_scalar('MAE', torch.tensor(mae), global_step=epoch)
143 | print('Epoch: {} MAE: {} #### bestMAE: {} bestEpoch: {}'.format(epoch,mae,best_mae,best_epoch))
144 | if epoch==1:
145 | best_mae=mae
146 | else:
147 | if mae
6 |
7 |
8 | Figure 1: Pipeline of the BBS-Net.
9 |
10 |
11 |
12 | ## 1. Requirements
13 |
14 | Python 3.7, Pytorch 0.4.0+, Cuda 10.0, TensorboardX 2.0, opencv-python
15 |
16 | ## 2. Data Preparation
17 |
18 | - Download the raw data from [Baidu Pan](https://pan.baidu.com/s/1SxBjlTF4Tb74WjuDsRmM3w) [code: yiy1] or [Google Drive](https://drive.google.com/drive/folders/1gIMun9bM5JDrs98sLjXt7XoFCdvy1DXF?usp=sharing) and trained model (BBSNet.pth) from [Here](https://pan.baidu.com/s/1Fn-Hvdou4DDWcgeTtx081g) [code: dwcp]. Then put them under the following directory:
19 |
20 | -BBS_dataset\
21 | -RGBD_for_train\
22 | -RGBD_for_test\
23 | -test_in_train\
24 | -BBSNet
25 | -models\
26 | -model_pths\
27 | -BBSNet.pth
28 | ...
29 |
30 | - Note that the depth maps of the raw data above are not normalized. If you train and test using the normalized depth maps, the performance will be improved.
31 |
32 | ## 3. Training & Testing
33 |
34 | - Train the BBSNet:
35 |
36 | `python BBSNet_train.py --batchsize 10 --gpu_id 0 `
37 |
38 | - Test the BBSNet:
39 |
40 | `python BBSNet_test.py --gpu_id 0 `
41 |
42 | The test maps will be saved to './test_maps/'.
43 |
44 | - Evaluate the result maps:
45 |
46 | You can evaluate the result maps using the tool in [Python_GPU Version](https://github.com/zyjwuyan/SOD_Evaluation_Metrics) or [Matlab Version](http://dpfan.net/d3netbenchmark/).
47 |
48 | - If you need the codes using VGG16 and VGG19 backbones, please send to the email (zhaiyingjier@163.com). Please provide your Name & Institution. Please note the code can be only used for research purpose.
49 | ## 4. Results
50 | ### 4.1 Qualitative Comparison
51 |
52 |
53 |
54 | Figure 2: Qualitative visual comparison of the proposed model versus 8 SOTA
55 | models.
56 |
57 |
58 |
59 |
60 |
61 | Table 1: Quantitative comparison of models using S-measure max F-measure, max E-measureand MAE scores on 7 datasets.
62 |
63 |
64 |
72 |
73 | ### 4.2 Results of multiple backbones
74 |
75 |
76 |
77 |
78 | Table 2: Performance comparison using different backbones.
79 |
80 |
81 |
82 | ### 4.3 Download
83 | - Test maps of the above datasets (ResNet50 backbone) can be download from [here](https://pan.baidu.com/s/1O-AhThLWEDVgQiPhX3QVYw) [code: qgai ].
84 | - Test maps of vgg16 and vgg19 backbones of our model can be download from [here](https://pan.baidu.com/s/1_hG3hC2Fpt1cbAWuPrHEPA) [code: zuds ].
85 | - Test maps of DUT-RGBD dataset (using the proposed training-test splits of [DMRA](https://openaccess.thecvf.com/content_ICCV_2019/papers/Piao_Depth-Induced_Multi-Scale_Recurrent_Attention_Network_for_Saliency_Detection_ICCV_2019_paper.pdf)) can be downloaded from [here](https://pan.baidu.com/s/15oc_-nwEKNiU1C9WRho5lg) [code: 3nme ].
86 | ## 5. Citation
87 |
88 | Please cite the following paper if you use this repository in your reseach.
89 |
90 | @inproceedings{fan2020bbsnet,
91 | title={BBS-Net: RGB-D Salient Object Detection with a Bifurcated Backbone Strategy Network},
92 | author={Fan, Deng-Ping and Zhai, Yingjie and Borji, Ali and Yang, Jufeng and Shao, Ling},
93 | booktitle={ECCV},
94 | year={2020}
95 | }
96 |
97 | - For more information about BBS-Net, please read the [Manuscript (PDF)](https://arxiv.org/pdf/2007.02713.pdf) ([Chinese version](https://pan.baidu.com/s/1zxni7QjBiewwA1Q-m7Cqfg)[code:0r4a]).
98 | - Note that there is a wrong in the Fig.3 (c) of the ECCV version. The second and third BConv3 in the first column of the figure should be BConv5 and BConv7 respectively.
99 |
100 | ## 6. Benchmark RGB-D SOD
101 |
102 | The complete RGB-D SOD benchmark can be found in this page:
103 |
104 | http://dpfan.net/d3netbenchmark/
105 |
106 | ## 7. Acknowledgement
107 | We implement this project based on the code of ‘Cascaded Partial Decoder for Fast and Accurate Salient Object Detection, CVPR2019’ proposed by Wu et al.
108 |
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | import torch.utils.data as data
4 | import torchvision.transforms as transforms
5 | import random
6 | import numpy as np
7 | from PIL import ImageEnhance
8 |
9 | #several data augumentation strategies
10 | def cv_random_flip(img, label,depth):
11 | flip_flag = random.randint(0, 1)
12 | # flip_flag2= random.randint(0,1)
13 | #left right flip
14 | if flip_flag == 1:
15 | img = img.transpose(Image.FLIP_LEFT_RIGHT)
16 | label = label.transpose(Image.FLIP_LEFT_RIGHT)
17 | depth = depth.transpose(Image.FLIP_LEFT_RIGHT)
18 | #top bottom flip
19 | # if flip_flag2==1:
20 | # img = img.transpose(Image.FLIP_TOP_BOTTOM)
21 | # label = label.transpose(Image.FLIP_TOP_BOTTOM)
22 | # depth = depth.transpose(Image.FLIP_TOP_BOTTOM)
23 | return img, label, depth
24 | def randomCrop(image, label,depth):
25 | border=30
26 | image_width = image.size[0]
27 | image_height = image.size[1]
28 | crop_win_width = np.random.randint(image_width-border , image_width)
29 | crop_win_height = np.random.randint(image_height-border , image_height)
30 | random_region = (
31 | (image_width - crop_win_width) >> 1, (image_height - crop_win_height) >> 1, (image_width + crop_win_width) >> 1,
32 | (image_height + crop_win_height) >> 1)
33 | return image.crop(random_region), label.crop(random_region),depth.crop(random_region)
34 | def randomRotation(image,label,depth):
35 | mode=Image.BICUBIC
36 | if random.random()>0.8:
37 | random_angle = np.random.randint(-15, 15)
38 | image=image.rotate(random_angle, mode)
39 | label=label.rotate(random_angle, mode)
40 | depth=depth.rotate(random_angle, mode)
41 | return image,label,depth
42 | def colorEnhance(image):
43 | bright_intensity=random.randint(5,15)/10.0
44 | image=ImageEnhance.Brightness(image).enhance(bright_intensity)
45 | contrast_intensity=random.randint(5,15)/10.0
46 | image=ImageEnhance.Contrast(image).enhance(contrast_intensity)
47 | color_intensity=random.randint(0,20)/10.0
48 | image=ImageEnhance.Color(image).enhance(color_intensity)
49 | sharp_intensity=random.randint(0,30)/10.0
50 | image=ImageEnhance.Sharpness(image).enhance(sharp_intensity)
51 | return image
52 | def randomGaussian(image, mean=0.1, sigma=0.35):
53 | def gaussianNoisy(im, mean=mean, sigma=sigma):
54 | for _i in range(len(im)):
55 | im[_i] += random.gauss(mean, sigma)
56 | return im
57 | img = np.asarray(image)
58 | width, height = img.shape
59 | img = gaussianNoisy(img[:].flatten(), mean, sigma)
60 | img = img.reshape([width, height])
61 | return Image.fromarray(np.uint8(img))
62 | def randomPeper(img):
63 |
64 | img=np.array(img)
65 | noiseNum=int(0.0015*img.shape[0]*img.shape[1])
66 | for i in range(noiseNum):
67 |
68 | randX=random.randint(0,img.shape[0]-1)
69 |
70 | randY=random.randint(0,img.shape[1]-1)
71 |
72 | if random.randint(0,1)==0:
73 |
74 | img[randX,randY]=0
75 |
76 | else:
77 |
78 | img[randX,randY]=255
79 | return Image.fromarray(img)
80 |
81 | # dataset for training
82 | #The current loader is not using the normalized depth maps for training and test. If you use the normalized depth maps
83 | #(e.g., 0 represents background and 1 represents foreground.), the performance will be further improved.
84 | class SalObjDataset(data.Dataset):
85 | def __init__(self, image_root, gt_root,depth_root, trainsize):
86 | self.trainsize = trainsize
87 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')]
88 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg')
89 | or f.endswith('.png')]
90 | self.depths=[depth_root + f for f in os.listdir(depth_root) if f.endswith('.bmp')
91 | or f.endswith('.png')]
92 | self.images = sorted(self.images)
93 | self.gts = sorted(self.gts)
94 | self.depths=sorted(self.depths)
95 | self.filter_files()
96 | self.size = len(self.images)
97 | self.img_transform = transforms.Compose([
98 | transforms.Resize((self.trainsize, self.trainsize)),
99 | transforms.ToTensor(),
100 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
101 | self.gt_transform = transforms.Compose([
102 | transforms.Resize((self.trainsize, self.trainsize)),
103 | transforms.ToTensor()])
104 | self.depths_transform = transforms.Compose([transforms.Resize((self.trainsize, self.trainsize)),transforms.ToTensor()])
105 |
106 | def __getitem__(self, index):
107 | image = self.rgb_loader(self.images[index])
108 | gt = self.binary_loader(self.gts[index])
109 | depth=self.binary_loader(self.depths[index])
110 | image,gt,depth =cv_random_flip(image,gt,depth)
111 | image,gt,depth=randomCrop(image, gt,depth)
112 | image,gt,depth=randomRotation(image, gt,depth)
113 | image=colorEnhance(image)
114 | # gt=randomGaussian(gt)
115 | gt=randomPeper(gt)
116 | image = self.img_transform(image)
117 | gt = self.gt_transform(gt)
118 | depth=self.depths_transform(depth)
119 |
120 | return image, gt, depth
121 |
122 | def filter_files(self):
123 | assert len(self.images) == len(self.gts) and len(self.gts)==len(self.images)
124 | images = []
125 | gts = []
126 | depths=[]
127 | for img_path, gt_path,depth_path in zip(self.images, self.gts, self.depths):
128 | img = Image.open(img_path)
129 | gt = Image.open(gt_path)
130 | depth= Image.open(depth_path)
131 | if img.size == gt.size and gt.size==depth.size:
132 | images.append(img_path)
133 | gts.append(gt_path)
134 | depths.append(depth_path)
135 | self.images = images
136 | self.gts = gts
137 | self.depths=depths
138 |
139 | def rgb_loader(self, path):
140 | with open(path, 'rb') as f:
141 | img = Image.open(f)
142 | return img.convert('RGB')
143 |
144 | def binary_loader(self, path):
145 | with open(path, 'rb') as f:
146 | img = Image.open(f)
147 | return img.convert('L')
148 |
149 | def resize(self, img, gt, depth):
150 | assert img.size == gt.size and gt.size==depth.size
151 | w, h = img.size
152 | if h < self.trainsize or w < self.trainsize:
153 | h = max(h, self.trainsize)
154 | w = max(w, self.trainsize)
155 | return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), Image.NEAREST),depth.resize((w, h), Image.NEAREST)
156 | else:
157 | return img, gt, depth
158 |
159 | def __len__(self):
160 | return self.size
161 |
162 | #dataloader for training
163 | def get_loader(image_root, gt_root,depth_root, batchsize, trainsize, shuffle=True, num_workers=12, pin_memory=True):
164 |
165 | dataset = SalObjDataset(image_root, gt_root, depth_root,trainsize)
166 | data_loader = data.DataLoader(dataset=dataset,
167 | batch_size=batchsize,
168 | shuffle=shuffle,
169 | num_workers=num_workers,
170 | pin_memory=pin_memory)
171 | return data_loader
172 |
173 | #test dataset and loader
174 | class test_dataset:
175 | def __init__(self, image_root, gt_root,depth_root, testsize):
176 | self.testsize = testsize
177 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')]
178 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg')
179 | or f.endswith('.png')]
180 | self.depths=[depth_root + f for f in os.listdir(depth_root) if f.endswith('.bmp')
181 | or f.endswith('.png')]
182 | self.images = sorted(self.images)
183 | self.gts = sorted(self.gts)
184 | self.depths=sorted(self.depths)
185 | self.transform = transforms.Compose([
186 | transforms.Resize((self.testsize, self.testsize)),
187 | transforms.ToTensor(),
188 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
189 | self.gt_transform = transforms.ToTensor()
190 | # self.gt_transform = transforms.Compose([
191 | # transforms.Resize((self.trainsize, self.trainsize)),
192 | # transforms.ToTensor()])
193 | self.depths_transform = transforms.Compose([transforms.Resize((self.testsize, self.testsize)),transforms.ToTensor()])
194 | self.size = len(self.images)
195 | self.index = 0
196 |
197 | def load_data(self):
198 | image = self.rgb_loader(self.images[self.index])
199 | image = self.transform(image).unsqueeze(0)
200 | gt = self.binary_loader(self.gts[self.index])
201 | depth=self.binary_loader(self.depths[self.index])
202 | depth=self.depths_transform(depth).unsqueeze(0)
203 | name = self.images[self.index].split('/')[-1]
204 | image_for_post=self.rgb_loader(self.images[self.index])
205 | image_for_post=image_for_post.resize(gt.size)
206 | if name.endswith('.jpg'):
207 | name = name.split('.jpg')[0] + '.png'
208 | self.index += 1
209 | self.index = self.index % self.size
210 | return image, gt,depth, name,np.array(image_for_post)
211 |
212 | def rgb_loader(self, path):
213 | with open(path, 'rb') as f:
214 | img = Image.open(f)
215 | return img.convert('RGB')
216 |
217 | def binary_loader(self, path):
218 | with open(path, 'rb') as f:
219 | img = Image.open(f)
220 | return img.convert('L')
221 | def __len__(self):
222 | return self.size
223 |
224 |
--------------------------------------------------------------------------------
/models/BBSNet_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torchvision.models as models
4 | from ResNet import ResNet50
5 | from torch.nn import functional as F
6 |
7 | def conv3x3(in_planes, out_planes, stride=1):
8 | "3x3 convolution with padding"
9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
10 | padding=1, bias=False)
11 | class TransBasicBlock(nn.Module):
12 | expansion = 1
13 |
14 | def __init__(self, inplanes, planes, stride=1, upsample=None, **kwargs):
15 | super(TransBasicBlock, self).__init__()
16 | self.conv1 = conv3x3(inplanes, inplanes)
17 | self.bn1 = nn.BatchNorm2d(inplanes)
18 | self.relu = nn.ReLU(inplace=True)
19 | if upsample is not None and stride != 1:
20 | self.conv2 = nn.ConvTranspose2d(inplanes, planes,
21 | kernel_size=3, stride=stride, padding=1,
22 | output_padding=1, bias=False)
23 | else:
24 | self.conv2 = conv3x3(inplanes, planes, stride)
25 | self.bn2 = nn.BatchNorm2d(planes)
26 | self.upsample = upsample
27 | self.stride = stride
28 |
29 | def forward(self, x):
30 | residual = x
31 |
32 | out = self.conv1(x)
33 | out = self.bn1(out)
34 | out = self.relu(out)
35 |
36 | out = self.conv2(out)
37 | out = self.bn2(out)
38 |
39 | if self.upsample is not None:
40 | residual = self.upsample(x)
41 |
42 | out += residual
43 | out = self.relu(out)
44 |
45 | return out
46 | class ChannelAttention(nn.Module):
47 | def __init__(self, in_planes, ratio=16):
48 | super(ChannelAttention, self).__init__()
49 |
50 | self.max_pool = nn.AdaptiveMaxPool2d(1)
51 |
52 | self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False)
53 | self.relu1 = nn.ReLU()
54 | self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False)
55 |
56 | self.sigmoid = nn.Sigmoid()
57 | def forward(self, x):
58 | max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
59 | out = max_out
60 | return self.sigmoid(out)
61 |
62 | class SpatialAttention(nn.Module):
63 | def __init__(self, kernel_size=7):
64 | super(SpatialAttention, self).__init__()
65 |
66 | assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
67 | padding = 3 if kernel_size == 7 else 1
68 |
69 | self.conv1 = nn.Conv2d(1, 1, kernel_size, padding=padding, bias=False)
70 | self.sigmoid = nn.Sigmoid()
71 |
72 | def forward(self, x):
73 | max_out, _ = torch.max(x, dim=1, keepdim=True)
74 | x=max_out
75 | x = self.conv1(x)
76 | return self.sigmoid(x)
77 |
78 | class BasicConv2d(nn.Module):
79 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1):
80 | super(BasicConv2d, self).__init__()
81 | self.conv = nn.Conv2d(in_planes, out_planes,
82 | kernel_size=kernel_size, stride=stride,
83 | padding=padding, dilation=dilation, bias=False)
84 | self.bn = nn.BatchNorm2d(out_planes)
85 | self.relu = nn.ReLU(inplace=True)
86 |
87 | def forward(self, x):
88 | x = self.conv(x)
89 | x = self.bn(x)
90 | return x
91 |
92 | #Global Contextual module
93 | class GCM(nn.Module):
94 | def __init__(self, in_channel, out_channel):
95 | super(GCM, self).__init__()
96 | self.relu = nn.ReLU(True)
97 | self.branch0 = nn.Sequential(
98 | BasicConv2d(in_channel, out_channel, 1),
99 | )
100 | self.branch1 = nn.Sequential(
101 | BasicConv2d(in_channel, out_channel, 1),
102 | BasicConv2d(out_channel, out_channel, kernel_size=(1, 3), padding=(0, 1)),
103 | BasicConv2d(out_channel, out_channel, kernel_size=(3, 1), padding=(1, 0)),
104 | BasicConv2d(out_channel, out_channel, 3, padding=3, dilation=3)
105 | )
106 | self.branch2 = nn.Sequential(
107 | BasicConv2d(in_channel, out_channel, 1),
108 | BasicConv2d(out_channel, out_channel, kernel_size=(1, 5), padding=(0, 2)),
109 | BasicConv2d(out_channel, out_channel, kernel_size=(5, 1), padding=(2, 0)),
110 | BasicConv2d(out_channel, out_channel, 3, padding=5, dilation=5)
111 | )
112 | self.branch3 = nn.Sequential(
113 | BasicConv2d(in_channel, out_channel, 1),
114 | BasicConv2d(out_channel, out_channel, kernel_size=(1, 7), padding=(0, 3)),
115 | BasicConv2d(out_channel, out_channel, kernel_size=(7, 1), padding=(3, 0)),
116 | BasicConv2d(out_channel, out_channel, 3, padding=7, dilation=7)
117 | )
118 | self.conv_cat = BasicConv2d(4*out_channel, out_channel, 3, padding=1)
119 | self.conv_res = BasicConv2d(in_channel, out_channel, 1)
120 |
121 | def forward(self, x):
122 | x0 = self.branch0(x)
123 | x1 = self.branch1(x)
124 | x2 = self.branch2(x)
125 | x3 = self.branch3(x)
126 |
127 | x_cat = self.conv_cat(torch.cat((x0, x1, x2, x3), 1))
128 |
129 | x = self.relu(x_cat + self.conv_res(x))
130 | return x
131 |
132 | #aggregation of the high-level(teacher) features
133 | class aggregation_init(nn.Module):
134 |
135 | def __init__(self, channel):
136 | super(aggregation_init, self).__init__()
137 | self.relu = nn.ReLU(True)
138 |
139 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
140 | self.conv_upsample1 = BasicConv2d(channel, channel, 3, padding=1)
141 | self.conv_upsample2 = BasicConv2d(channel, channel, 3, padding=1)
142 | self.conv_upsample3 = BasicConv2d(channel, channel, 3, padding=1)
143 | self.conv_upsample4 = BasicConv2d(channel, channel, 3, padding=1)
144 | self.conv_upsample5 = BasicConv2d(2*channel, 2*channel, 3, padding=1)
145 |
146 | self.conv_concat2 = BasicConv2d(2*channel, 2*channel, 3, padding=1)
147 | self.conv_concat3 = BasicConv2d(3*channel, 3*channel, 3, padding=1)
148 | self.conv4 = BasicConv2d(3*channel, 3*channel, 3, padding=1)
149 | self.conv5 = nn.Conv2d(3*channel, 1, 1)
150 |
151 | def forward(self, x1, x2, x3):
152 | x1_1 = x1
153 | x2_1 = self.conv_upsample1(self.upsample(x1)) * x2
154 | x3_1 = self.conv_upsample2(self.upsample(self.upsample(x1))) \
155 | * self.conv_upsample3(self.upsample(x2)) * x3
156 |
157 | x2_2 = torch.cat((x2_1, self.conv_upsample4(self.upsample(x1_1))), 1)
158 | x2_2 = self.conv_concat2(x2_2)
159 |
160 | x3_2 = torch.cat((x3_1, self.conv_upsample5(self.upsample(x2_2))), 1)
161 | x3_2 = self.conv_concat3(x3_2)
162 |
163 | x = self.conv4(x3_2)
164 | x = self.conv5(x)
165 |
166 | return x
167 |
168 | #aggregation of the low-level(student) features
169 | class aggregation_final(nn.Module):
170 |
171 | def __init__(self, channel):
172 | super(aggregation_final, self).__init__()
173 | self.relu = nn.ReLU(True)
174 |
175 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
176 | self.conv_upsample1 = BasicConv2d(channel, channel, 3, padding=1)
177 | self.conv_upsample2 = BasicConv2d(channel, channel, 3, padding=1)
178 | self.conv_upsample3 = BasicConv2d(channel, channel, 3, padding=1)
179 | self.conv_upsample4 = BasicConv2d(channel, channel, 3, padding=1)
180 | self.conv_upsample5 = BasicConv2d(2*channel, 2*channel, 3, padding=1)
181 |
182 | self.conv_concat2 = BasicConv2d(2*channel, 2*channel, 3, padding=1)
183 | self.conv_concat3 = BasicConv2d(3*channel, 3*channel, 3, padding=1)
184 |
185 | def forward(self, x1, x2, x3):
186 | x1_1 = x1
187 | x2_1 = self.conv_upsample1(self.upsample(x1)) * x2
188 | x3_1 = self.conv_upsample2(self.upsample(x1)) \
189 | * self.conv_upsample3(x2) * x3
190 |
191 | x2_2 = torch.cat((x2_1, self.conv_upsample4(self.upsample(x1_1))), 1)
192 | x2_2 = self.conv_concat2(x2_2)
193 |
194 | x3_2 = torch.cat((x3_1, self.conv_upsample5(x2_2)), 1)
195 | x3_2 = self.conv_concat3(x3_2)
196 |
197 | return x3_2
198 |
199 | #Refinement flow
200 | class Refine(nn.Module):
201 | def __init__(self):
202 | super(Refine,self).__init__()
203 | self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
204 |
205 | def forward(self, attention,x1,x2,x3):
206 | #Note that there is an error in the manuscript. In the paper, the refinement strategy is depicted as ""f'=f*S1"", it should be ""f'=f+f*S1"".
207 | x1 = x1+torch.mul(x1, self.upsample2(attention))
208 | x2 = x2+torch.mul(x2,self.upsample2(attention))
209 | x3 = x3+torch.mul(x3,attention)
210 |
211 | return x1,x2,x3
212 |
213 | #BBSNet
214 | class BBSNet(nn.Module):
215 | def __init__(self, channel=32):
216 | super(BBSNet, self).__init__()
217 |
218 | #Backbone model
219 | self.resnet = ResNet50('rgb')
220 | self.resnet_depth=ResNet50('rgbd')
221 |
222 | #Decoder 1
223 | self.rfb2_1 = GCM(512, channel)
224 | self.rfb3_1 = GCM(1024, channel)
225 | self.rfb4_1 = GCM(2048, channel)
226 | self.agg1 = aggregation_init(channel)
227 |
228 | #Decoder 2
229 | self.rfb0_2 = GCM(64, channel)
230 | self.rfb1_2 = GCM(256, channel)
231 | self.rfb5_2 = GCM(512, channel)
232 | self.agg2 = aggregation_final(channel)
233 |
234 | #upsample function
235 | self.upsample = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True)
236 | self.upsample4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
237 | self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
238 |
239 | #Refinement flow
240 | self.HA = Refine()
241 |
242 | #Components of DEM module
243 | self.atten_depth_channel_0=ChannelAttention(64)
244 | self.atten_depth_channel_1=ChannelAttention(256)
245 | self.atten_depth_channel_2=ChannelAttention(512)
246 | self.atten_depth_channel_3_1=ChannelAttention(1024)
247 | self.atten_depth_channel_4_1=ChannelAttention(2048)
248 |
249 | self.atten_depth_spatial_0=SpatialAttention()
250 | self.atten_depth_spatial_1=SpatialAttention()
251 | self.atten_depth_spatial_2=SpatialAttention()
252 | self.atten_depth_spatial_3_1=SpatialAttention()
253 | self.atten_depth_spatial_4_1=SpatialAttention()
254 |
255 | #Components of PTM module
256 | self.inplanes = 32*2
257 | self.deconv1 = self._make_transpose(TransBasicBlock, 32*2, 3, stride=2)
258 | self.inplanes =32
259 | self.deconv2 = self._make_transpose(TransBasicBlock, 32, 3, stride=2)
260 | self.agant1 = self._make_agant_layer(32*3, 32*2)
261 | self.agant2 = self._make_agant_layer(32*2, 32)
262 | self.out0_conv = nn.Conv2d(32*3, 1, kernel_size=1, stride=1, bias=True)
263 | self.out1_conv = nn.Conv2d(32*2, 1, kernel_size=1, stride=1, bias=True)
264 | self.out2_conv = nn.Conv2d(32*1, 1, kernel_size=1, stride=1, bias=True)
265 |
266 | if self.training:
267 | self.initialize_weights()
268 |
269 | def forward(self, x, x_depth):
270 | x = self.resnet.conv1(x)
271 | x = self.resnet.bn1(x)
272 | x = self.resnet.relu(x)
273 | x = self.resnet.maxpool(x)
274 |
275 | x_depth = self.resnet_depth.conv1(x_depth)
276 | x_depth = self.resnet_depth.bn1(x_depth)
277 | x_depth = self.resnet_depth.relu(x_depth)
278 | x_depth = self.resnet_depth.maxpool(x_depth)
279 |
280 | #layer0 merge
281 | temp = x_depth.mul(self.atten_depth_channel_0(x_depth))
282 | temp = temp.mul(self.atten_depth_spatial_0(temp))
283 | x=x+temp
284 | #layer0 merge end
285 |
286 | x1 = self.resnet.layer1(x) # 256 x 64 x 64
287 | x1_depth=self.resnet_depth.layer1(x_depth)
288 |
289 | #layer1 merge
290 | temp = x1_depth.mul(self.atten_depth_channel_1(x1_depth))
291 | temp = temp.mul(self.atten_depth_spatial_1(temp))
292 | x1=x1+temp
293 | #layer1 merge end
294 |
295 | x2 = self.resnet.layer2(x1) # 512 x 32 x 32
296 | x2_depth=self.resnet_depth.layer2(x1_depth)
297 |
298 | #layer2 merge
299 | temp = x2_depth.mul(self.atten_depth_channel_2(x2_depth))
300 | temp = temp.mul(self.atten_depth_spatial_2(temp))
301 | x2=x2+temp
302 | #layer2 merge end
303 |
304 | x2_1 = x2
305 |
306 | x3_1 = self.resnet.layer3_1(x2_1) # 1024 x 16 x 16
307 | x3_1_depth=self.resnet_depth.layer3_1(x2_depth)
308 |
309 | #layer3_1 merge
310 | temp = x3_1_depth.mul(self.atten_depth_channel_3_1(x3_1_depth))
311 | temp = temp.mul(self.atten_depth_spatial_3_1(temp))
312 | x3_1=x3_1+temp
313 | #layer3_1 merge end
314 |
315 | x4_1 = self.resnet.layer4_1(x3_1) # 2048 x 8 x 8
316 | x4_1_depth=self.resnet_depth.layer4_1(x3_1_depth)
317 |
318 | #layer4_1 merge
319 | temp = x4_1_depth.mul(self.atten_depth_channel_4_1(x4_1_depth))
320 | temp = temp.mul(self.atten_depth_spatial_4_1(temp))
321 | x4_1=x4_1+temp
322 | #layer4_1 merge end
323 |
324 | #produce initial saliency map by decoder1
325 | x2_1 = self.rfb2_1(x2_1)
326 | x3_1 = self.rfb3_1(x3_1)
327 | x4_1 = self.rfb4_1(x4_1)
328 | attention_map = self.agg1(x4_1, x3_1, x2_1)
329 |
330 | #Refine low-layer features by initial map
331 | x,x1,x5 = self.HA(attention_map.sigmoid(), x,x1,x2)
332 |
333 | #produce final saliency map by decoder2
334 | x0_2 = self.rfb0_2(x)
335 | x1_2 = self.rfb1_2(x1)
336 | x5_2 = self.rfb5_2(x5)
337 | y = self.agg2(x5_2, x1_2, x0_2) #*4
338 |
339 | #PTM module
340 | y =self.agant1(y)
341 | y = self.deconv1(y)
342 | y = self.agant2(y)
343 | y = self.deconv2(y)
344 | y = self.out2_conv(y)
345 |
346 | return self.upsample(attention_map),y
347 |
348 | def _make_agant_layer(self, inplanes, planes):
349 | layers = nn.Sequential(
350 | nn.Conv2d(inplanes, planes, kernel_size=1,
351 | stride=1, padding=0, bias=False),
352 | nn.BatchNorm2d(planes),
353 | nn.ReLU(inplace=True)
354 | )
355 | return layers
356 |
357 | def _make_transpose(self, block, planes, blocks, stride=1):
358 | upsample = None
359 | if stride != 1:
360 | upsample = nn.Sequential(
361 | nn.ConvTranspose2d(self.inplanes, planes,
362 | kernel_size=2, stride=stride,
363 | padding=0, bias=False),
364 | nn.BatchNorm2d(planes),
365 | )
366 | elif self.inplanes != planes:
367 | upsample = nn.Sequential(
368 | nn.Conv2d(self.inplanes, planes,
369 | kernel_size=1, stride=stride, bias=False),
370 | nn.BatchNorm2d(planes),
371 | )
372 |
373 | layers = []
374 |
375 | for i in range(1, blocks):
376 | layers.append(block(self.inplanes, self.inplanes))
377 |
378 | layers.append(block(self.inplanes, planes, stride, upsample))
379 | self.inplanes = planes
380 |
381 | return nn.Sequential(*layers)
382 |
383 | #initialize the weights
384 | def initialize_weights(self):
385 | res50 = models.resnet50(pretrained=True)
386 | pretrained_dict = res50.state_dict()
387 | all_params = {}
388 | for k, v in self.resnet.state_dict().items():
389 | if k in pretrained_dict.keys():
390 | v = pretrained_dict[k]
391 | all_params[k] = v
392 | elif '_1' in k:
393 | name = k.split('_1')[0] + k.split('_1')[1]
394 | v = pretrained_dict[name]
395 | all_params[k] = v
396 | elif '_2' in k:
397 | name = k.split('_2')[0] + k.split('_2')[1]
398 | v = pretrained_dict[name]
399 | all_params[k] = v
400 | assert len(all_params.keys()) == len(self.resnet.state_dict().keys())
401 | self.resnet.load_state_dict(all_params)
402 |
403 | all_params = {}
404 | for k, v in self.resnet_depth.state_dict().items():
405 | if k=='conv1.weight':
406 | all_params[k]=torch.nn.init.normal_(v, mean=0, std=1)
407 | elif k in pretrained_dict.keys():
408 | v = pretrained_dict[k]
409 | all_params[k] = v
410 | elif '_1' in k:
411 | name = k.split('_1')[0] + k.split('_1')[1]
412 | v = pretrained_dict[name]
413 | all_params[k] = v
414 | elif '_2' in k:
415 | name = k.split('_2')[0] + k.split('_2')[1]
416 | v = pretrained_dict[name]
417 | all_params[k] = v
418 | assert len(all_params.keys()) == len(self.resnet_depth.state_dict().keys())
419 | self.resnet_depth.load_state_dict(all_params)
420 |
421 |
--------------------------------------------------------------------------------
/models/ResNet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import math
3 |
4 |
5 | def conv3x3(in_planes, out_planes, stride=1):
6 | """3x3 convolution with padding"""
7 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
8 | padding=1, bias=False)
9 |
10 |
11 | class BasicBlock(nn.Module):
12 | expansion = 1
13 |
14 | def __init__(self, inplanes, planes, stride=1, downsample=None):
15 | super(BasicBlock, self).__init__()
16 | self.conv1 = conv3x3(inplanes, planes, stride)
17 | self.bn1 = nn.BatchNorm2d(planes)
18 | self.relu = nn.ReLU(inplace=True)
19 | self.conv2 = conv3x3(planes, planes)
20 | self.bn2 = nn.BatchNorm2d(planes)
21 | self.downsample = downsample
22 | self.stride = stride
23 |
24 | def forward(self, x):
25 | residual = x
26 |
27 | out = self.conv1(x)
28 | out = self.bn1(out)
29 | out = self.relu(out)
30 |
31 | out = self.conv2(out)
32 | out = self.bn2(out)
33 |
34 | if self.downsample is not None:
35 | residual = self.downsample(x)
36 |
37 | out += residual
38 | out = self.relu(out)
39 |
40 | return out
41 |
42 |
43 | class Bottleneck(nn.Module):
44 | expansion = 4
45 |
46 | def __init__(self, inplanes, planes, stride=1, downsample=None):
47 | super(Bottleneck, self).__init__()
48 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
49 | self.bn1 = nn.BatchNorm2d(planes)
50 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
51 | padding=1, bias=False)
52 | self.bn2 = nn.BatchNorm2d(planes)
53 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
54 | self.bn3 = nn.BatchNorm2d(planes * 4)
55 | self.relu = nn.ReLU(inplace=True)
56 | self.downsample = downsample
57 | self.stride = stride
58 |
59 | def forward(self, x):
60 | residual = x
61 |
62 | out = self.conv1(x)
63 | out = self.bn1(out)
64 | out = self.relu(out)
65 |
66 | out = self.conv2(out)
67 | out = self.bn2(out)
68 | out = self.relu(out)
69 |
70 | out = self.conv3(out)
71 | out = self.bn3(out)
72 |
73 | if self.downsample is not None:
74 | residual = self.downsample(x)
75 |
76 | out += residual
77 | out = self.relu(out)
78 |
79 | return out
80 |
81 |
82 | class ResNet50(nn.Module):
83 | def __init__(self,mode='rgb'):
84 | self.inplanes = 64
85 | super(ResNet50, self).__init__()
86 | if(mode=='rgb'):
87 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
88 | bias=False)
89 | elif(mode=='rgbd'):
90 | self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3,
91 | bias=False)
92 | elif(mode=="share"):
93 | self.conv1=nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
94 | bias=False)
95 | self.conv1_d=nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3,
96 | bias=False)
97 | else:
98 | raise
99 | self.bn1 = nn.BatchNorm2d(64)
100 | self.relu = nn.ReLU(inplace=True)
101 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
102 | self.layer1 = self._make_layer(Bottleneck, 64, 3)
103 | self.layer2 = self._make_layer(Bottleneck, 128, 4, stride=2)
104 | self.layer3_1 = self._make_layer(Bottleneck, 256, 6, stride=2)
105 | self.layer4_1 = self._make_layer(Bottleneck, 512, 3, stride=2)
106 |
107 | self.inplanes = 512
108 |
109 | for m in self.modules():
110 | if isinstance(m, nn.Conv2d):
111 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
112 | m.weight.data.normal_(0, math.sqrt(2. / n))
113 | elif isinstance(m, nn.BatchNorm2d):
114 | m.weight.data.fill_(1)
115 | m.bias.data.zero_()
116 |
117 | def _make_layer(self, block, planes, blocks, stride=1):
118 | downsample = None
119 | if stride != 1 or self.inplanes != planes * block.expansion:
120 | downsample = nn.Sequential(
121 | nn.Conv2d(self.inplanes, planes * block.expansion,
122 | kernel_size=1, stride=stride, bias=False),
123 | nn.BatchNorm2d(planes * block.expansion),
124 | )
125 |
126 | layers = []
127 | layers.append(block(self.inplanes, planes, stride, downsample))
128 | self.inplanes = planes * block.expansion
129 | for i in range(1, blocks):
130 | layers.append(block(self.inplanes, planes))
131 |
132 | return nn.Sequential(*layers)
133 |
134 | def forward(self, x):
135 | x = self.conv1(x)
136 | x = self.bn1(x)
137 | x = self.relu(x)
138 | x = self.maxpool(x)
139 |
140 | x = self.layer1(x)
141 | x = self.layer2(x)
142 | x1 = self.layer3_1(x)
143 | x1 = self.layer4_1(x1)
144 |
145 | return x1, x1
146 |
--------------------------------------------------------------------------------
/models/__pycache__/BBSNet_model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zyjwuyan/BBS-Net/d48bd4f658844d78ca5b39d36bae65ce50688a00/models/__pycache__/BBSNet_model.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/ResNet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zyjwuyan/BBS-Net/d48bd4f658844d78ca5b39d36bae65ce50688a00/models/__pycache__/ResNet.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/ResNet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zyjwuyan/BBS-Net/d48bd4f658844d78ca5b39d36bae65ce50688a00/models/__pycache__/ResNet.cpython-37.pyc
--------------------------------------------------------------------------------
/options.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | parser = argparse.ArgumentParser()
3 | parser.add_argument('--epoch', type=int, default=200, help='epoch number')
4 | parser.add_argument('--lr', type=float, default=1e-4, help='learning rate')
5 | parser.add_argument('--batchsize', type=int, default=10, help='training batch size')
6 | parser.add_argument('--trainsize', type=int, default=352, help='training dataset size')
7 | parser.add_argument('--clip', type=float, default=0.5, help='gradient clipping margin')
8 | parser.add_argument('--decay_rate', type=float, default=0.1, help='decay rate of learning rate')
9 | parser.add_argument('--decay_epoch', type=int, default=60, help='every n epochs decay learning rate')
10 | parser.add_argument('--load', type=str, default=None, help='train from checkpoints')
11 | parser.add_argument('--gpu_id', type=str, default='0', help='train use gpu')
12 | parser.add_argument('--rgb_root', type=str, default='../BBS_dataset/RGBD_for_train/RGB/', help='the training rgb images root')
13 | parser.add_argument('--depth_root', type=str, default='../BBS_dataset/RGBD_for_train/depth/', help='the training depth images root')
14 | parser.add_argument('--gt_root', type=str, default='../BBS_dataset/RGBD_for_train/GT/', help='the training gt images root')
15 | parser.add_argument('--test_rgb_root', type=str, default='../BBS_dataset/test_in_train/RGB/', help='the test rgb images root')
16 | parser.add_argument('--test_depth_root', type=str, default='../BBS_dataset/test_in_train/depth/', help='the test depth images root')
17 | parser.add_argument('--test_gt_root', type=str, default='../BBS_dataset/test_in_train/GT/', help='the test gt images root')
18 | parser.add_argument('--save_path', type=str, default='./BBSNet_cpts/', help='the path to save models and logs')
19 | opt = parser.parse_args()
20 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | def clip_gradient(optimizer, grad_clip):
2 | for group in optimizer.param_groups:
3 | for param in group['params']:
4 | if param.grad is not None:
5 | param.grad.data.clamp_(-grad_clip, grad_clip)
6 |
7 |
8 | def adjust_lr(optimizer, init_lr, epoch, decay_rate=0.1, decay_epoch=30):
9 | decay = decay_rate ** (epoch // decay_epoch)
10 | for param_group in optimizer.param_groups:
11 | param_group['lr'] = decay*init_lr
12 | lr=param_group['lr']
13 | return lr
14 |
--------------------------------------------------------------------------------