├── .DS_Store ├── LICENSE ├── README.md ├── dataset_loader.py ├── demo.py ├── figure ├── dataset.png └── overall.png ├── functions.py ├── fusion.py ├── model.py ├── trainer.py └── transform.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DUT-IIAU-OIP-Lab/DMRA_RGBD-SOD/ca9dc57e50c425293c1c131bbee25abd10f62b89/.DS_Store -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 wei ji 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DMRA_RGBD-SOD 2 | Code repository for our paper entilted "Depth-induced Multi-scale Recurrent Attention Network for Saliency Detection" accepted at ICCV 2019 (poster). 3 | 4 | # Overall 5 | ![avatar](https://github.com/jiwei0921/DMRA/blob/master/figure/overall.png) 6 | 7 | ## The proposed Dataset 8 | + Dataset: DUTLF-V1 9 | 1. This dataset consists of DUTLF-MV, DUTLF-FS, DUTLF-Depth. 10 | 2. The dataset will be expanded to 3000 about real scenes. 11 | 3. We are working on it and will make it publicly available soon. 12 | + Dataset: DUTLF-Depth 13 | 1. The dataset is part of DUTLF dataset captured by Lytro camera, and we selected a more accurate 1200 depth map pairs for more accurate RGB-D saliency detection. 14 | 2. We create a large scale RGB-D dataset(DUTLF-Depth) with 1200 paired images containing more complex scenarios, such as multiple or transparent objects, similar foreground and background, complex background, low-intensity environment. This challenging dataset can contribute to comprehensively evaluating saliency models. 15 | 16 | ![avatar](https://github.com/jiwei0921/DMRA/blob/master/figure/dataset.png) 17 | + The **dataset link** can be found [here](https://pan.baidu.com/s/1FwUFmNBox_gMZ0CVjby2dg). And we split the dataset including 800 training set and 400 test set. 18 | 19 | ## DMRA Code 20 | 21 | ### > Requirment 22 | + pytorch 0.3.0+ 23 | + torchvision 24 | + PIL 25 | + numpy 26 | 27 | ### > Usage 28 | #### 1. Clone the repo 29 | ``` 30 | git clone https://github.com/jiwei0921/DMRA.git 31 | cd DMRA/ 32 | ``` 33 | #### 2. Train/Test 34 | + test 35 | Download related dataset [**link**](https://github.com/jiwei0921/RGBD-SOD-datasets), and set the param '--phase' as "**test**" and '--param' as '**True**' in ```demo.py```. Meanwhile, you need to set **dataset path** and **checkpoint name** correctly. 36 | ``` 37 | python demo.py 38 | ``` 39 | + train 40 | Our train-augment dataset [**link**](https://pan.baidu.com/s/18nVAiOkTKczB_ZpIzBHA0A) [ fetch code **haxl** ] / [train-ori dataset](https://pan.baidu.com/s/1B8PS4SXT7ISd-M6vAlrv_g), and set the param '--phase' as "**train**" and '--param' as '**True**'(loading checkpoint) or '**False**'(no loading checkpoint) in ```demo.py```. Meanwhile, you need to set **dataset path** and **checkpoint name** correctly. 41 | ``` 42 | python demo.py 43 | ``` 44 | 45 | ### > Training info and pre-trained models for DMRA 46 | To better understand, we retrain our network and record some detailed training details as well as corresponding pre-trained models. 47 | 48 | **Iterations** | **Loss** | NJUD(F-measure) | NJUD(MAE) | NLPR(F-measure) | NLPR(MAE) | download link 49 | :-: | :-: | :-: | :-: | :-: | :-: | :-: | 50 | 100W | 958 | 0.882 | 0.048 | 0.867 | 0.031 | [link](https://pan.baidu.com/s/1Hb0CDDH7vG6F9yxl6wTymQ) 51 | 70W | 2413 | 0.876 | 0.050 | 0.854 | 0.033 | [link](https://pan.baidu.com/s/19SvkoKrkLPHFJUa_9z4ulg) 52 | 40W | 3194 | 0.861 | 0.056 | 0.823 | 0.037 | [link](https://pan.baidu.com/s/1_1ihh0TIm9pwQ4nyNSXKDQ) 53 | 16W | 8260 | 0.805 | 0.081 | 0.725 | 0.056 | [link](https://pan.baidu.com/s/1BzCOBV5HKNLAJcON0ImqyQ) 54 | 2W | 33494 | 0.009 | 0.470 | 0.030 | 0.452 | [link](https://pan.baidu.com/s/1QUJsr3oPOCUJsJu8nCHbHQ) 55 | 0W | 45394 | - | - | - | - | - 56 | 57 | + Tips: **The results of the paper shall prevail.** Because of the randomness of the training process, the results fluctuated slightly. 58 | 59 | 60 | ### > Results 61 | | [DUTLF-Depth](https://pan.baidu.com/s/1mS9EzoyY_ULXb3BCSd21eA) | 62 | | [NJUD](https://pan.baidu.com/s/1smz7KQbCPPClw58bDheH4w) | 63 | | [NLPR](https://pan.baidu.com/s/19qJkHtFQGV9oVtEFWY_ctg) | 64 | | [STEREO](https://pan.baidu.com/s/1L11R1c51mMPTrfpW6ykGjA) | 65 | | [LFSD](https://pan.baidu.com/s/1asgu1fGsHRk4CZcbz0NYxA) | 66 | | [RGBD135](https://pan.baidu.com/s/1jRYgoAijf_digGLQnsSbhA) | 67 | | [SSD](https://pan.baidu.com/s/1VY4I-4qpWS3wewz0MC8kqA) | 68 | + Note: For evaluation, all results are implemented on this ready-to-use [toolbox](https://github.com/jiwei0921/Saliency-Evaluation-Toolbox). 69 | 70 | ### > Related RGB-D Saliency Datasets 71 | All common RGB-D Saliency Datasets we collected are shared in ready-to-use manner. 72 | + The web link is [here](https://github.com/jiwei0921/RGBD-SOD-datasets). 73 | 74 | ### If you think this work is helpful, please cite 75 | ``` 76 | @InProceedings{Piao_2019_ICCV, 77 | author = {Yongri {Piao} and Wei {Ji} and Jingjing {Li} and Miao {Zhang} and Huchuan {Lu}}, 78 | title = {Depth-induced Multi-scale Recurrent Attention Network for Saliency Detection}, 79 | booktitle = "ICCV", 80 | year = {2019} 81 | } 82 | ``` 83 | 84 | ### Contact Us 85 | If you have any questions, please contact us ( jiwei521@mail.dlut.edu.cn ). 86 | -------------------------------------------------------------------------------- /dataset_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import PIL.Image 5 | import scipy.io as sio 6 | import torch 7 | from torch.utils import data 8 | 9 | class MyData(data.Dataset): # inherit 10 | """ 11 | load data in a folder 12 | """ 13 | mean_rgb = np.array([0.447, 0.407, 0.386]) 14 | std_rgb = np.array([0.244, 0.250, 0.253]) 15 | def __init__(self, root, transform=False): 16 | super(MyData, self).__init__() 17 | self.root = root 18 | 19 | self._transform = transform 20 | 21 | img_root = os.path.join(self.root, 'train_images') 22 | lbl_root = os.path.join(self.root, 'train_masks') 23 | depth_root = os.path.join(self.root, 'train_depth') 24 | 25 | file_names = os.listdir(img_root) 26 | self.img_names = [] 27 | self.lbl_names = [] 28 | self.depth_names = [] 29 | for i, name in enumerate(file_names): 30 | if not name.endswith('.jpg'): 31 | continue 32 | self.lbl_names.append( 33 | os.path.join(lbl_root, name[:-4]+'.png') 34 | ) 35 | self.img_names.append( 36 | os.path.join(img_root, name) 37 | ) 38 | self.depth_names.append( 39 | os.path.join(depth_root, name[:-4]+'.png') 40 | ) 41 | 42 | def __len__(self): 43 | return len(self.img_names) 44 | 45 | def __getitem__(self, index): 46 | # load image 47 | img_file = self.img_names[index] 48 | img = PIL.Image.open(img_file) 49 | # img = img.resize((256, 256)) 50 | img = np.array(img, dtype=np.uint8) 51 | # load label 52 | lbl_file = self.lbl_names[index] 53 | lbl = PIL.Image.open(lbl_file) 54 | # lbl = lbl.resize((256, 256)) 55 | lbl = np.array(lbl, dtype=np.int32) 56 | lbl[lbl != 0] = 1 57 | # load depth 58 | depth_file = self.depth_names[index] 59 | depth = PIL.Image.open(depth_file) 60 | # depth = depth.resize(256, 256) 61 | depth = np.array(depth, dtype=np.uint8) 62 | 63 | 64 | 65 | if self._transform: 66 | return self.transform(img, lbl, depth) 67 | else: 68 | return img, lbl, depth 69 | 70 | 71 | # Translating numpy_array into format that pytorch can use on Code. 72 | def transform(self, img, lbl, depth): 73 | 74 | img = img.astype(np.float64)/255.0 75 | img -= self.mean_rgb 76 | img /= self.std_rgb 77 | img = img.transpose(2, 0, 1) # to verify 78 | img = torch.from_numpy(img).float() 79 | lbl = torch.from_numpy(lbl).long() 80 | depth = depth.astype(np.float64)/255.0 81 | depth = torch.from_numpy(depth).float() 82 | return img, lbl, depth 83 | 84 | 85 | class MyTestData(data.Dataset): 86 | """ 87 | load data in a folder 88 | """ 89 | mean_rgb = np.array([0.447, 0.407, 0.386]) 90 | std_rgb = np.array([0.244, 0.250, 0.253]) 91 | 92 | 93 | def __init__(self, root, transform=False): 94 | super(MyTestData, self).__init__() 95 | self.root = root 96 | self._transform = transform 97 | 98 | img_root = os.path.join(self.root, 'test_images') 99 | depth_root = os.path.join(self.root, 'test_depth') 100 | file_names = os.listdir(img_root) 101 | self.img_names = [] 102 | self.names = [] 103 | self.depth_names = [] 104 | 105 | for i, name in enumerate(file_names): 106 | if not name.endswith('.jpg'): 107 | continue 108 | self.img_names.append( 109 | os.path.join(img_root, name) 110 | ) 111 | self.names.append(name[:-4]) 112 | self.depth_names.append( 113 | # os.path.join(depth_root, name[:-4]+'_depth.png') # Test RGBD135 dataset 114 | os.path.join(depth_root, name[:-4] + '.png') 115 | ) 116 | 117 | def __len__(self): 118 | return len(self.img_names) 119 | 120 | def __getitem__(self, index): 121 | # load image 122 | img_file = self.img_names[index] 123 | img = PIL.Image.open(img_file) 124 | img_size = img.size 125 | # img = img.resize((256, 256)) 126 | img = np.array(img, dtype=np.uint8) 127 | 128 | # load focal 129 | depth_file = self.depth_names[index] 130 | depth = PIL.Image.open(depth_file) 131 | # depth = depth.resize(256, 256) 132 | depth = np.array(depth, dtype=np.uint8) 133 | if self._transform: 134 | img, focal = self.transform(img, depth) 135 | return img, focal, self.names[index], img_size 136 | else: 137 | return img, depth, self.names[index], img_size 138 | 139 | def transform(self, img, depth): 140 | img = img.astype(np.float64)/255.0 141 | img -= self.mean_rgb 142 | img /= self.std_rgb 143 | img = img.transpose(2, 0, 1) 144 | img = torch.from_numpy(img).float() 145 | 146 | depth = depth.astype(np.float64)/255.0 147 | depth = torch.from_numpy(depth).float() 148 | 149 | return img, depth 150 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | """ Title: Depth-induced Multi-scale Recurrent Attention Network for Saliency Detection Author: Wei Ji, Jingjing Li E-mail: weiji.dlut@gmail.com """ import torch from torch.autograd import Variable from torch.utils.data import DataLoader import torchvision import torch.nn.functional as F import torch.optim as optim from dataset_loader import MyData, MyTestData from model import RGBNet,DepthNet from fusion import ConvLSTM from functions import imsave import argparse from trainer import Trainer import os configurations = { # same configuration as original work # https://github.com/shelhamer/fcn.berkeleyvision.org 1: dict( max_iteration=1000000, lr=1.0e-10, momentum=0.99, weight_decay=0.0005, spshot=20000, nclass=2, sshow=10, ) } parser=argparse.ArgumentParser() parser.add_argument('--phase', type=str, default='test', help='train or test') parser.add_argument('--param', type=str, default=True, help='path to pre-trained parameters') # parser.add_argument('--train_dataroot', type=str, default='/home/jiwei-computer/Documents/Depth_data/train_data', help='path to train data') parser.add_argument('--train_dataroot', type=str, default='/home/jiwei-computer/Documents/Depth_data/train_data-augment', help='path to train data') parser.add_argument('--test_dataroot', type=str, default='/home/jiwei-computer/Documents/Depth_data/DUT-RGBD/test_data', help='path to test data') # parser.add_argument('--test_dataroot', type=str, default='/home/jiwei-computer/Documents/Depth_data/NJUD/test_data', help='path to test data') # parser.add_argument('--test_dataroot', type=str, default='/home/jiwei-computer/Documents/Depth_data/NLPR/test_data', help='path to test data') # parser.add_argument('--test_dataroot', type=str, default='/home/jiwei-computer/Documents/Depth_data/LFSD', help='path to test data') # parser.add_argument('--test_dataroot', type=str, default='/home/jiwei-computer/Documents/Depth_data/SSD', help='path to test data') # parser.add_argument('--test_dataroot', type=str, default='/home/jiwei-computer/Documents/Depth_data/STEREO', help='path to test data') # parser.add_argument('--test_dataroot', type=str, default='/home/jiwei-computer/Documents/Depth_data/RGBD135', help='path to test data') # Need to set dataset_loader.py/line 113 parser.add_argument('--snapshot_root', type=str, default='./snapshot', help='path to snapshot') parser.add_argument('--salmap_root', type=str, default='./sal_map', help='path to saliency map') parser.add_argument('-c', '--config', type=int, default=1, choices=configurations.keys()) args = parser.parse_args() cfg = configurations[args.config] cuda = torch.cuda.is_available() """""""""""~~~ dataset loader ~~~""""""""" train_dataRoot = args.train_dataroot test_dataRoot = args.test_dataroot if not os.path.exists(args.snapshot_root): os.mkdir(args.snapshot_root) if not os.path.exists(args.salmap_root): os.mkdir(args.salmap_root) if args.phase == 'train': SnapRoot = args.snapshot_root # checkpoint train_loader = torch.utils.data.DataLoader(MyData(train_dataRoot, transform=True), batch_size=2, shuffle=True, num_workers=4, pin_memory=True) else: MapRoot = args.salmap_root test_loader = torch.utils.data.DataLoader(MyTestData(test_dataRoot, transform=True), batch_size=1, shuffle=True, num_workers=4, pin_memory=True) print ('data already') """"""""""" ~~~nets~~~ """"""""" start_epoch = 0 start_iteration = 0 model_rgb = RGBNet(cfg['nclass']) model_depth = DepthNet(cfg['nclass']) model_clstm = ConvLSTM(input_channels=64, hidden_channels=[64, 32, 64], kernel_size=5, step=4, effective_step=[2, 4, 8]) if args.param is True: model_rgb.load_state_dict(torch.load(os.path.join(args.snapshot_root, 'snapshot_iter_1000000.pth'))) model_depth.load_state_dict(torch.load(os.path.join(args.snapshot_root, 'depth_snapshot_iter_1000000.pth'))) model_clstm.load_state_dict(torch.load(os.path.join(args.snapshot_root, 'clstm_snapshot_iter_1000000.pth'))) else: vgg19_bn = torchvision.models.vgg19_bn(pretrained=True) model_rgb.copy_params_from_vgg19_bn(vgg19_bn) model_depth.copy_params_from_vgg19_bn(vgg19_bn) if cuda: model_rgb = model_rgb.cuda() model_depth = model_depth.cuda() model_clstm = model_clstm.cuda() if args.phase == 'train': # Trainer: class, defined in trainer.py optimizer_rgb = optim.SGD(model_rgb.parameters(), lr=cfg['lr'],momentum=cfg['momentum'], weight_decay=cfg['weight_decay']) optimizer_depth = optim.SGD(model_depth.parameters(), lr=cfg['lr'],momentum=cfg['momentum'], weight_decay=cfg['weight_decay']) optimizer_clstm = optim.SGD(model_clstm.parameters(), lr=cfg['lr'],momentum=cfg['momentum'], weight_decay=cfg['weight_decay']) training = Trainer( cuda=cuda, model_rgb=model_rgb, model_depth=model_depth, model_clstm=model_clstm, optimizer_rgb=optimizer_rgb, optimizer_depth=optimizer_depth, optimizer_clstm=optimizer_clstm, train_loader=train_loader, max_iter=cfg['max_iteration'], snapshot=cfg['spshot'], outpath=args.snapshot_root, sshow=cfg['sshow'] ) training.epoch = start_epoch training.iteration = start_iteration training.train() else: for id, (data, depth, img_name, img_size) in enumerate(test_loader): print('testing bach %d' % (id+1)) inputs = Variable(data).cuda() inputs_depth = Variable(depth).cuda() n, c, h, w = inputs.size() depth = inputs_depth.view(n, h, w, 1).repeat(1, 1, 1, c) depth = depth.transpose(3, 1) depth = depth.transpose(3, 2) h1, h2, h3, h4, h5 = model_rgb(inputs) # RGBNet's output depth_vector, d1, d2, d3, d4, d5 = model_depth(depth) # DepthNet's output outputs_all = model_clstm(depth_vector, h1, h2, h3, h4, h5, d1, d2, d3, d4, d5) # Final output outputs_all = F.softmax(outputs_all, dim=1) outputs = outputs_all[0][1] outputs = outputs.cpu().data.resize_(h, w) imsave(os.path.join(MapRoot,img_name[0] + '.png'), outputs, img_size) print('The testing process has finished!') -------------------------------------------------------------------------------- /figure/dataset.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DUT-IIAU-OIP-Lab/DMRA_RGBD-SOD/ca9dc57e50c425293c1c131bbee25abd10f62b89/figure/dataset.png -------------------------------------------------------------------------------- /figure/overall.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DUT-IIAU-OIP-Lab/DMRA_RGBD-SOD/ca9dc57e50c425293c1c131bbee25abd10f62b89/figure/overall.png -------------------------------------------------------------------------------- /functions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import torch 4 | from scipy.misc import imresize 5 | 6 | def imsave(file_name, img, img_size): 7 | """ 8 | save a torch tensor as an image 9 | :param file_name: 'image/folder/image_name' 10 | :param img: 3*h*w torch tensor 11 | :return: nothing 12 | """ 13 | assert(type(img) == torch.FloatTensor, 14 | 'img must be a torch.FloatTensor') 15 | ndim = len(img.size()) 16 | assert(ndim == 2 or ndim == 3, 17 | 'img must be a 2 or 3 dimensional tensor') 18 | 19 | img = img.numpy() 20 | img = imresize(img, [img_size[1][0], img_size[0][0]], interp='nearest') 21 | if ndim == 3: 22 | plt.imsave(file_name, np.transpose(img, (1, 2, 0))) 23 | else: 24 | plt.imsave(file_name, img, cmap='gray') 25 | -------------------------------------------------------------------------------- /fusion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import torch.nn.functional as F 5 | 6 | ''' 7 | fusion: consits of DRB, DMSW, RAM. 8 | ''' 9 | 10 | class ConvLSTMCell(nn.Module): 11 | def __init__(self, input_channels, hidden_channels, kernel_size, bias=True): 12 | super(ConvLSTMCell, self).__init__() 13 | 14 | assert hidden_channels % 2 == 0 15 | 16 | self.input_channels = input_channels 17 | self.hidden_channels = hidden_channels 18 | self.bias = bias 19 | self.kernel_size = kernel_size 20 | self.num_features = 4 21 | 22 | self.padding = (kernel_size - 1) //2 23 | self.conv = nn.Conv2d(self.input_channels + self.hidden_channels, 4 * self.hidden_channels, self.kernel_size, 1, 24 | self.padding) 25 | self._initialize_weights() 26 | 27 | def _initialize_weights(self): 28 | for m in self.modules(): 29 | if isinstance(m, nn.Conv2d): 30 | nn.init.normal(m.weight.data, std=0.01) 31 | if m.bias is not None: 32 | m.bias.data.zero_() 33 | 34 | def forward(self, input, h, c): 35 | 36 | combined = torch.cat((input, h), dim=1) 37 | A = self.conv(combined) 38 | (ai, af, ao, ag) = torch.split(A, A.size()[1] // self.num_features, dim=1) 39 | i = torch.sigmoid(ai) #input gate 40 | f = torch.sigmoid(af) #forget gate 41 | o = torch.sigmoid(ao) #output 42 | g = torch.tanh(ag) #update_Cell 43 | 44 | new_c = f * c + i * g 45 | new_h = o * torch.tanh(new_c) 46 | return new_h, new_c, o 47 | 48 | @staticmethod 49 | def init_hidden(batch_size, hidden_c, shape): 50 | return (Variable(torch.zeros(batch_size, hidden_c, shape[0], shape[1])).cuda(), 51 | Variable(torch.zeros(batch_size, hidden_c, shape[0], shape[1])).cuda()) 52 | 53 | 54 | class ConvLSTM(nn.Module): 55 | def __init__(self, input_channels, hidden_channels, kernel_size, step=1, effective_step=[1], bias=True): 56 | super(ConvLSTM, self).__init__() 57 | self.input_channels = [input_channels] + hidden_channels 58 | self.hidden_channels = hidden_channels 59 | self.kernel_size = kernel_size 60 | self.num_layers = len(hidden_channels) 61 | self.step = step 62 | self.bias = bias 63 | self.effective_step = effective_step 64 | self._all_layers = [] 65 | for i in range(self.num_layers): 66 | name = 'cell{}'.format(i) 67 | cell = ConvLSTMCell(self.input_channels[i], self.hidden_channels[i], self.kernel_size, self.bias) 68 | setattr(self, name, cell) 69 | self._all_layers.append(cell) 70 | 71 | 72 | 73 | # --------------------------- Depth Refinement Block -------------------------- # 74 | # DRB 1 75 | self.conv_refine1_1 = nn.Conv2d(64, 64, 3, padding=1) 76 | self.bn_refine1_1 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 77 | self.relu_refine1_1 = nn.PReLU() 78 | self.conv_refine1_2 = nn.Conv2d(64, 64, 3, padding=1) 79 | self.bn_refine1_2 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 80 | self.relu_refine1_2 = nn.PReLU() 81 | self.conv_refine1_3 = nn.Conv2d(64, 64, 3, padding=1) 82 | self.bn_refine1_3 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 83 | self.relu_refine1_3 = nn.PReLU() 84 | self.down_2_1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 85 | self.down_2_2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 86 | # DRB 2 87 | self.conv_refine2_1 = nn.Conv2d(128, 128, 3, padding=1) 88 | self.bn_refine2_1 = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True) 89 | self.relu_refine2_1 = nn.PReLU() 90 | self.conv_refine2_2 = nn.Conv2d(128, 128, 3, padding=1) 91 | self.bn_refine2_2 = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True) 92 | self.relu_refine2_2 = nn.PReLU() 93 | self.conv_refine2_3 = nn.Conv2d(128, 128, 3, padding=1) 94 | self.bn_refine2_3 = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True) 95 | self.relu_refine2_3 = nn.PReLU() 96 | self.conv_r2_1 = nn.Conv2d(128, 64, 3, padding=1) 97 | self.bn_r2_1 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 98 | self.relu_r2_1 = nn.PReLU() 99 | # DRB 3 100 | self.conv_refine3_1 = nn.Conv2d(256, 256, 3, padding=1) 101 | self.bn_refine3_1 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) 102 | self.relu_refine3_1 = nn.PReLU() 103 | self.conv_refine3_2 = nn.Conv2d(256, 256, 3, padding=1) 104 | self.bn_refine3_2 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) 105 | self.relu_refine3_2 = nn.PReLU() 106 | self.conv_refine3_3 = nn.Conv2d(256, 256, 3, padding=1) 107 | self.bn_refine3_3 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) 108 | self.relu_refine3_3 = nn.PReLU() 109 | self.conv_r3_1 = nn.Conv2d(256, 64, 3, padding=1) 110 | self.bn_r3_1 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 111 | self.relu_r3_1 = nn.PReLU() 112 | # DRB 4 113 | self.conv_refine4_1 = nn.Conv2d(512, 512, 3, padding=1) 114 | self.bn_refine4_1 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 115 | self.relu_refine4_1 = nn.PReLU() 116 | self.conv_refine4_2 = nn.Conv2d(512, 512, 3, padding=1) 117 | self.bn_refine4_2 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 118 | self.relu_refine4_2 = nn.PReLU() 119 | self.conv_refine4_3 = nn.Conv2d(512, 512, 3, padding=1) 120 | self.bn_refine4_3 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 121 | self.relu_refine4_3 = nn.PReLU() 122 | self.conv_r4_1 = nn.Conv2d(512, 64, 3, padding=1) 123 | self.bn_r4_1 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 124 | self.relu_r4_1 = nn.PReLU() 125 | # DRB 5 126 | self.conv_refine5_1 = nn.Conv2d(512, 512, 3, padding=1) 127 | self.bn_refine5_1 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 128 | self.relu_refine5_1 = nn.PReLU() 129 | self.conv_refine5_2 = nn.Conv2d(512, 512, 3, padding=1) 130 | self.bn_refine5_2 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 131 | self.relu_refine5_2 = nn.PReLU() 132 | self.conv_refine5_3 = nn.Conv2d(512, 512, 3, padding=1) 133 | self.bn_refine5_3 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 134 | self.relu_refine5_3 = nn.PReLU() 135 | self.conv_r5_1 = nn.Conv2d(512, 64, 3, padding=1) 136 | self.bn_r5_1 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 137 | self.relu_r5_1 = nn.PReLU() 138 | 139 | 140 | # ----------------------------- Multi-scale ----------------------------- # 141 | # Add new structure: ASPP Atrous spatial Pyramid Pooling based on DeepLab v3 142 | # part0: 1*1*64 Conv 143 | self.conv5_conv_1 = nn.Conv2d(64, 64, 1, padding=0) # size: 64*64*64 144 | self.bn5_conv_1 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 145 | self.relu5_conv_1 = nn.ReLU(inplace=True) 146 | # part1: 3*3*64 Conv 147 | self.conv5_conv = nn.Conv2d(64, 64, 3, padding=1) # size: 64*64*64 148 | self.bn5_conv = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 149 | self.relu5_conv = nn.ReLU(inplace=True) 150 | # part2: 3*3*64 (dilated=7) Atrous Conv 151 | self.Atrous_conv_1 = nn.Conv2d(64, 64, 3, padding=7, dilation=7) # size: 64*64*64 152 | self.Atrous_bn5_1 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 153 | self.Atrous_relu_1 = nn.ReLU(inplace=True) 154 | # part3: 3*3*64 (dilated=5) Atrous Conv 155 | self.Atrous_conv_2 = nn.Conv2d(64, 64, 3, padding=5, dilation=5) # size: 64*64*64 156 | self.Atrous_bn5_2 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 157 | self.Atrous_relu_2 = nn.ReLU(inplace=True) 158 | # part4: 3*3*64 (dilated=3) Atrous Conv 159 | self.Atrous_conv_5 = nn.Conv2d(64, 64, 3, padding=3, dilation=3) # size: 64*64*64 160 | self.Atrous_bn5_5 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 161 | self.Atrous_relu_5 = nn.ReLU(inplace=True) 162 | # part5: Max_pooling # size: 16*16*64 163 | self.Atrous_pooling = nn.MaxPool2d(2, stride=2, ceil_mode=True) 164 | self.Atrous_conv_pool = nn.Conv2d(64, 64, 1, padding=0) 165 | self.Atrous_bn_pool = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 166 | self.Atrous_relu_pool = nn.ReLU(inplace=True) 167 | 168 | 169 | 170 | # ----------------------------- Channel-wise Attention ----------------------------- # 171 | self.conv_c = nn.Conv2d(64, 64, 3, padding=1) 172 | self.conv_h = nn.Conv2d(64, 64, 3, padding=1) 173 | self.pool_avg = nn.AvgPool2d(64, stride=2, ceil_mode=True) # 1/8 174 | 175 | 176 | 177 | # ----------------------------- Sptatial-wise Attention ----------------------------- # 178 | self.conv_s1 = nn.Conv2d(64 * self.num_layers, 64, 1, padding=0) 179 | self.conv_s2 = nn.Conv2d(64 * self.num_layers, 1, 1, padding=0) 180 | 181 | 182 | # ----------------------------- Prediction ----------------------------- # 183 | self.conv_pred = nn.Conv2d(64, 2, 1, padding=0) 184 | 185 | self._initialize_weights() 186 | 187 | def _initialize_weights(self): 188 | for m in self.modules(): 189 | if isinstance(m, nn.Conv2d): 190 | nn.init.normal(m.weight.data, std=0.01) 191 | if m.bias is not None: 192 | m.bias.data.zero_() 193 | 194 | 195 | def forward(self,depth_vector,h1,h2,h3,h4,h5,d1,d2,d3,d4,d5): 196 | internal_state = [] 197 | 198 | 199 | # -------- apply DRB --------- # 200 | # drb 1 201 | d1_1 = self.relu_refine1_1(self.bn_refine1_1(self.conv_refine1_1(d1))) 202 | d1_2 = self.relu_refine1_2(self.bn_refine1_2(self.conv_refine1_2(d1_1))) 203 | d1_2 = d1_2 + h1 # (256x256)*64 204 | d1_2 = self.down_2_2(self.down_2_1(d1_2)) 205 | d1_2_0 = d1_2 206 | d1_3 = self.relu_refine1_3(self.bn_refine1_3(self.conv_refine1_3(d1_2))) 207 | drb1 = d1_2_0 + d1_3 # (64 x 64)*64 208 | 209 | # drb 2 210 | d2_1 = self.relu_refine2_1(self.bn_refine2_1(self.conv_refine2_1(d2))) 211 | d2_2 = self.relu_refine2_2(self.bn_refine2_2(self.conv_refine2_2(d2_1))) 212 | d2_2 = d2_2 + h2 # (128x128)*128 213 | d2_2 = self.down_2_1(d2_2) 214 | d2_2_0 = d2_2 215 | d2_3 = self.relu_refine2_3(self.bn_refine2_3(self.conv_refine2_3(d2_2))) 216 | drb2 = d2_2_0 + d2_3 217 | drb2 = self.relu_r2_1(self.bn_r2_1(self.conv_r2_1(drb2))) # (64 x 64)*64 218 | 219 | # drb 3 220 | d3_1 = self.relu_refine3_1(self.bn_refine3_1(self.conv_refine3_1(d3))) 221 | d3_2 = self.relu_refine3_2(self.bn_refine3_2(self.conv_refine3_2(d3_1))) 222 | d3_2 = d3_2 + h3 # (64 x 64)*256 223 | d3_2_0 = d3_2 224 | d3_3 = self.relu_refine3_3(self.bn_refine3_3(self.conv_refine3_3(d3_2))) 225 | drb3 = d3_2_0 + d3_3 226 | drb3 = self.relu_r3_1(self.bn_r3_1(self.conv_r3_1(drb3))) # (64 x 64)*64 227 | 228 | # drb 4 229 | d4_1 = self.relu_refine4_1(self.bn_refine4_1(self.conv_refine4_1(d4))) 230 | d4_2 = self.relu_refine4_2(self.bn_refine4_2(self.conv_refine4_2(d4_1))) 231 | d4_2 = d4_2 + h4 # (32 x 32)*512 232 | d4_2 = F.upsample(d4_2, scale_factor=2, mode='bilinear') 233 | d4_2_0 = d4_2 234 | d4_3 = self.relu_refine4_3(self.bn_refine4_3(self.conv_refine4_3(d4_2))) 235 | drb4 = d4_2_0 + d4_3 236 | drb4 = self.relu_r4_1(self.bn_r4_1(self.conv_r4_1(drb4))) # (64 x 64)*64 237 | 238 | # drb 5 239 | d5_1 = self.relu_refine5_1(self.bn_refine5_1(self.conv_refine5_1(d5))) 240 | d5_2 = self.relu_refine5_2(self.bn_refine5_2(self.conv_refine5_2(d5_1))) 241 | d5_2 = d5_2 + h5 # (16 x 16)*64 242 | d5_2 = F.upsample(d5_2, scale_factor=4, mode='bilinear') 243 | d5_2_0 = d5_2 244 | d5_3 = self.relu_refine5_3(self.bn_refine5_3(self.conv_refine5_3(d5_2))) 245 | drb5 = d5_2_0 + d5_3 246 | drb5 = self.relu_r5_1(self.bn_r5_1(self.conv_r5_1(drb5))) # (64 x 64)*64 247 | 248 | drb_fusion = drb1 +drb2 + drb3 +drb4 +drb5 # (64 x 64)*64 249 | 250 | 251 | # --------------------- obtain multi-scale ----------------------- # 252 | f1 = self.relu5_conv_1(self.bn5_conv_1(self.conv5_conv_1(drb_fusion))) 253 | f2 = self.relu5_conv(self.bn5_conv(self.conv5_conv(drb_fusion))) 254 | f3 = self.Atrous_relu_1(self.Atrous_bn5_1(self.Atrous_conv_1(drb_fusion))) 255 | f4 = self.Atrous_relu_2(self.Atrous_bn5_2(self.Atrous_conv_2(drb_fusion))) 256 | f5 = self.Atrous_relu_5(self.Atrous_bn5_5(self.Atrous_conv_5(drb_fusion))) 257 | f6 = F.upsample( 258 | self.Atrous_relu_pool(self.Atrous_bn_pool(self.Atrous_conv_pool(self.Atrous_pooling(self.Atrous_pooling(drb_fusion))))), 259 | scale_factor=4, mode='bilinear') 260 | 261 | 262 | 263 | 264 | fusion = torch.cat([f1,f2,f3,f4,f5,f6],dim=0) # 6x64x64x64 265 | fusion_o = fusion 266 | input = torch.cat(torch.chunk(fusion, 6, dim=0), dim=1) 267 | 268 | 269 | 270 | 271 | for step in range(self.step): 272 | depth = depth_vector # 1x 6 x 1 x1 273 | 274 | if step == 0: 275 | basize, _, height, width = input.size() 276 | (h_step, c) = ConvLSTMCell.init_hidden(basize, self.hidden_channels[self.num_layers-1],(height, width)) 277 | 278 | 279 | # Feature-wise Attention 280 | depth = torch.mul(F.softmax(depth,dim=1), 6) 281 | 282 | basize, dime, h, w = depth.size() 283 | 284 | depth = depth.view(1, basize, dime, h, w).transpose(0,1).transpose(1,2) 285 | depth = torch.cat(torch.chunk(depth, basize, dim=0), dim=1).view(basize*dime, 1, 1, 1) 286 | 287 | depth = torch.mul(fusion_o, depth).view(1, basize*dime, 64, 64, 64) 288 | depth = torch.cat(torch.chunk(depth, basize, dim=1), dim=0) 289 | F_sum = torch.sum(depth, 1, keepdim=False)#.squeeze() 290 | 291 | 292 | # Channel-wise Attention 293 | depth_fw_ori = F_sum 294 | depth = self.conv_c(F_sum) 295 | h_c = self.conv_h(h_step) 296 | depth = depth + h_c 297 | depth = self.pool_avg(depth) 298 | depth = torch.mul(F.softmax(depth, dim=1), 64) 299 | F_sum_wt = torch.mul(depth_fw_ori, depth) 300 | 301 | 302 | 303 | x = F_sum_wt 304 | if step < self.step-1: 305 | for i in range(self.num_layers): 306 | # all cells are initialized in the first step 307 | if step == 0: 308 | bsize, _, height, width = x.size() 309 | (h, c) = ConvLSTMCell.init_hidden(bsize, self.hidden_channels[i], (height, width)) 310 | internal_state.append((h, c)) 311 | # do forward 312 | name = 'cell{}'.format(i) 313 | (h, c) = internal_state[i] 314 | h_step = h 315 | 316 | x, new_c, new_o = getattr(self, name)(x, h, c) # ConvLSTMCell forward 317 | internal_state[i] = (x, new_c) 318 | 319 | # only record effective steps 320 | #if step in self.effective_step: 321 | 322 | if step == 0: 323 | outputs_o = new_o 324 | else: 325 | outputs_o = torch.cat((outputs_o, new_o), dim=1) 326 | 327 | # ---------------> Spatial-wise Attention Module <----------------- # 328 | outputs = self.conv_s1(outputs_o) 329 | spatial_weight = F.sigmoid(self.conv_s2(outputs_o)) 330 | outputs = torch.mul(outputs,spatial_weight) 331 | # -------------------------> Prediction <-------------------------- # 332 | outputs = self.conv_pred(outputs) 333 | output = F.upsample(outputs, scale_factor=4, mode='bilinear') 334 | 335 | return output 336 | 337 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | def get_upsampling_weight(in_channels, out_channels, kernel_size): 7 | """Make a 2D bilinear kernel suitable for upsampling""" 8 | factor = (kernel_size + 1) // 2 9 | if kernel_size % 2 == 1: 10 | center = factor - 1 11 | else: 12 | center = factor - 0.5 13 | og = np.ogrid[:kernel_size, :kernel_size] 14 | filt = (1 - abs(og[0] - center) / factor) * \ 15 | (1 - abs(og[1] - center) / factor) 16 | weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size), 17 | dtype=np.float64) 18 | weight[range(in_channels), range(out_channels), :, :] = filt 19 | return torch.from_numpy(weight).float() 20 | 21 | 22 | 23 | 24 | 25 | #################################### RGB Network ##################################### 26 | class RGBNet(nn.Module): 27 | def __init__(self,n_class=2): 28 | super(RGBNet, self).__init__() 29 | 30 | # original image's size = 256*256*3 31 | 32 | # conv1 33 | self.conv1_1 = nn.Conv2d(3, 64, 3, padding=1) 34 | self.bn1_1 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 35 | self.relu1_1 = nn.ReLU(inplace=True) 36 | self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1) 37 | self.bn1_2 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 38 | self.relu1_2 = nn.ReLU(inplace=True) 39 | self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/2 2 layers 40 | 41 | # conv2 42 | self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1) 43 | self.bn2_1 = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True) 44 | self.relu2_1 = nn.ReLU(inplace=True) 45 | self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1) 46 | self.bn2_2 = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True) 47 | self.relu2_2 = nn.ReLU(inplace=True) 48 | self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/4 2 layers 49 | 50 | # conv3 51 | self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1) 52 | self.bn3_1 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) 53 | self.relu3_1 = nn.ReLU(inplace=True) 54 | self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1) 55 | self.bn3_2 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) 56 | self.relu3_2 = nn.ReLU(inplace=True) 57 | self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1) 58 | self.bn3_3 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) 59 | self.relu3_3 = nn.ReLU(inplace=True) 60 | self.conv3_4 = nn.Conv2d(256, 256, 3, padding=1) 61 | self.bn3_4 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) 62 | self.relu3_4 = nn.ReLU(inplace=True) 63 | self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/8 4 layers 64 | 65 | # conv4 66 | self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1) 67 | self.bn4_1 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 68 | self.relu4_1 = nn.ReLU(inplace=True) 69 | self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1) 70 | self.bn4_2 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 71 | self.relu4_2 = nn.ReLU(inplace=True) 72 | self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1) 73 | self.bn4_3 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 74 | self.relu4_3 = nn.ReLU(inplace=True) 75 | self.conv4_4 = nn.Conv2d(512, 512, 3, padding=1) 76 | self.bn4_4 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 77 | self.relu4_4 = nn.ReLU(inplace=True) 78 | self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/16 4 layers 79 | 80 | # conv5 81 | self.conv5_1 = nn.Conv2d(512, 512, 3, padding=1) 82 | self.bn5_1 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 83 | self.relu5_1 = nn.ReLU(inplace=True) 84 | self.conv5_2 = nn.Conv2d(512, 512, 3, padding=1) 85 | self.bn5_2 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 86 | self.relu5_2 = nn.ReLU(inplace=True) 87 | self.conv5_3 = nn.Conv2d(512, 512, 3, padding=1) 88 | self.bn5_3 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 89 | self.relu5_3 = nn.ReLU(inplace=True) 90 | self.conv5_4 = nn.Conv2d(512, 512, 3, padding=1) 91 | self.bn5_4 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 92 | self.relu5_4 = nn.ReLU(inplace=True) # 1/32 4 layers 93 | 94 | self._initialize_weights() 95 | 96 | 97 | def _initialize_weights(self): 98 | for m in self.modules(): 99 | if isinstance(m, nn.Conv2d): 100 | # m.weight.data.zero_() 101 | nn.init.normal(m.weight.data, std=0.01) 102 | if m.bias is not None: 103 | m.bias.data.zero_() 104 | if isinstance(m, nn.ConvTranspose2d): 105 | assert m.kernel_size[0] == m.kernel_size[1] 106 | initial_weight = get_upsampling_weight(m.in_channels, m.out_channels, m.kernel_size[0]) 107 | m.weight.data.copy_(initial_weight) 108 | 109 | 110 | 111 | def forward(self, x): 112 | h = x 113 | 114 | h = self.relu1_1(self.bn1_1(self.conv1_1(h))) 115 | h = self.relu1_2(self.bn1_2(self.conv1_2(h))) 116 | h_nopool1 = h 117 | h = self.pool1(h) 118 | h1 = h_nopool1 # (256x256)*64 119 | 120 | h = self.relu2_1(self.bn2_1(self.conv2_1(h))) 121 | h = self.relu2_2(self.bn2_2(self.conv2_2(h))) 122 | h_nopool2 = h 123 | h = self.pool2(h) 124 | h2 = h_nopool2 # (128x128)*128 125 | 126 | h = self.relu3_1(self.bn3_1(self.conv3_1(h))) 127 | h = self.relu3_2(self.bn3_2(self.conv3_2(h))) 128 | h = self.relu3_3(self.bn3_3(self.conv3_3(h))) 129 | h = self.relu3_4(self.bn3_4(self.conv3_4(h))) 130 | h_nopool3 = h 131 | h = self.pool3(h) 132 | h3 = h_nopool3 # (64x64)*256 133 | 134 | h = self.relu4_1(self.bn4_1(self.conv4_1(h))) 135 | h = self.relu4_2(self.bn4_2(self.conv4_2(h))) 136 | h = self.relu4_3(self.bn4_3(self.conv4_3(h))) 137 | h = self.relu4_4(self.bn4_4(self.conv4_4(h))) 138 | h_nopool4 = h 139 | h = self.pool4(h) 140 | h4 = h_nopool4 # (32x32)*512 141 | 142 | h = self.relu5_1(self.bn5_1(self.conv5_1(h))) 143 | h = self.relu5_2(self.bn5_2(self.conv5_2(h))) 144 | h = self.relu5_3(self.bn5_3(self.conv5_3(h))) 145 | h = self.relu5_4(self.bn5_4(self.conv5_4(h))) 146 | h5 = h # (16x16)*512 147 | 148 | 149 | return h1,h2,h3,h4,h5 150 | 151 | 152 | 153 | def copy_params_from_vgg19_bn(self, vgg19_bn): 154 | features = [ 155 | self.conv1_1, self.bn1_1, self.relu1_1, 156 | self.conv1_2, self.bn1_2, self.relu1_2, 157 | self.pool1, 158 | self.conv2_1, self.bn2_1, self.relu2_1, 159 | self.conv2_2, self.bn2_2, self.relu2_2, 160 | self.pool2, 161 | self.conv3_1, self.bn3_1, self.relu3_1, 162 | self.conv3_2, self.bn3_2, self.relu3_2, 163 | self.conv3_3, self.bn3_3, self.relu3_3, 164 | self.conv3_4, self.bn3_4, self.relu3_4, 165 | self.pool3, 166 | self.conv4_1, self.bn4_1, self.relu4_1, 167 | self.conv4_2, self.bn4_2, self.relu4_2, 168 | self.conv4_3, self.bn4_3, self.relu4_3, 169 | self.conv4_4, self.bn4_4, self.relu4_4, 170 | self.pool4, 171 | self.conv5_1, self.bn5_1, self.relu5_1, 172 | self.conv5_2, self.bn5_2, self.relu5_2, 173 | self.conv5_3, self.bn5_3, self.relu5_3, 174 | self.conv5_4, self.bn5_4, self.relu5_4, 175 | ] 176 | for l1, l2 in zip(vgg19_bn.features, features): 177 | if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d): 178 | assert l1.weight.size() == l2.weight.size() 179 | assert l1.bias.size() == l2.bias.size() 180 | l2.weight.data = l1.weight.data 181 | l2.bias.data = l1.bias.data 182 | if isinstance(l1, nn.BatchNorm2d) and isinstance(l2, nn.BatchNorm2d): 183 | assert l1.weight.size() == l2.weight.size() 184 | assert l1.bias.size() == l2.bias.size() 185 | l2.weight.data = l1.weight.data 186 | l2.bias.data = l1.bias.data 187 | 188 | 189 | #################################### Depth Network ##################################### 190 | class DepthNet(nn.Module): 191 | def __init__(self, n_class=2): 192 | super(DepthNet, self).__init__() 193 | 194 | # original image's size = 256*256*3 195 | 196 | # conv1 197 | self.conv1_1 = nn.Conv2d(3, 64, 3, padding=1) 198 | self.bn1_1 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 199 | self.relu1_1 = nn.ReLU(inplace=True) 200 | self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1) 201 | self.bn1_2 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 202 | self.relu1_2 = nn.ReLU(inplace=True) 203 | self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/2 2 layers 204 | 205 | # conv2 206 | self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1) 207 | self.bn2_1 = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True) 208 | self.relu2_1 = nn.ReLU(inplace=True) 209 | self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1) 210 | self.bn2_2 = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True) 211 | self.relu2_2 = nn.ReLU(inplace=True) 212 | self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/4 2 layers 213 | 214 | # conv3 215 | self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1) 216 | self.bn3_1 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) 217 | self.relu3_1 = nn.ReLU(inplace=True) 218 | self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1) 219 | self.bn3_2 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) 220 | self.relu3_2 = nn.ReLU(inplace=True) 221 | self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1) 222 | self.bn3_3 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) 223 | self.relu3_3 = nn.ReLU(inplace=True) 224 | self.conv3_4 = nn.Conv2d(256, 256, 3, padding=1) 225 | self.bn3_4 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) 226 | self.relu3_4 = nn.ReLU(inplace=True) 227 | self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/8 4 layers 228 | 229 | # conv4 230 | self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1) 231 | self.bn4_1 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 232 | self.relu4_1 = nn.ReLU(inplace=True) 233 | self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1) 234 | self.bn4_2 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 235 | self.relu4_2 = nn.ReLU(inplace=True) 236 | self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1) 237 | self.bn4_3 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 238 | self.relu4_3 = nn.ReLU(inplace=True) 239 | self.conv4_4 = nn.Conv2d(512, 512, 3, padding=1) 240 | self.bn4_4 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 241 | self.relu4_4 = nn.ReLU(inplace=True) 242 | self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/16 4 layers 243 | 244 | # conv5 245 | self.conv5_1 = nn.Conv2d(512, 512, 3, padding=1) 246 | self.bn5_1 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 247 | self.relu5_1 = nn.ReLU(inplace=True) 248 | self.conv5_2 = nn.Conv2d(512, 512, 3, padding=1) 249 | self.bn5_2 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 250 | self.relu5_2 = nn.ReLU(inplace=True) 251 | self.conv5_3 = nn.Conv2d(512, 512, 3, padding=1) 252 | self.bn5_3 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 253 | self.relu5_3 = nn.ReLU(inplace=True) 254 | self.conv5_4 = nn.Conv2d(512, 512, 3, padding=1) 255 | self.bn5_4 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 256 | self.relu5_4 = nn.ReLU(inplace=True) # 1/32 4 layers 257 | 258 | # depth vector 259 | self.conv_fcn2 = nn.Conv2d(512, 64, 3, padding=1) 260 | self.pool_avg = nn.AvgPool2d(16, stride=2, ceil_mode=True) 261 | self.conv_c = nn.Conv2d(64, 6, 1, padding=0) 262 | 263 | self._initialize_weights() 264 | 265 | def _initialize_weights(self): 266 | for m in self.modules(): 267 | if isinstance(m, nn.Conv2d): 268 | # m.weight.data.zero_() 269 | nn.init.normal(m.weight.data, std=0.01) 270 | if m.bias is not None: 271 | m.bias.data.zero_() 272 | if isinstance(m, nn.ConvTranspose2d): 273 | assert m.kernel_size[0] == m.kernel_size[1] 274 | initial_weight = get_upsampling_weight(m.in_channels, m.out_channels, m.kernel_size[0]) 275 | m.weight.data.copy_(initial_weight) 276 | 277 | def forward(self, x): 278 | h = x 279 | 280 | h = self.relu1_1(self.bn1_1(self.conv1_1(h))) 281 | h = self.relu1_2(self.bn1_2(self.conv1_2(h))) 282 | h_nopool1 = h 283 | h = self.pool1(h) 284 | d1 = h_nopool1 # (256x256)*64 285 | 286 | h = self.relu2_1(self.bn2_1(self.conv2_1(h))) 287 | h = self.relu2_2(self.bn2_2(self.conv2_2(h))) 288 | h_nopool2 = h 289 | h = self.pool2(h) 290 | d2 = h_nopool2 # (128x128)*128 291 | 292 | h = self.relu3_1(self.bn3_1(self.conv3_1(h))) 293 | h = self.relu3_2(self.bn3_2(self.conv3_2(h))) 294 | h = self.relu3_3(self.bn3_3(self.conv3_3(h))) 295 | h = self.relu3_4(self.bn3_4(self.conv3_4(h))) 296 | h_nopool3 = h 297 | h = self.pool3(h) 298 | d3 = h_nopool3 # (64x64)*256 299 | 300 | h = self.relu4_1(self.bn4_1(self.conv4_1(h))) 301 | h = self.relu4_2(self.bn4_2(self.conv4_2(h))) 302 | h = self.relu4_3(self.bn4_3(self.conv4_3(h))) 303 | h = self.relu4_4(self.bn4_4(self.conv4_4(h))) 304 | h_nopool4 = h 305 | h = self.pool4(h) 306 | d4 = h_nopool4 # (32x32)*512 307 | 308 | h = self.relu5_1(self.bn5_1(self.conv5_1(h))) 309 | h = self.relu5_2(self.bn5_2(self.conv5_2(h))) 310 | h = self.relu5_3(self.bn5_3(self.conv5_3(h))) 311 | h = self.relu5_4(self.bn5_4(self.conv5_4(h))) 312 | d5 = h # (16x16)*512 313 | 314 | # depth vector 315 | vector = self.conv_fcn2(d5) 316 | vector = self.pool_avg(vector) 317 | depth_vector = self.conv_c(vector) 318 | 319 | 320 | 321 | return depth_vector, d1, d2, d3, d4, d5 322 | 323 | def copy_params_from_vgg19_bn(self, vgg19_bn): 324 | features = [ 325 | self.conv1_1, self.bn1_1, self.relu1_1, 326 | self.conv1_2, self.bn1_2, self.relu1_2, 327 | self.pool1, 328 | self.conv2_1, self.bn2_1, self.relu2_1, 329 | self.conv2_2, self.bn2_2, self.relu2_2, 330 | self.pool2, 331 | self.conv3_1, self.bn3_1, self.relu3_1, 332 | self.conv3_2, self.bn3_2, self.relu3_2, 333 | self.conv3_3, self.bn3_3, self.relu3_3, 334 | self.conv3_4, self.bn3_4, self.relu3_4, 335 | self.pool3, 336 | self.conv4_1, self.bn4_1, self.relu4_1, 337 | self.conv4_2, self.bn4_2, self.relu4_2, 338 | self.conv4_3, self.bn4_3, self.relu4_3, 339 | self.conv4_4, self.bn4_4, self.relu4_4, 340 | self.pool4, 341 | self.conv5_1, self.bn5_1, self.relu5_1, 342 | self.conv5_2, self.bn5_2, self.relu5_2, 343 | self.conv5_3, self.bn5_3, self.relu5_3, 344 | self.conv5_4, self.bn5_4, self.relu5_4, 345 | ] 346 | for l1, l2 in zip(vgg19_bn.features, features): 347 | if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d): 348 | assert l1.weight.size() == l2.weight.size() 349 | assert l1.bias.size() == l2.bias.size() 350 | l2.weight.data = l1.weight.data 351 | l2.bias.data = l1.bias.data 352 | if isinstance(l1, nn.BatchNorm2d) and isinstance(l2, nn.BatchNorm2d): 353 | assert l1.weight.size() == l2.weight.size() 354 | assert l1.bias.size() == l2.bias.size() 355 | l2.weight.data = l1.weight.data 356 | l2.bias.data = l1.bias.data 357 | 358 | 359 | 360 | 361 | 362 | 363 | 364 | 365 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from torch.autograd import Variable 4 | import torch.nn.functional as F 5 | import torch 6 | 7 | 8 | 9 | running_loss_final = 0 10 | 11 | 12 | 13 | def cross_entropy2d(input, target, weight=None, size_average=True): 14 | n, c, h, w = input.size() 15 | 16 | input = input.transpose(1,2).transpose(2,3).contiguous() 17 | input = input[target.view(n, h, w, 1).repeat(1, 1, 1, c) >= 0] 18 | input = input.view(-1, c) 19 | # target: (n*h*w,) 20 | mask = target >= 0 21 | target = target[mask] 22 | loss = F.cross_entropy(input, target, weight=weight, size_average=False) 23 | if size_average: 24 | loss /= mask.data.sum() 25 | return loss 26 | 27 | 28 | 29 | 30 | class Trainer(object): 31 | 32 | def __init__(self, cuda, model_rgb,model_depth,model_clstm, optimizer_rgb, 33 | optimizer_depth,optimizer_clstm, 34 | train_loader, max_iter, snapshot, outpath, sshow, size_average=False): 35 | self.cuda = cuda 36 | self.model_rgb = model_rgb 37 | self.model_depth = model_depth 38 | self.model_clstm = model_clstm 39 | self.optim_rgb = optimizer_rgb 40 | self.optim_depth = optimizer_depth 41 | self.optim_clstm = optimizer_clstm 42 | self.train_loader = train_loader 43 | self.epoch = 0 44 | self.iteration = 0 45 | self.max_iter = max_iter 46 | self.snapshot = snapshot 47 | self.outpath = outpath 48 | self.sshow = sshow 49 | self.size_average = size_average 50 | 51 | 52 | 53 | def train_epoch(self): 54 | 55 | for batch_idx, (data, target, depth) in enumerate(self.train_loader): 56 | 57 | 58 | iteration = batch_idx + self.epoch * len(self.train_loader) 59 | if self.iteration != 0 and (iteration - 1) != self.iteration: 60 | continue # for resuming 61 | self.iteration = iteration 62 | if self.iteration >= self.max_iter: 63 | break 64 | if self.cuda: 65 | data, target, depth = data.cuda(), target.cuda(), depth.cuda() 66 | data, target, depth = Variable(data), Variable(target), Variable(depth) 67 | n, c, h, w = data.size() # batch_size, channels, height, weight 68 | depth = depth.view(n,h,w,1).repeat(1,1,1,c) 69 | depth = depth.transpose(3,1) 70 | depth = depth.transpose(3,2) 71 | 72 | 73 | self.optim_rgb.zero_grad() 74 | self.optim_depth.zero_grad() 75 | self.optim_clstm.zero_grad() 76 | 77 | global running_loss_final 78 | 79 | 80 | h1,h2,h3,h4,h5 = self.model_rgb(data) # RGBNet's output 81 | depth_vector,d1,d2,d3,d4,d5 = self.model_depth(depth) # DepthNet's output 82 | 83 | # ------------------------------ Fusion --------------------------- # 84 | score_fusion = self.model_clstm(depth_vector,h1,h2,h3,h4,h5,d1,d2,d3,d4,d5) # Final output 85 | loss_all = cross_entropy2d(score_fusion, target, size_average=self.size_average) 86 | 87 | 88 | 89 | running_loss_final += loss_all.data[0] 90 | 91 | 92 | if iteration % self.sshow == (self.sshow-1): 93 | print('\n [%3d, %6d, The training loss of DMRA_Net: %.3f]' % (self.epoch + 1, iteration + 1, running_loss_final / (n * self.sshow))) 94 | 95 | running_loss_final = 0.0 96 | 97 | 98 | if iteration <= 200000: 99 | if iteration % self.snapshot == (self.snapshot-1): 100 | savename = ('%s/snapshot_iter_%d.pth' % (self.outpath, iteration+1)) 101 | torch.save(self.model_rgb.state_dict(), savename) 102 | print('save: (snapshot: %d)' % (iteration+1)) 103 | 104 | savename_focal = ('%s/depth_snapshot_iter_%d.pth' % (self.outpath, iteration+1)) 105 | torch.save(self.model_depth.state_dict(), savename_focal) 106 | print('save: (snapshot_depth: %d)' % (iteration+1)) 107 | 108 | savename_clstm = ('%s/clstm_snapshot_iter_%d.pth' % (self.outpath, iteration+1)) 109 | torch.save(self.model_clstm.state_dict(), savename_clstm) 110 | print('save: (snapshot_clstm: %d)' % (iteration+1)) 111 | 112 | else: 113 | if iteration % 10000 == (10000 - 1): 114 | savename = ('%s/snapshot_iter_%d.pth' % (self.outpath, iteration + 1)) 115 | torch.save(self.model_rgb.state_dict(), savename) 116 | print('save: (snapshot: %d)' % (iteration + 1)) 117 | 118 | savename_focal = ('%s/depth_snapshot_iter_%d.pth' % (self.outpath, iteration + 1)) 119 | torch.save(self.model_depth.state_dict(), savename_focal) 120 | print('save: (snapshot_depth: %d)' % (iteration + 1)) 121 | 122 | savename_clstm = ('%s/clstm_snapshot_iter_%d.pth' % (self.outpath, iteration + 1)) 123 | torch.save(self.model_clstm.state_dict(), savename_clstm) 124 | print('save: (snapshot_clstm: %d)' % (iteration + 1)) 125 | 126 | 127 | 128 | if (iteration+1) == self.max_iter: 129 | savename = ('%s/snapshot_iter_%d.pth' % (self.outpath, iteration+1)) 130 | torch.save(self.model_rgb.state_dict(), savename) 131 | print('save: (snapshot: %d)' % (iteration+1)) 132 | 133 | savename_focal = ('%s/depth_snapshot_iter_%d.pth' % (self.outpath, iteration+1)) 134 | torch.save(self.model_depth.state_dict(), savename_focal) 135 | print('save: (snapshot_depth: %d)' % (iteration+1)) 136 | 137 | savename_clstm = ('%s/clstm_snapshot_iter_%d.pth' % (self.outpath, iteration+1)) 138 | torch.save(self.model_clstm.state_dict(), savename_clstm) 139 | print('save: (snapshot_clstm: %d)' % (iteration+1)) 140 | 141 | 142 | 143 | 144 | loss_all.backward() 145 | self.optim_clstm.step() 146 | self.optim_depth.step() 147 | self.optim_rgb.step() 148 | 149 | def train(self): 150 | max_epoch = int(math.ceil(1. * self.max_iter / len(self.train_loader))) 151 | 152 | for epoch in range(max_epoch): 153 | self.epoch = epoch 154 | self.train_epoch() 155 | if self.iteration >= self.max_iter: 156 | break 157 | -------------------------------------------------------------------------------- /transform.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from PIL import Image 5 | 6 | def colormap(n): #import n, then r'g'b obtain values, finally acquiring colormap 7 | cmap=np.zeros([n, 3]).astype(np.uint8) 8 | 9 | for i in np.arange(n): 10 | r, g, b = np.zeros(3) 11 | 12 | for j in np.arange(8): 13 | r = r + (1<<(7-j))*((i&(1<<(3*j))) >> (3*j)) 14 | g = g + (1<<(7-j))*((i&(1<<(3*j+1))) >> (3*j+1)) 15 | b = b + (1<<(7-j))*((i&(1<<(3*j+2))) >> (3*j+2)) 16 | 17 | cmap[i,:] = np.array([r, g, b]) 18 | 19 | return cmap 20 | 21 | class Relabel: 22 | 23 | def __init__(self, olabel, nlabel): 24 | self.olabel = olabel 25 | self.nlabel = nlabel 26 | 27 | def __call__(self, tensor): 28 | assert isinstance(tensor, torch.LongTensor), 'tensor needs to be LongTensor' 29 | tensor[tensor == self.olabel] = self.nlabel 30 | return tensor 31 | 32 | 33 | class ToLabel: 34 | 35 | def __call__(self, image): 36 | return torch.from_numpy(np.array(image)).long().unsqueeze(0) 37 | 38 | 39 | class Colorize: 40 | 41 | def __init__(self, n=21): 42 | self.cmap = colormap(256) 43 | self.cmap[n] = self.cmap[-1] 44 | self.cmap = torch.from_numpy(self.cmap[:n]) 45 | 46 | def __call__(self, gray_image): 47 | size = gray_image.size() 48 | color_image = torch.ByteTensor(3, size[0], size[1]).fill_(0) 49 | 50 | for label in range(1, len(self.cmap)): 51 | mask = (gray_image == label) 52 | 53 | color_image[0][mask] = self.cmap[label][0] 54 | color_image[1][mask] = self.cmap[label][1] 55 | color_image[2][mask] = self.cmap[label][2] 56 | 57 | return color_image 58 | --------------------------------------------------------------------------------