├── ConvGRU.py ├── LICENSE ├── README.md ├── dataloader.py ├── eval.py ├── graph.py └── model.py /ConvGRU.py: -------------------------------------------------------------------------------- 1 | ################################################### 2 | # Nicolo Savioli, 2017 -- Conv-GRU pytorch v 1.0 # 3 | ################################################### 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as f 7 | from torch.autograd import Variable 8 | 9 | class ConvGRUCell(nn.Module): 10 | 11 | def __init__(self,input_size,hidden_size,kernel_size): 12 | super(ConvGRUCell,self).__init__() 13 | self.input_size = input_size 14 | self.cuda_flag = True 15 | self.hidden_size = hidden_size 16 | self.kernel_size = kernel_size 17 | self.padding = int((self.kernel_size - 1) / 2) 18 | self.ConvGates = nn.Conv2d(self.input_size + self.hidden_size,2 * self.hidden_size,self.kernel_size,padding=self.padding) 19 | self.Conv_ct = nn.Conv2d(self.input_size + self.hidden_size,self.hidden_size,self.kernel_size,padding=self.padding) 20 | dtype = torch.FloatTensor 21 | for m in self.modules(): 22 | if isinstance(m, nn.Conv2d): 23 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 24 | m.weight.data.normal_(0, 0.01) 25 | elif isinstance(m, nn.BatchNorm2d): 26 | m.weight.data.fill_(1) 27 | m.bias.data.zero_() 28 | 29 | def forward(self,input,hidden): 30 | if hidden is None: 31 | 32 | size_h = [input.data.size()[0],self.hidden_size] + list(input.data.size()[2:]) 33 | if self.cuda_flag == True: 34 | 35 | hidden = Variable(torch.zeros(size_h)).cuda() 36 | else: 37 | hidden = Variable(torch.zeros(size_h)) 38 | #print('input type:', (input[0,1,1,1]), hidden[0,1,1,1]) 39 | c1 = self.ConvGates(torch.cat((input,hidden),1)) 40 | (rt,ut) = c1.chunk(2, 1) 41 | reset_gate = torch.sigmoid(rt) 42 | update_gate = torch.sigmoid(ut) 43 | gated_hidden = torch.mul(reset_gate,hidden) 44 | p1 = self.Conv_ct(torch.cat((input,gated_hidden),1)) 45 | ct = torch.tanh(p1) 46 | #next_h = torch.mul(update_gate,hidden) + (1-update_gate)*ct 47 | next_h = (1-update_gate) *hidden + torch.mul(update_gate , ct) 48 | return next_h 49 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 LA30 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Cas-Gnn 2 | This project provides the codes and results for '[Cascade Graph Neural Networks for RGB-D Salient Object Detection](https://arxiv.org/pdf/2008.03087.pdf).' (ECCV-2020) 3 | 4 | ## Saliency maps and Evaluation 5 | 6 | The trained model is available on [GoogleDrive](https://drive.google.com/file/d/1wOPT2OVSihwkPBcOqK5YkuFV66D9Cuey/view?usp=sharing). 7 | 8 | All of the saliency maps are available on [GoogleDrive](https://drive.google.com/file/d/1XXUCFUzD2g5Uamh7DzjQr-S7vX5lQ3yF/view?usp=sharing) or [BaiduYun](https://pan.baidu.com/s/1VVZbqYq9yc6Ouu0jrJTVuw)(3e23) 9 | 10 | You can use the toolbox provided by [dpfan](http://dpfan.net/d3netbenchmark/) or [jiwei0921](https://github.com/jiwei0921/Saliency-Evaluation-Toolbox) for evaluation. 11 | 12 | ### If you think this work is helpful, please cite 13 | ``` 14 | @InProceedings{Luo2020CascadeGN, 15 | title={Cascade Graph Neural Networks for RGB-D Salient Object Detection}, 16 | author={Ao Luo and Xin Li and Fan Yang and Zhicheng Jiao and Hong Cheng and Siwei Lyu}, 17 | booktitle={In 16th European Conference on Computer Vision (ECCV)}, 18 | year={2020}, 19 | } 20 | ``` 21 | 22 | If you have any questions, please contact me at (aoluo_uestc@hotmail.com). 23 | 24 | ## Acknowledgement 25 | Code is partially adapted from [DMRA](https://github.com/jiwei0921/DMRA). Please refer to [RGBD-SOD-datasets](https://github.com/jiwei0921/RGBD-SOD-datasets) for more details about the training and testing datasets. 26 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import PIL.Image 5 | from PIL import Image 6 | #import scipy.io as sio 7 | import torch 8 | from torch.utils import data 9 | 10 | 11 | class Data(data.Dataset): 12 | """ 13 | load data in a folder 14 | """ 15 | mean_rgb = np.array([0.447, 0.407, 0.386]) 16 | std_rgb = np.array([0.244, 0.250, 0.253]) 17 | 18 | def __init__(self, root, transform=False): 19 | super().__init__() 20 | self.root = root 21 | self._transform = transform 22 | 23 | img_root = os.path.join(self.root, 'test_images') 24 | depth_root = os.path.join(self.root, 'test_depth') 25 | mask_root = os.path.join(self.root, 'test_masks') 26 | 27 | file_names = os.listdir(img_root) 28 | self.img_names = [] 29 | self.names = [] 30 | self.depth_names = [] 31 | self.mask_names = [] 32 | 33 | for i, name in enumerate(file_names): 34 | if not name.endswith('.jpg'): 35 | continue 36 | self.img_names.append( 37 | os.path.join(img_root, name) 38 | ) 39 | self.names.append(name[:-4]) 40 | self.depth_names.append( 41 | #os.path.join(depth_root, name[:-4]+'_depth.png') # Test RGBD135 dataset 42 | os.path.join(depth_root, name[:-4] + '.png') 43 | ) 44 | self.mask_names.append( 45 | os.path.join(mask_root, name[:-4]+'.png') 46 | ) 47 | 48 | def __len__(self): 49 | return len(self.img_names) 50 | 51 | def __getitem__(self, index): 52 | # load image 53 | img_file = self.img_names[index] 54 | img = PIL.Image.open(img_file) 55 | img_size = img.size 56 | # img = img.resize((256, 256)) 57 | 58 | # load label 59 | mask_file = self.mask_names[index] 60 | mask = PIL.Image.open(mask_file) 61 | 62 | # load focal 63 | depth_file = self.depth_names[index] 64 | depth = PIL.Image.open(depth_file) 65 | # depth = depth.resize(256, 256) 66 | 67 | img = np.array(img, dtype=np.uint8) 68 | mask = np.array(mask, dtype=np.uint8) 69 | mask[mask != 0] = 1 70 | depth = np.array(depth, dtype=np.uint8) 71 | 72 | if self._transform: 73 | img, focal = self.transform(img, depth) 74 | return img, focal, mask, self.names[index], img_size 75 | else: 76 | return img, depth, mask, self.names[index], img_size 77 | 78 | def transform(self, img, depth): 79 | img = img.astype(np.float64)/255.0 80 | img -= self.mean_rgb 81 | img /= self.std_rgb 82 | img = img.transpose(2, 0, 1) 83 | img = torch.from_numpy(img).float() 84 | 85 | depth = depth.astype(np.float64)/255.0 86 | depth = torch.from_numpy(depth).float() 87 | 88 | return img, depth 89 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.utils.data import DataLoader 4 | from dataloader import Data 5 | import argparse 6 | 7 | from model import CasGnn 8 | 9 | 10 | parser=argparse.ArgumentParser() 11 | parser.add_argument('--checkpoints', type=str, default='./checkpoints/CasGnn.pth') 12 | parser.add_argument('--data_root', type=str, default='/data/Datasets/SOD/NJUD/test_data') 13 | 14 | 15 | def main(args): 16 | # data 17 | data_root = args.data_root 18 | test_loader = torch.utils.data.DataLoader(Data(data_root, transform=True), 19 | batch_size=1, shuffle=False, num_workers=4, pin_memory=True) 20 | 21 | model = CasGnn().cuda() 22 | model.load_state_dict(torch.load(args.checkpoints)) 23 | mae = func_eval(test_loader, model) 24 | 25 | 26 | def func_eval(test_loader, model): 27 | mae = 0 28 | model.eval() 29 | for id, (data, depth, mask, img_name, img_size) in enumerate(test_loader): 30 | datas = [data, depth, mask] 31 | 32 | with torch.no_grad(): 33 | inputs = data.cuda() 34 | depth = depth.cuda() 35 | n, c, h, w = inputs.size() 36 | depth = depth.unsqueeze(1).repeat(1, c, 1, 1) 37 | 38 | pred = model([inputs, depth]) 39 | out = F.softmax(pred, dim=1) 40 | out = out.max(1)[1].squeeze_(1).float() * out[:, 1] 41 | 42 | mae += abs(out.detach().cpu() - mask.float()).mean() 43 | 44 | mae = mae / len(test_loader) 45 | print(' * MAE {mae:.3f} ' 46 | .format(mae=mae)) 47 | return mae 48 | 49 | 50 | if __name__ == '__main__': 51 | main(parser.parse_args()) 52 | -------------------------------------------------------------------------------- /graph.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import ConvGRU as ConvGRU 5 | 6 | 7 | class GraphReasoning(nn.Module): 8 | def __init__(self, chnn_in, rd_sc, dila, n_iter): 9 | super().__init__() 10 | self.n_iter = n_iter 11 | self.ppm_rgb = PPM(chnn_in, rd_sc, dila) 12 | self.ppm_dep = PPM(chnn_in, rd_sc, dila) 13 | self.n_node = len(dila) 14 | self.graph_rgb = GraphModel(self.n_node, chnn_in//rd_sc) 15 | self.graph_dep = GraphModel(self.n_node, chnn_in//rd_sc) 16 | chnn = chnn_in * 2 // rd_sc 17 | C_ca = [nn.Sequential( 18 | nn.AdaptiveAvgPool2d((1, 1)), 19 | nn.Conv2d(chnn, chnn//4, 1, bias=False), 20 | nn.ReLU(inplace=True), 21 | nn.Conv2d(chnn//4, chnn_in//rd_sc, 1, bias=False)) 22 | for ii in range(2)] 23 | self.C_ca = nn.ModuleList(C_ca) 24 | C_pa = [nn.Conv2d(chnn_in//rd_sc, 1, 1, bias=False) for ii in range(2)] 25 | self.C_pa = nn.ModuleList(C_pa) 26 | 27 | def _enh(self, Func, src, dst): 28 | out = torch.sigmoid(Func(src)) * dst + dst 29 | return out 30 | 31 | def _inn(self, Func, feat): 32 | feat = [fm.unsqueeze(1) for fm in feat] 33 | feat = torch.cat(feat, 1) 34 | for ii in range(self.n_iter): 35 | feat = Func(feat) 36 | feat = torch.split(feat, 1, 1) 37 | feat = [fm.squeeze_(1) for fm in feat] 38 | return feat 39 | 40 | def _int(self, Func, src_1, src_2): 41 | out_2 = src_1 * torch.sigmoid(Func[0](src_1 - src_2)) + src_2 42 | out_1 = src_2 * torch.sigmoid(Func[1](src_2 - src_1)) + src_1 43 | return out_1, out_2 44 | 45 | def forward(self, inputs, node=False): 46 | feat_rgb, feat_dep, nd_rgb, nd_dep = inputs 47 | feat_rgb = self.ppm_rgb(feat_rgb) 48 | feat_dep = self.ppm_dep(feat_dep) 49 | if node: 50 | feat_rgb = [self._enh(self.C_ca[0], nd_rgb, fm) for fm in feat_rgb] 51 | feat_dep = [self._enh(self.C_ca[1], nd_dep, fm) for fm in feat_dep] 52 | for ii in range(self.n_node): 53 | feat_rgb[ii], feat_dep[ii] = self._int([self.C_pa[0], self.C_pa[1]], feat_rgb[ii], feat_dep[ii]) 54 | feat_rgb[ii], feat_dep[ii] = self._int([self.C_pa[0], self.C_pa[1]], feat_rgb[ii], feat_dep[ii]) 55 | feat_rgb = self._inn(self.graph_rgb, feat_rgb) 56 | feat_dep = self._inn(self.graph_dep, feat_dep) 57 | return feat_rgb, feat_dep 58 | 59 | 60 | class PPM(nn.Module): 61 | def __init__(self, chnn_in, rd_sc, dila): 62 | super(PPM, self).__init__() 63 | chnn = chnn_in // rd_sc 64 | convs = [nn.Sequential( 65 | nn.Conv2d(chnn_in, chnn, 3, padding=ii, dilation=ii, bias=False), 66 | nn.BatchNorm2d(chnn), 67 | nn.ReLU(inplace=True)) 68 | for ii in dila] 69 | self.convs = nn.ModuleList(convs) 70 | 71 | def forward(self, inputs): 72 | feats = [] 73 | for conv in self.convs: 74 | feat = conv(inputs) 75 | feats.append(feat) 76 | return feats 77 | 78 | 79 | class GraphModel(nn.Module): 80 | def __init__(self, N, chnn_in=256): 81 | super().__init__() 82 | self.n_node = N 83 | chnn = chnn_in 84 | self.C_wgt = nn.Conv2d(chnn*(N-1), (N-1), 1, groups=(N-1), bias=False) 85 | self.ConvGRU = ConvGRU.ConvGRUCell(chnn, chnn, kernel_size=1) 86 | self.gamma = nn.Parameter(torch.zeros(1)) 87 | 88 | def forward(self, inputs): 89 | b, n, c, h, w = inputs.shape 90 | feat_s = [inputs[:,ii,:] for ii in range(self.n_node)] 91 | pred_s =[] 92 | for idx_node in range(self.n_node): 93 | h_t = feat_s[idx_node] 94 | h_t_m = h_t.repeat(1, self.n_node-1, 1, 1) 95 | h_n = torch.cat([feat_s[ii] for ii in range(self.n_node) if ii != idx_node], dim=1) 96 | msg = self._get_msg(h_t_m, h_n) 97 | m_t = torch.sum(msg.view(b, -1, c, h, w), dim=1) 98 | h_t = self.ConvGRU(m_t, h_t) 99 | base = feat_s[idx_node] 100 | pred_s.append(h_t*self.gamma+base) 101 | pred = torch.stack(pred_s).permute(1, 0, 2, 3, 4).contiguous() 102 | return pred 103 | 104 | def _get_msg(self, x1, x2): 105 | b, c, h, w = x1.shape 106 | wgt = self.C_wgt(x1 - x2).unsqueeze(1).repeat(1, c//(self.n_node-1), 1, 1, 1).view(b, c, h, w) 107 | out = x2 * torch.sigmoid(wgt) 108 | return out 109 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from graph import GraphReasoning 7 | 8 | 9 | class VGG(nn.Module): 10 | def __init__(self): 11 | super(VGG, self).__init__() 12 | # conv1, 2 layers 13 | self.conv1_1 = nn.Conv2d(3, 64, 3, padding=1) 14 | self.bn1_1 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 15 | self.relu1_1 = nn.ReLU(inplace=True) 16 | self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1) 17 | self.bn1_2 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 18 | self.relu1_2 = nn.ReLU(inplace=True) 19 | self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 20 | # conv2, 2 layers 21 | self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1) 22 | self.bn2_1 = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True) 23 | self.relu2_1 = nn.ReLU(inplace=True) 24 | self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1) 25 | self.bn2_2 = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True) 26 | self.relu2_2 = nn.ReLU(inplace=True) 27 | self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 28 | # conv3, 4 layers 29 | self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1) 30 | self.bn3_1 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) 31 | self.relu3_1 = nn.ReLU(inplace=True) 32 | self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1) 33 | self.bn3_2 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) 34 | self.relu3_2 = nn.ReLU(inplace=True) 35 | self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1) 36 | self.bn3_3 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) 37 | self.relu3_3 = nn.ReLU(inplace=True) 38 | self.conv3_4 = nn.Conv2d(256, 256, 3, padding=1) 39 | self.bn3_4 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) 40 | self.relu3_4 = nn.ReLU(inplace=True) 41 | self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 42 | # conv4, 4 layers 43 | self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1) 44 | self.bn4_1 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 45 | self.relu4_1 = nn.ReLU(inplace=True) 46 | self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1) 47 | self.bn4_2 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 48 | self.relu4_2 = nn.ReLU(inplace=True) 49 | self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1) 50 | self.bn4_3 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 51 | self.relu4_3 = nn.ReLU(inplace=True) 52 | self.conv4_4 = nn.Conv2d(512, 512, 3, padding=1) 53 | self.bn4_4 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 54 | self.relu4_4 = nn.ReLU(inplace=True) 55 | self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 56 | # conv5, 4 layers 57 | dila = [2, 4, 8, 16] 58 | self.conv5_1 = nn.Conv2d(512, 512, 3, padding=dila[0], dilation=dila[0]) 59 | self.bn5_1 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 60 | self.relu5_1 = nn.ReLU(inplace=True) 61 | self.conv5_2 = nn.Conv2d(512, 512, 3, padding=dila[1], dilation=dila[1]) 62 | self.bn5_2 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 63 | self.relu5_2 = nn.ReLU(inplace=True) 64 | self.conv5_3 = nn.Conv2d(512, 512, 3, padding=dila[2], dilation=dila[2]) 65 | self.bn5_3 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 66 | self.relu5_3 = nn.ReLU(inplace=True) 67 | self.conv5_4 = nn.Conv2d(512, 512, 3, padding=dila[3], dilation=dila[3]) 68 | self.bn5_4 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 69 | self.relu5_4 = nn.ReLU(inplace=True) 70 | 71 | def forward(self, x): 72 | h = x # [3, 256, 256] 73 | h = self.relu1_1(self.bn1_1(self.conv1_1(h))) 74 | h = self.relu1_2(self.bn1_2(self.conv1_2(h))) 75 | h_nopool1 = h 76 | h = self.pool1(h) 77 | h1 = h_nopool1 # [64, 256, 256] 78 | h = self.relu2_1(self.bn2_1(self.conv2_1(h))) 79 | h = self.relu2_2(self.bn2_2(self.conv2_2(h))) 80 | h_nopool2 = h 81 | h = self.pool2(h) 82 | h2 = h_nopool2 # [128, 128, 128] 83 | h = self.relu3_1(self.bn3_1(self.conv3_1(h))) 84 | h = self.relu3_2(self.bn3_2(self.conv3_2(h))) 85 | h = self.relu3_3(self.bn3_3(self.conv3_3(h))) 86 | h = self.relu3_4(self.bn3_4(self.conv3_4(h))) 87 | h_nopool3 = h 88 | h = self.pool3(h) 89 | h3 = h_nopool3 # [256, 64, 64] 90 | h = self.relu4_1(self.bn4_1(self.conv4_1(h))) 91 | h = self.relu4_2(self.bn4_2(self.conv4_2(h))) 92 | h = self.relu4_3(self.bn4_3(self.conv4_3(h))) 93 | h = self.relu4_4(self.bn4_4(self.conv4_4(h))) 94 | h_nopool4 = h 95 | #h = self.pool4(h) 96 | #h4 = h_nopool4 # [512, 32, 32] 97 | h = self.relu5_1(self.bn5_1(self.conv5_1(h))) 98 | h = self.relu5_2(self.bn5_2(self.conv5_2(h))) 99 | h = self.relu5_3(self.bn5_3(self.conv5_3(h))) 100 | h = self.relu5_4(self.bn5_4(self.conv5_4(h))) 101 | h5 = h # [512, 32, 32] 102 | return h5, h3, h2 #h4 h1 103 | 104 | def copy_params_from_vgg19_bn(self, vgg19_bn): 105 | features = [ 106 | self.conv1_1, self.bn1_1, self.relu1_1, 107 | self.conv1_2, self.bn1_2, self.relu1_2, 108 | self.pool1, 109 | self.conv2_1, self.bn2_1, self.relu2_1, 110 | self.conv2_2, self.bn2_2, self.relu2_2, 111 | self.pool2, 112 | self.conv3_1, self.bn3_1, self.relu3_1, 113 | self.conv3_2, self.bn3_2, self.relu3_2, 114 | self.conv3_3, self.bn3_3, self.relu3_3, 115 | self.conv3_4, self.bn3_4, self.relu3_4, 116 | self.pool3, 117 | self.conv4_1, self.bn4_1, self.relu4_1, 118 | self.conv4_2, self.bn4_2, self.relu4_2, 119 | self.conv4_3, self.bn4_3, self.relu4_3, 120 | self.conv4_4, self.bn4_4, self.relu4_4, 121 | self.pool4, 122 | self.conv5_1, self.bn5_1, self.relu5_1, 123 | self.conv5_2, self.bn5_2, self.relu5_2, 124 | self.conv5_3, self.bn5_3, self.relu5_3, 125 | self.conv5_4, self.bn5_4, self.relu5_4, 126 | ] 127 | for l1, l2 in zip(vgg19_bn.features, features): 128 | if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d): 129 | assert l1.weight.size() == l2.weight.size() 130 | assert l1.bias.size() == l2.bias.size() 131 | l2.weight.data = l1.weight.data 132 | l2.bias.data = l1.bias.data 133 | if isinstance(l1, nn.BatchNorm2d) and isinstance(l2, nn.BatchNorm2d): 134 | assert l1.weight.size() == l2.weight.size() 135 | assert l1.bias.size() == l2.bias.size() 136 | l2.weight.data = l1.weight.data 137 | l2.bias.data = l1.bias.data 138 | 139 | 140 | class CGR(nn.Module): 141 | def __init__(self, n_class=2, n_iter=2, chnn_side=(512, 256, 128), chnn_targ=(512, 128, 32, 4), rd_sc=32, dila=(4, 8, 16)): 142 | super().__init__() 143 | self.n_graph = len(chnn_side) 144 | n_node = len(dila) 145 | graph = [GraphReasoning(ii, rd_sc, dila, n_iter) for ii in chnn_side] 146 | self.graph = nn.ModuleList(graph) 147 | C_cat = [nn.Sequential( 148 | nn.Conv2d(ii//rd_sc*n_node, ii//rd_sc, 3, 1, 1, bias=False), 149 | nn.BatchNorm2d(ii//rd_sc), 150 | nn.ReLU(inplace=True)) 151 | for ii in (chnn_side+chnn_side)] 152 | self.C_cat = nn.ModuleList(C_cat) 153 | idx = [ii for ii in range(len(chnn_side))] 154 | C_up = [nn.Sequential( 155 | nn.Conv2d(chnn_targ[ii]+chnn_side[ii]//rd_sc, chnn_targ[ii+1], 3, 1, 1, bias=False), 156 | nn.BatchNorm2d(chnn_targ[ii+1]), 157 | nn.ReLU(inplace=True)) 158 | for ii in (idx+idx)] 159 | self.C_up = nn.ModuleList(C_up) 160 | self.C_cls = nn.Conv2d(chnn_targ[-1]*2, n_class, 1) 161 | 162 | def forward(self, inputs): 163 | img, depth = inputs 164 | cas_rgb, cas_dep = img[0], depth[0] 165 | nd_rgb, nd_dep, nd_key = None, None, False 166 | for ii in range(self.n_graph): 167 | feat_rgb, feat_dep = self.graph[ii]([img[ii], depth[ii], nd_rgb, nd_dep], nd_key) 168 | feat_rgb = torch.cat(feat_rgb, 1) 169 | feat_rgb = self.C_cat[ii](feat_rgb) 170 | feat_dep = torch.cat(feat_dep, 1) 171 | feat_dep = self.C_cat[self.n_graph+ii](feat_dep) 172 | nd_rgb, nd_dep, nd_key = feat_rgb, feat_dep, True 173 | cas_rgb = torch.cat((feat_rgb, cas_rgb), 1) 174 | cas_rgb = F.interpolate(cas_rgb, scale_factor=2, mode='bilinear', align_corners=True) 175 | cas_rgb = self.C_up[ii](cas_rgb) 176 | cas_dep = torch.cat((feat_dep, cas_dep), 1) 177 | cas_dep = F.interpolate(cas_dep, scale_factor=2, mode='bilinear', align_corners=True) 178 | cas_dep = self.C_up[self.n_graph+ii](cas_dep) 179 | feat = torch.cat((cas_rgb, cas_dep), 1) 180 | out = self.C_cls(feat) 181 | return out 182 | 183 | 184 | class CasGnn(nn.Module): 185 | def __init__(self): 186 | super().__init__() 187 | self.enc_rgb = VGG() 188 | self.enc_dep = VGG() 189 | # Cascade Graph Reasoning 190 | self.graph = CGR() 191 | 192 | def forward(self, inputs): 193 | img, depth = inputs 194 | feat_rgb = self.enc_rgb(img) 195 | feat_dep = self.enc_dep(depth) 196 | out = self.graph([feat_rgb, feat_dep]) 197 | return out 198 | --------------------------------------------------------------------------------