├── CoNet ├── __pycache__ │ ├── attention.cpython-35.pyc │ ├── attention.cpython-36.pyc │ ├── collector.cpython-36.pyc │ ├── dataset_loader.cpython-35.pyc │ ├── dataset_loader.cpython-36.pyc │ ├── functions.cpython-35.pyc │ ├── functions.cpython-36.pyc │ ├── fusion.cpython-35.pyc │ ├── fusion.cpython-36.pyc │ ├── model.cpython-35.pyc │ ├── model.cpython-36.pyc │ ├── trainer.cpython-35.pyc │ └── trainer.cpython-36.pyc ├── collector.py ├── dataset_loader.py ├── demo.py ├── functions.py ├── gen_edge.py ├── model.py ├── modeling │ └── sync_batchnorm │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-35.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── batchnorm.cpython-35.pyc │ │ ├── batchnorm.cpython-36.pyc │ │ ├── comm.cpython-35.pyc │ │ ├── comm.cpython-36.pyc │ │ ├── replicate.cpython-35.pyc │ │ └── replicate.cpython-36.pyc │ │ ├── batchnorm.py │ │ ├── comm.py │ │ ├── replicate.py │ │ └── unittest.py ├── trainer.py └── transform.py ├── Comparison.png ├── README.md ├── egbib.bib └── overall111.png /CoNet/__pycache__/attention.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiwei0921/CoNet/f9c1719f9c1a843f1eb04ac5271c4f2e061dfa90/CoNet/__pycache__/attention.cpython-35.pyc -------------------------------------------------------------------------------- /CoNet/__pycache__/attention.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiwei0921/CoNet/f9c1719f9c1a843f1eb04ac5271c4f2e061dfa90/CoNet/__pycache__/attention.cpython-36.pyc -------------------------------------------------------------------------------- /CoNet/__pycache__/collector.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiwei0921/CoNet/f9c1719f9c1a843f1eb04ac5271c4f2e061dfa90/CoNet/__pycache__/collector.cpython-36.pyc -------------------------------------------------------------------------------- /CoNet/__pycache__/dataset_loader.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiwei0921/CoNet/f9c1719f9c1a843f1eb04ac5271c4f2e061dfa90/CoNet/__pycache__/dataset_loader.cpython-35.pyc -------------------------------------------------------------------------------- /CoNet/__pycache__/dataset_loader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiwei0921/CoNet/f9c1719f9c1a843f1eb04ac5271c4f2e061dfa90/CoNet/__pycache__/dataset_loader.cpython-36.pyc -------------------------------------------------------------------------------- /CoNet/__pycache__/functions.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiwei0921/CoNet/f9c1719f9c1a843f1eb04ac5271c4f2e061dfa90/CoNet/__pycache__/functions.cpython-35.pyc -------------------------------------------------------------------------------- /CoNet/__pycache__/functions.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiwei0921/CoNet/f9c1719f9c1a843f1eb04ac5271c4f2e061dfa90/CoNet/__pycache__/functions.cpython-36.pyc -------------------------------------------------------------------------------- /CoNet/__pycache__/fusion.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiwei0921/CoNet/f9c1719f9c1a843f1eb04ac5271c4f2e061dfa90/CoNet/__pycache__/fusion.cpython-35.pyc -------------------------------------------------------------------------------- /CoNet/__pycache__/fusion.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiwei0921/CoNet/f9c1719f9c1a843f1eb04ac5271c4f2e061dfa90/CoNet/__pycache__/fusion.cpython-36.pyc -------------------------------------------------------------------------------- /CoNet/__pycache__/model.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiwei0921/CoNet/f9c1719f9c1a843f1eb04ac5271c4f2e061dfa90/CoNet/__pycache__/model.cpython-35.pyc -------------------------------------------------------------------------------- /CoNet/__pycache__/model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiwei0921/CoNet/f9c1719f9c1a843f1eb04ac5271c4f2e061dfa90/CoNet/__pycache__/model.cpython-36.pyc -------------------------------------------------------------------------------- /CoNet/__pycache__/trainer.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiwei0921/CoNet/f9c1719f9c1a843f1eb04ac5271c4f2e061dfa90/CoNet/__pycache__/trainer.cpython-35.pyc -------------------------------------------------------------------------------- /CoNet/__pycache__/trainer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiwei0921/CoNet/f9c1719f9c1a843f1eb04ac5271c4f2e061dfa90/CoNet/__pycache__/trainer.cpython-36.pyc -------------------------------------------------------------------------------- /CoNet/collector.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | 7 | class TriAtt(nn.Module): 8 | def __init__(self): 9 | super(TriAtt, self).__init__() 10 | 11 | self.conv_concat = nn.Conv2d(2*2, 2, 1, padding=0) 12 | self.predict = nn.Conv2d(128, 2, 1, padding=0) 13 | 14 | 15 | self._initialize_weights() 16 | 17 | def _initialize_weights(self): 18 | for m in self.modules(): 19 | if isinstance(m, nn.Conv2d): 20 | nn.init.normal_(m.weight.data, std=0.01) 21 | if m.bias is not None: 22 | m.bias.data.zero_() 23 | 24 | 25 | def forward(self,Features, Edge, Sal, Depth): 26 | 27 | # -------------------- Knowledge Collector ------------------- # 28 | Feature_d = torch.mul(Features, Depth) 29 | Feature_d = Features + Feature_d 30 | 31 | Att = torch.cat([Edge,Sal],dim=1) 32 | Att = self.conv_concat(Att) 33 | Att = F.softmax(Att,dim=1)[:,1:,:] 34 | 35 | Feature_dcs = torch.mul(Feature_d, Att) 36 | Feature_all = Feature_dcs + Feature_d 37 | 38 | outputs = self.predict(Feature_all) 39 | 40 | return outputs 41 | 42 | -------------------------------------------------------------------------------- /CoNet/dataset_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import PIL.Image 5 | import scipy.io as sio 6 | import torch 7 | from torch.utils import data 8 | 9 | class MyData(data.Dataset): # inherit 10 | """ 11 | load data in a folder 12 | """ 13 | mean_rgb = np.array([0.447, 0.407, 0.386]) 14 | std_rgb = np.array([0.244, 0.250, 0.253]) 15 | def __init__(self, root, transform=False): 16 | super(MyData, self).__init__() 17 | self.root = root 18 | 19 | self._transform = transform 20 | 21 | img_root = os.path.join(self.root, 'train_images') 22 | lbl_root = os.path.join(self.root, 'train_masks') 23 | depth_root = os.path.join(self.root, 'train_depth') 24 | edge_root = os.path.join(self.root, 'train_edge') 25 | 26 | file_names = os.listdir(img_root) 27 | self.img_names = [] 28 | self.lbl_names = [] 29 | self.depth_names = [] 30 | self.edge_names = [] 31 | for i, name in enumerate(file_names): 32 | if not name.endswith('.jpg'): 33 | continue 34 | self.lbl_names.append( 35 | os.path.join(lbl_root, name[:-4]+'.png') 36 | ) 37 | self.img_names.append( 38 | os.path.join(img_root, name) 39 | ) 40 | self.depth_names.append( 41 | os.path.join(depth_root, name[:-4] + '.png') 42 | ) 43 | self.edge_names.append( 44 | os.path.join(edge_root, name[:-4] + '.png') 45 | ) 46 | 47 | def __len__(self): 48 | return len(self.img_names) 49 | 50 | def __getitem__(self, index): 51 | # load image 52 | img_file = self.img_names[index] 53 | img = PIL.Image.open(img_file) 54 | # img = img.resize((256, 256)) 55 | img = np.array(img, dtype=np.uint8) 56 | # load label 57 | lbl_file = self.lbl_names[index] 58 | lbl = PIL.Image.open(lbl_file) 59 | # lbl = lbl.resize((256, 256)) 60 | lbl = np.array(lbl, dtype=np.int32) 61 | lbl[lbl != 0] = 1 62 | # load depth 63 | depth_file = self.depth_names[index] 64 | depth = PIL.Image.open(depth_file) 65 | # depth = depth.resize(256, 256) 66 | depth = np.array(depth, dtype=np.uint8) 67 | # load edge 68 | edge_file = self.edge_names[index] 69 | edge = PIL.Image.open(edge_file) 70 | edge = np.array(edge, dtype=np.uint8) 71 | 72 | 73 | 74 | if self._transform: 75 | return self.transform(img, lbl, depth, edge) 76 | else: 77 | return img, lbl, depth, edge 78 | 79 | 80 | # Translating numpy_array into format that pytorch can use on Code. 81 | def transform(self, img, lbl, depth, edge): 82 | 83 | img = img.astype(np.float64)/255.0 84 | img -= self.mean_rgb 85 | img /= self.std_rgb 86 | img = img.transpose(2, 0, 1) # to verify 87 | img = torch.from_numpy(img).float() 88 | lbl = torch.from_numpy(lbl).long() 89 | depth = depth.astype(np.float64)/255.0 90 | depth = torch.from_numpy(depth).float() 91 | edge = edge.astype(np.float64) / 255.0 92 | edge = torch.from_numpy(edge).long() 93 | return img, lbl, depth, edge 94 | 95 | 96 | class MyTestData(data.Dataset): 97 | """ 98 | load data in a folder 99 | """ 100 | mean_rgb = np.array([0.447, 0.407, 0.386]) 101 | std_rgb = np.array([0.244, 0.250, 0.253]) 102 | 103 | 104 | def __init__(self, root, transform=False): 105 | super(MyTestData, self).__init__() 106 | self.root = root 107 | self._transform = transform 108 | 109 | img_root = os.path.join(self.root, 'test_images') 110 | file_names = os.listdir(img_root) 111 | self.img_names = [] 112 | self.names = [] 113 | 114 | for i, name in enumerate(file_names): 115 | if not name.endswith('.jpg'): 116 | continue 117 | self.img_names.append( 118 | os.path.join(img_root, name) 119 | ) 120 | self.names.append(name[:-4]) 121 | 122 | 123 | def __len__(self): 124 | return len(self.img_names) 125 | 126 | def __getitem__(self, index): 127 | # load image 128 | img_file = self.img_names[index] 129 | img = PIL.Image.open(img_file) 130 | img_size = img.size 131 | # img = img.resize((256, 256)) 132 | img = np.array(img, dtype=np.uint8) 133 | 134 | 135 | if self._transform: 136 | img = self.transform(img) 137 | return img, self.names[index], img_size 138 | else: 139 | return img, self.names[index], img_size 140 | 141 | def transform(self, img): 142 | img = img.astype(np.float64)/255.0 143 | img -= self.mean_rgb 144 | img /= self.std_rgb 145 | img = img.transpose(2, 0, 1) 146 | img = torch.from_numpy(img).float() 147 | 148 | return img 149 | -------------------------------------------------------------------------------- /CoNet/demo.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | import torch 3 | from torch.autograd import Variable 4 | from torch.utils.data import DataLoader 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.optim as optim 8 | 9 | from dataset_loader import MyData, MyTestData 10 | from model import ResNet101,Integration 11 | from collector import TriAtt 12 | from functions import imsave 13 | from trainer import Trainer 14 | 15 | import os 16 | import argparse 17 | import time 18 | 19 | configurations = { 20 | 1: dict( 21 | max_iteration=2000000, 22 | lr=1.0e-10, 23 | momentum=0.99, 24 | weight_decay=0.0005, 25 | spshot=20000, 26 | nclass=2, 27 | sshow=10, 28 | )} 29 | 30 | parameters = { 31 | "phase":"test", # train or test 32 | "param":True, # True or False 33 | "dataset":"NJUD", # DUT-RGBD, NJUD, NLPR, STEREO, LFSD, RGBD135 34 | "snap_num":str(1200000)+'.pth', # Snapshot Number 35 | } 36 | 37 | 38 | parser=argparse.ArgumentParser() 39 | parser.add_argument('--phase', type=str, default=parameters["phase"], help='train or test') 40 | parser.add_argument('--param', type=str, default=parameters["param"], help='path to pre-trained parameters') 41 | parser.add_argument('--train_dataroot', type=str, default='/dockerdata/weiji/Code/Data/E_Depth/train_data/', help='path to train data') 42 | parser.add_argument('--test_dataroot', type=str, default='/dockerdata/weiji/Code/Data/E_Depth/test_data/'+ parameters["dataset"], help='path to test data') 43 | parser.add_argument('--snapshot_root', type=str, default='../Out/snapshot', help='path to snapshot') 44 | parser.add_argument('--salmap_root', type=str, default='../Out/sal_map', help='path to saliency map') 45 | parser.add_argument('--out1', type=str, default='../Out/edge', help='path to saliency map') 46 | parser.add_argument('--out2', type=str, default='../Out/depth', help='path to saliency map') 47 | parser.add_argument('--out3', type=str, default='../Out/sal_att', help='path to saliency map') 48 | parser.add_argument('-c', '--config', type=int, default=1, choices=configurations.keys()) 49 | args = parser.parse_args() 50 | cfg = configurations[args.config] 51 | os.environ["CUDA_VISIBLE_DEVICES"]="0" 52 | cuda = torch.cuda.is_available() 53 | 54 | 55 | """""""""""~~~ dataset loader ~~~""""""""" 56 | train_dataRoot = args.train_dataroot 57 | test_dataRoot = args.test_dataroot 58 | if not os.path.exists(args.snapshot_root): 59 | os.mkdir(args.snapshot_root) 60 | if not os.path.exists(args.salmap_root): 61 | os.mkdir(args.salmap_root) 62 | if not os.path.exists(args.out1): 63 | os.mkdir(args.out1) 64 | if not os.path.exists(args.out2): 65 | os.mkdir(args.out2) 66 | if not os.path.exists(args.out3): 67 | os.mkdir(args.out3) 68 | 69 | if args.phase == 'train': 70 | SnapRoot = args.snapshot_root # checkpoint 71 | train_loader = torch.utils.data.DataLoader(MyData(train_dataRoot, transform=True),batch_size=2, shuffle=True, num_workers=4, pin_memory=True) 72 | else: 73 | MapRoot = args.salmap_root 74 | test_loader = torch.utils.data.DataLoader(MyTestData(test_dataRoot, transform=True),batch_size=1, shuffle=True, num_workers=4, pin_memory=True) 75 | 76 | print ('data already') 77 | 78 | 79 | """"""""""" ~~~nets~~~ """"""""" 80 | start_epoch = 0 81 | start_iteration = 0 82 | model_rgb = ResNet101(BatchNorm=nn.BatchNorm2d, pretrained=bool(1-args.param), output_stride=16) 83 | model_intergration = Integration() 84 | model_att = TriAtt() 85 | 86 | 87 | if args.param is True: 88 | model_rgb.load_state_dict(torch.load(os.path.join(args.snapshot_root, 'snapshot_iter_'+ parameters["snap_num"]))) 89 | model_intergration.load_state_dict(torch.load(os.path.join(args.snapshot_root, 'integrate_snapshot_iter_'+ parameters["snap_num"]))) 90 | model_att.load_state_dict(torch.load(os.path.join(args.snapshot_root, 'att_snapshot_iter_'+ parameters["snap_num"]))) 91 | 92 | 93 | if cuda: 94 | model_rgb = model_rgb.cuda() 95 | model_intergration = model_intergration.cuda() 96 | model_att = model_att.cuda() 97 | 98 | 99 | if args.phase == 'train': 100 | 101 | #Trainer: class, defined in trainer.py 102 | optimizer_rgb = optim.SGD(model_rgb.parameters(), lr=cfg['lr'],momentum=cfg['momentum'], weight_decay=cfg['weight_decay']) 103 | optimizer_inter = optim.SGD(model_intergration.parameters(), lr=cfg['lr'],momentum=cfg['momentum'], weight_decay=cfg['weight_decay']) 104 | optimizer_att = optim.SGD(model_att.parameters(), lr=cfg['lr'],momentum=cfg['momentum'], weight_decay=cfg['weight_decay']) 105 | training = Trainer( 106 | cuda=cuda, 107 | model_rgb=model_rgb, 108 | model_intergration=model_intergration, 109 | model_att=model_att, 110 | optimizer_rgb=optimizer_rgb, 111 | optimizer_inter=optimizer_inter, 112 | optimizer_att=optimizer_att, 113 | train_loader=train_loader, 114 | max_iter=cfg['max_iteration'], 115 | snapshot=cfg['spshot'], 116 | outpath=args.snapshot_root, 117 | sshow=cfg['sshow'] 118 | ) 119 | training.epoch = start_epoch 120 | training.iteration = start_iteration 121 | training.train() 122 | else: 123 | res = [] 124 | for id, (data, img_name, img_size) in enumerate(test_loader): 125 | print('testing bach %d' % (id+1)) 126 | 127 | inputs = Variable(data).cuda() 128 | n, c, h, w = inputs.size() 129 | begin_time = time.time() 130 | 131 | low_1, low_2, high_1, high_2, high_3 = model_rgb(inputs) 132 | Features, _, _, Edge, _, _, Depth, Sal = model_intergration(low_1, low_2, high_1, high_2, high_3) 133 | outputs = model_att(Features, Edge, Sal, Depth) 134 | outputs = F.softmax(outputs, dim=1) 135 | outputs = outputs[0][1] 136 | outputs = outputs.cpu().data.resize_(h, w) 137 | end_time = time.time() 138 | run_time = end_time - begin_time 139 | res.append(run_time) 140 | imsave(os.path.join(MapRoot,img_name[0] + '.png'), outputs, img_size) 141 | 142 | # ---------------- Visual Results ------------------ # 143 | # Edge 144 | out1 = F.softmax(Edge, dim=1) 145 | out1 = out1[0][1] 146 | out1 = out1.cpu().data.resize_(h, w) 147 | imsave(os.path.join(args.out1, img_name[0] + '.png'), out1, img_size) 148 | # Depth 149 | out2 = Depth[0][0] 150 | out2 = out2.cpu().data.resize_(h, w) 151 | imsave(os.path.join(args.out2, img_name[0] + '.png'), out2, img_size) 152 | # Sal-Att 153 | out3 = Sal[0][1] 154 | out3 = out3.cpu().data.resize_(h, w) 155 | imsave(os.path.join(args.out3, img_name[0] + '.png'), out3, img_size) 156 | # -------------------------------------------------- # 157 | 158 | print('The testing process has finished!') 159 | time_sum = 0 160 | for i in res: 161 | time_sum +=i 162 | print("FPS: %f"%(1.0/(time_sum/len(res)))) 163 | -------------------------------------------------------------------------------- /CoNet/functions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import torch 4 | #from scipy.misc import imresize 5 | 6 | def imsave(file_name, img, img_size): 7 | """ 8 | save a torch tensor as an image 9 | :param file_name: 'image/folder/image_name' 10 | :param img: 3*h*w torch tensor 11 | :return: nothing 12 | """ 13 | assert(type(img) == torch.FloatTensor, 14 | 'img must be a torch.FloatTensor') 15 | ndim = len(img.size()) 16 | assert(ndim == 2 or ndim == 3, 17 | 'img must be a 2 or 3 dimensional tensor') 18 | 19 | img = img.numpy() 20 | # img = imresize(img, [img_size[1][0], img_size[0][0]], interp='nearest') 21 | if ndim == 3: 22 | plt.imsave(file_name, np.transpose(img, (1, 2, 0))) 23 | else: 24 | plt.imsave(file_name, img, cmap='gray') 25 | -------------------------------------------------------------------------------- /CoNet/gen_edge.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | 4 | def Edge_Extract(root): 5 | img_root = os.path.join(root,'train_masks') 6 | edge_root = os.path.join(root,'train_edge') 7 | 8 | if not os.path.exists(edge_root): 9 | os.mkdir(edge_root) 10 | 11 | file_names = os.listdir(img_root) 12 | img_name = [] 13 | 14 | for name in file_names: 15 | if not name.endswith('.png'): 16 | assert "This file %s is not PNG"%(name) 17 | img_name.append(os.path.join(img_root,name[:-4]+'.png')) 18 | 19 | index = 0 20 | for image in img_name: 21 | img = cv2.imread(image,0) 22 | cv2.imwrite(edge_root+'/'+file_names[index],cv2.Canny(img,30,100)) 23 | index += 1 24 | return 0 25 | 26 | 27 | if __name__ == '__main__': 28 | root = '/data/jiwei/Datasets/train_data' 29 | Edge_Extract(root) 30 | -------------------------------------------------------------------------------- /CoNet/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import torch.nn as nn 4 | import numpy as np 5 | import torch.nn.functional as F 6 | import torch.utils.model_zoo as model_zoo 7 | from modeling.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 8 | 9 | 10 | def get_upsampling_weight(in_channels, out_channels, kernel_size): 11 | """Make a 2D bilinear kernel suitable for upsampling""" 12 | factor = (kernel_size + 1) // 2 13 | if kernel_size % 2 == 1: 14 | center = factor - 1 15 | else: 16 | center = factor - 0.5 17 | og = np.ogrid[:kernel_size, :kernel_size] 18 | filt = (1 - abs(og[0] - center) / factor) * \ 19 | (1 - abs(og[1] - center) / factor) 20 | weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size), 21 | dtype=np.float64) 22 | weight[range(in_channels), range(out_channels), :, :] = filt 23 | return torch.from_numpy(weight).float() 24 | 25 | 26 | # ---------------------------------- ResNet101 ---------------------------------- # 27 | class Bottleneck(nn.Module): 28 | expansion = 4 29 | 30 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None): 31 | super(Bottleneck, self).__init__() 32 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 33 | self.bn1 = BatchNorm(planes) 34 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 35 | dilation=dilation, padding=dilation, bias=False) 36 | self.bn2 = BatchNorm(planes) 37 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 38 | self.bn3 = BatchNorm(planes * 4) 39 | self.relu = nn.ReLU(inplace=True) 40 | self.downsample = downsample 41 | self.stride = stride 42 | self.dilation = dilation 43 | 44 | def forward(self, x): 45 | residual = x 46 | 47 | out = self.conv1(x) 48 | out = self.bn1(out) 49 | out = self.relu(out) 50 | 51 | out = self.conv2(out) 52 | out = self.bn2(out) 53 | out = self.relu(out) 54 | 55 | out = self.conv3(out) 56 | out = self.bn3(out) 57 | 58 | if self.downsample is not None: 59 | residual = self.downsample(x) 60 | 61 | out += residual 62 | out = self.relu(out) 63 | 64 | return out 65 | 66 | class ResNet(nn.Module): 67 | 68 | def __init__(self, block, layers, output_stride, BatchNorm, pretrained=True): 69 | self.inplanes = 64 70 | super(ResNet, self).__init__() 71 | blocks = [1, 2, 4] 72 | if output_stride == 16: 73 | strides = [1, 2, 2, 1] 74 | dilations = [1, 1, 1, 2] 75 | elif output_stride == 8: 76 | strides = [1, 2, 1, 1] 77 | dilations = [1, 1, 2, 4] 78 | else: 79 | raise NotImplementedError 80 | 81 | # Modules 82 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 83 | bias=False) 84 | self.bn1 = BatchNorm(64) 85 | self.relu = nn.ReLU(inplace=True) 86 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 87 | 88 | self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], dilation=dilations[0], BatchNorm=BatchNorm) 89 | self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1], BatchNorm=BatchNorm) 90 | self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2], BatchNorm=BatchNorm) 91 | self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm) 92 | self._init_weight() 93 | 94 | if pretrained: 95 | self._load_pretrained_model() 96 | 97 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 98 | downsample = None 99 | if stride != 1 or self.inplanes != planes * block.expansion: 100 | downsample = nn.Sequential( 101 | nn.Conv2d(self.inplanes, planes * block.expansion, 102 | kernel_size=1, stride=stride, bias=False), 103 | BatchNorm(planes * block.expansion), 104 | ) 105 | 106 | layers = [] 107 | layers.append(block(self.inplanes, planes, stride, dilation, downsample, BatchNorm)) 108 | self.inplanes = planes * block.expansion 109 | for i in range(1, blocks): 110 | layers.append(block(self.inplanes, planes, dilation=dilation, BatchNorm=BatchNorm)) 111 | 112 | return nn.Sequential(*layers) 113 | 114 | def _make_MG_unit(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 115 | downsample = None 116 | if stride != 1 or self.inplanes != planes * block.expansion: 117 | downsample = nn.Sequential( 118 | nn.Conv2d(self.inplanes, planes * block.expansion, 119 | kernel_size=1, stride=stride, bias=False), 120 | BatchNorm(planes * block.expansion), 121 | ) 122 | 123 | layers = [] 124 | layers.append(block(self.inplanes, planes, stride, dilation=blocks[0]*dilation, 125 | downsample=downsample, BatchNorm=BatchNorm)) 126 | self.inplanes = planes * block.expansion 127 | for i in range(1, len(blocks)): 128 | layers.append(block(self.inplanes, planes, stride=1, 129 | dilation=blocks[i]*dilation, BatchNorm=BatchNorm)) 130 | 131 | return nn.Sequential(*layers) 132 | 133 | def forward(self, input): 134 | x = self.conv1(input) 135 | x = self.bn1(x) 136 | x = self.relu(x) 137 | x = self.maxpool(x) 138 | low_1= x 139 | 140 | x = self.layer1(x) 141 | low_2 = x 142 | x = self.layer2(x) 143 | high_1 = x 144 | x = self.layer3(x) 145 | high_2 = x 146 | x = self.layer4(x) 147 | high_3 = x 148 | return low_1, low_2, high_1, high_2, high_3 149 | 150 | def _init_weight(self): 151 | for m in self.modules(): 152 | if isinstance(m, nn.Conv2d): 153 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 154 | m.weight.data.normal_(0, math.sqrt(2. / n)) 155 | elif isinstance(m, SynchronizedBatchNorm2d): 156 | m.weight.data.fill_(1) 157 | m.bias.data.zero_() 158 | elif isinstance(m, nn.BatchNorm2d): 159 | m.weight.data.fill_(1) 160 | m.bias.data.zero_() 161 | 162 | def _load_pretrained_model(self): 163 | pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet101-5d3b4d8f.pth') 164 | model_dict = {} 165 | state_dict = self.state_dict() 166 | for k, v in pretrain_dict.items(): 167 | if k in state_dict: 168 | model_dict[k] = v 169 | state_dict.update(model_dict) 170 | self.load_state_dict(state_dict) 171 | 172 | def ResNet101(output_stride, BatchNorm, pretrained=True): 173 | """Constructs a ResNet-101 model. 174 | Args: 175 | pretrained (bool): If True, returns a model pre-trained on ImageNet 176 | """ 177 | model = ResNet(Bottleneck, [3, 4, 23, 3], output_stride, BatchNorm, pretrained=pretrained) 178 | return model 179 | # ------------------------------------------- End ------------------------------------------ # 180 | 181 | 182 | 183 | # --------------------------------- Collaborative Learning --------------------------------- # 184 | 185 | class Integration(nn.Module): 186 | def __init__(self): 187 | super(Integration, self).__init__() 188 | 189 | # ----------> Feature Extract <----------- 190 | # conv3 191 | self.conv3_0 = nn.Conv2d(512, 64, 1, padding=0) 192 | self.bn3_0 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 193 | self.relu3_0 = nn.PReLU() 194 | 195 | # conv4 196 | self.conv4_0 = nn.Conv2d(1024, 64, 1, padding=0) 197 | self.bn4_0 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 198 | self.relu4_0 = nn.PReLU() 199 | 200 | # conv5 201 | self.conv5_0 = nn.Conv2d(2048, 64, 1, padding=0) 202 | self.bn5_0 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 203 | self.relu5_0 = nn.PReLU() 204 | 205 | # ----------> Feature Extract < ----------- 206 | # ===== Low-level features ===== 207 | self.conv_low = nn.Conv2d(256 + 64, 64, 3, padding=1) 208 | self.bn_low = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 209 | self.relu_low = nn.PReLU() 210 | 211 | # ===== High-level features [GPM] ===== 212 | # --- Conv5 --- 213 | # part0: 1*1*64 Conv 214 | self.C5_conv1 = nn.Conv2d(64, 64, 1, padding=0) 215 | self.C5_bn1 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 216 | self.C5_relu1 = nn.PReLU() 217 | # part1: 3*3*64 Conv dilation =1 218 | self.C5_conv2 = nn.Conv2d(64, 64, 3, padding=1) 219 | self.C5_bn2 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 220 | self.C5_relu2 = nn.PReLU() 221 | # part2: 3*3*64 Conv dilation =6 222 | self.C5_conv3 = nn.Conv2d(64, 64, 3, padding=6, dilation=6) 223 | self.C5_bn3 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 224 | self.C5_relu3 = nn.PReLU() 225 | # part3: 3*3*64 Conv dilation =12 226 | self.C5_conv4 = nn.Conv2d(64, 64, 3, padding=12, dilation=12) 227 | self.C5_bn4 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 228 | self.C5_relu4 = nn.PReLU() 229 | # part4: 3*3*64 Conv dilation =18 230 | self.C5_conv5 = nn.Conv2d(64, 64, 3, padding=18, dilation=18) 231 | self.C5_bn5 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 232 | self.C5_relu5 = nn.PReLU() 233 | # part5: 1*1*64 Conv Concatenation 234 | self.C5_conv = nn.Conv2d(64 * 5, 64, 1, padding=0) 235 | self.C5_bn = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 236 | self.C5_relu = nn.PReLU() 237 | 238 | # --- Conv4 --- 239 | # part0: 1*1*64 Conv 240 | self.C4_conv1 = nn.Conv2d(64, 64, 1, padding=0) 241 | self.C4_bn1 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 242 | self.C4_relu1 = nn.PReLU() 243 | # part1: 3*3*64 Conv dilation =1 244 | self.C4_conv2 = nn.Conv2d(64, 64, 3, padding=1) 245 | self.C4_bn2 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 246 | self.C4_relu2 = nn.PReLU() 247 | # part2: 3*3*64 Conv dilation =6 248 | self.C4_conv3 = nn.Conv2d(64, 64, 3, padding=6, dilation=6) 249 | self.C4_bn3 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 250 | self.C4_relu3 = nn.PReLU() 251 | # part3: 3*3*64 Conv dilation =12 252 | self.C4_conv4 = nn.Conv2d(64, 64, 3, padding=12, dilation=12) 253 | self.C4_bn4 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 254 | self.C4_relu4 = nn.PReLU() 255 | # part4: 3*3*64 Conv dilation =18 256 | self.C4_conv5 = nn.Conv2d(64, 64, 3, padding=18, dilation=18) 257 | self.C4_bn5 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 258 | self.C4_relu5 = nn.PReLU() 259 | # part5: 1*1*64 Conv Concatenation 260 | self.C4_conv = nn.Conv2d(64 * 5, 64, 1, padding=0) 261 | self.C4_bn = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 262 | self.C4_relu = nn.PReLU() 263 | 264 | # --- Conv3 --- 265 | # part0: 1*1*64 Conv 266 | self.C3_conv1 = nn.Conv2d(64, 64, 1, padding=0) 267 | self.C3_bn1 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 268 | self.C3_relu1 = nn.PReLU() 269 | # part1: 3*3*64 Conv dilation =1 270 | self.C3_conv2 = nn.Conv2d(64, 64, 3, padding=1) 271 | self.C3_bn2 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 272 | self.C3_relu2 = nn.PReLU() 273 | # part2: 3*3*64 Conv dilation =6 274 | self.C3_conv3 = nn.Conv2d(64, 64, 3, padding=6, dilation=6) 275 | self.C3_bn3 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 276 | self.C3_relu3 = nn.PReLU() 277 | # part3: 3*3*64 Conv dilation =12 278 | self.C3_conv4 = nn.Conv2d(64, 64, 3, padding=12, dilation=12) 279 | self.C3_bn4 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 280 | self.C3_relu4 = nn.PReLU() 281 | # part4: 3*3*64 Conv dilation =18 282 | self.C3_conv5 = nn.Conv2d(64, 64, 3, padding=18, dilation=18) 283 | self.C3_bn5 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 284 | self.C3_relu5 = nn.PReLU() 285 | # part5: 1*1*64 Conv Concatenation 286 | self.C3_conv = nn.Conv2d(64 * 5, 64, 1, padding=0) 287 | self.C3_bn = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 288 | self.C3_relu = nn.PReLU() 289 | 290 | self.conv_high = nn.Conv2d(64 * 3, 64, 3, padding=1) 291 | self.bn_high = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 292 | self.relu_high = nn.PReLU() 293 | 294 | # -------------------- Integration Training ------------------- # 295 | # Low-level Integration 296 | 297 | self.low_sal = nn.Conv2d(64, 2, 1, padding=0) 298 | self.pred1_sal = nn.Conv2d(64, 2, 1, padding=0) 299 | self.pred1_edge = nn.Conv2d(64, 2, 1, padding=0) 300 | 301 | # High-level Integration 302 | self.high_depth = nn.Conv2d(64, 1, 1, padding=0) 303 | self.high_sal = nn.Conv2d(64, 2, 1, padding=0) 304 | self.pred2_sal = nn.Conv2d(64, 2, 1, padding=0) 305 | self.pred2_depth = nn.Conv2d(64, 1, 1, padding=0) 306 | # channel attention 307 | self.conv_ca = nn.Conv2d(1, 64, 3, padding=1) 308 | self.pool_avg = nn.AvgPool2d(64, stride=2, ceil_mode=True) # 1/8 309 | 310 | # Depth CNN 311 | self.D_conv1 = nn.Conv2d(64, 64, 3, padding=1) 312 | self.D_bn1 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 313 | self.D_relu1 = nn.PReLU() 314 | self.D_conv2 = nn.Conv2d(64, 64, 3, padding=1) 315 | self.D_bn2 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 316 | self.D_relu2 = nn.PReLU() 317 | self.D_conv3 = nn.Conv2d(64, 64, 3, padding=1) 318 | self.D_bn3 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 319 | self.D_relu3 = nn.PReLU() 320 | 321 | self._initialize_weights() 322 | 323 | 324 | def _initialize_weights(self): 325 | for m in self.modules(): 326 | if isinstance(m, nn.Conv2d): 327 | # m.weight.data.zero_() 328 | nn.init.normal_(m.weight.data, std=0.01) 329 | if m.bias is not None: 330 | m.bias.data.zero_() 331 | if isinstance(m, nn.ConvTranspose2d): 332 | assert m.kernel_size[0] == m.kernel_size[1] 333 | initial_weight = get_upsampling_weight(m.in_channels, m.out_channels, m.kernel_size[0]) 334 | m.weight.data.copy_(initial_weight) 335 | 336 | def forward(self, low_1, low_2, high_1, high_2, high_3): 337 | 338 | # -----------> Feature Extract <----------- # 339 | # Low-level features 340 | Low = F.interpolate(torch.cat([low_1, low_2], dim=1), scale_factor=4, mode='bilinear', align_corners=False) 341 | Low = self.relu_low(self.bn_low(self.conv_low(Low))) 342 | 343 | # High-level features 344 | h3 = self.relu3_0(self.bn3_0(self.conv3_0(high_1))) 345 | h3 = F.interpolate(h3, scale_factor=2, mode='bilinear', align_corners=False) 346 | h4 = self.relu4_0(self.bn4_0(self.conv4_0(high_2))) 347 | h5 = self.relu5_0(self.bn5_0(self.conv5_0(high_3))) 348 | 349 | c5 = h5 350 | Conv5_1 = self.C5_relu1(self.C5_bn1(self.C5_conv1(c5))) 351 | Conv5_2 = self.C5_relu2(self.C5_bn2(self.C5_conv2(c5))) 352 | Conv5_3 = self.C5_relu3(self.C5_bn3(self.C5_conv3(c5))) 353 | Conv5_4 = self.C5_relu4(self.C5_bn4(self.C5_conv4(c5))) 354 | Conv5_5 = self.C5_relu5(self.C5_bn5(self.C5_conv5(c5))) 355 | Conv5_ori = self.C5_relu( 356 | self.C5_bn(self.C5_conv(torch.cat([Conv5_1, Conv5_2, Conv5_3, Conv5_4, Conv5_5], dim=1)))) 357 | Conv5 = F.interpolate(Conv5_ori, scale_factor=4, mode='bilinear', align_corners=False) 358 | 359 | c4 = Conv5_ori + h4 360 | Conv4_1 = self.C4_relu1(self.C4_bn1(self.C4_conv1(c4))) 361 | Conv4_2 = self.C4_relu2(self.C4_bn2(self.C4_conv2(c4))) 362 | Conv4_3 = self.C4_relu3(self.C4_bn3(self.C4_conv3(c4))) 363 | Conv4_4 = self.C4_relu4(self.C4_bn4(self.C4_conv4(c4))) 364 | Conv4_5 = self.C4_relu5(self.C4_bn5(self.C4_conv5(c4))) 365 | Conv4_ori = self.C4_relu( 366 | self.C4_bn(self.C4_conv(torch.cat([Conv4_1, Conv4_2, Conv4_3, Conv4_4, Conv4_5], dim=1)))) 367 | Conv4 = F.interpolate(Conv4_ori, scale_factor=4, mode='bilinear', align_corners=False) 368 | 369 | c3 = Conv5 + Conv4 + h3 370 | Conv3_1 = self.C3_relu1(self.C3_bn1(self.C3_conv1(c3))) 371 | Conv3_2 = self.C3_relu2(self.C3_bn2(self.C3_conv2(c3))) 372 | Conv3_3 = self.C3_relu3(self.C3_bn3(self.C3_conv3(c3))) 373 | Conv3_4 = self.C3_relu4(self.C3_bn4(self.C3_conv4(c3))) 374 | Conv3_5 = self.C3_relu5(self.C3_bn5(self.C3_conv5(c3))) 375 | Conv3 = self.C3_relu(self.C3_bn(self.C3_conv(torch.cat([Conv3_1, Conv3_2, Conv3_3, Conv3_4, Conv3_5], dim=1)))) 376 | 377 | High = torch.cat([Conv3, Conv4, Conv5], dim=1) 378 | High = self.relu_high(self.bn_high(self.conv_high(High))) 379 | 380 | # obtain Low and High 381 | # -----------> Integration Training <----------- # 382 | # Low-level Integration 383 | 384 | pred_edge1 = self.pred1_edge(Low) 385 | 386 | # Depth CNN 387 | D1=self.D_relu1(self.D_bn1(self.D_conv1(High))) 388 | D2=self.D_relu2(self.D_bn2(self.D_conv2(D1))) 389 | D3=self.D_relu3(self.D_bn3(self.D_conv3(D2))) 390 | 391 | # High-level Integration 392 | high_depth = self.high_depth(D3) 393 | # CA 394 | Att_map_CA = self.pool_avg(self.conv_ca(high_depth)) 395 | Att_map_CA = torch.mul(F.softmax(Att_map_CA, dim=1), 64) 396 | Att_High = torch.mul(High, Att_map_CA) 397 | Enhance_High = Att_High + High 398 | # SA 399 | high_sal = self.high_sal(Enhance_High) 400 | Att_map_SA = F.softmax(high_sal,dim=1)[:,1:,:] 401 | Feature = torch.mul(Enhance_High, Att_map_SA) 402 | Enhance_Feature = Feature + Enhance_High 403 | # predict depth attention and sal_map attention 404 | D_1=self.D_relu1(self.D_bn1(self.D_conv1(Enhance_Feature))) 405 | D_2=self.D_relu2(self.D_bn2(self.D_conv2(D_1))) 406 | D_3=self.D_relu3(self.D_bn3(self.D_conv3(D_2))) 407 | pred_depth = self.pred2_depth(D_3) 408 | pred_sal2 = self.pred2_sal(Enhance_Feature) 409 | 410 | Enhance_Feature = F.interpolate(Enhance_Feature, scale_factor=4, mode='bilinear', align_corners=False) 411 | Features = torch.cat([Enhance_Feature,Low],dim=1) 412 | 413 | high_depth = F.interpolate(high_depth, scale_factor=4, mode='bilinear', align_corners=False) 414 | high_sal = F.interpolate(high_sal, scale_factor=4, mode='bilinear', align_corners=False) 415 | pred_depth = F.interpolate(pred_depth, scale_factor=4, mode='bilinear', align_corners=False) 416 | pred_sal2 = F.interpolate(pred_sal2, scale_factor=4, mode='bilinear', align_corners=False) 417 | 418 | 419 | return Features, Features, Features, pred_edge1, high_depth, high_sal, pred_depth, pred_sal2 420 | 421 | 422 | 423 | -------------------------------------------------------------------------------- /CoNet/modeling/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback -------------------------------------------------------------------------------- /CoNet/modeling/sync_batchnorm/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiwei0921/CoNet/f9c1719f9c1a843f1eb04ac5271c4f2e061dfa90/CoNet/modeling/sync_batchnorm/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /CoNet/modeling/sync_batchnorm/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiwei0921/CoNet/f9c1719f9c1a843f1eb04ac5271c4f2e061dfa90/CoNet/modeling/sync_batchnorm/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /CoNet/modeling/sync_batchnorm/__pycache__/batchnorm.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiwei0921/CoNet/f9c1719f9c1a843f1eb04ac5271c4f2e061dfa90/CoNet/modeling/sync_batchnorm/__pycache__/batchnorm.cpython-35.pyc -------------------------------------------------------------------------------- /CoNet/modeling/sync_batchnorm/__pycache__/batchnorm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiwei0921/CoNet/f9c1719f9c1a843f1eb04ac5271c4f2e061dfa90/CoNet/modeling/sync_batchnorm/__pycache__/batchnorm.cpython-36.pyc -------------------------------------------------------------------------------- /CoNet/modeling/sync_batchnorm/__pycache__/comm.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiwei0921/CoNet/f9c1719f9c1a843f1eb04ac5271c4f2e061dfa90/CoNet/modeling/sync_batchnorm/__pycache__/comm.cpython-35.pyc -------------------------------------------------------------------------------- /CoNet/modeling/sync_batchnorm/__pycache__/comm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiwei0921/CoNet/f9c1719f9c1a843f1eb04ac5271c4f2e061dfa90/CoNet/modeling/sync_batchnorm/__pycache__/comm.cpython-36.pyc -------------------------------------------------------------------------------- /CoNet/modeling/sync_batchnorm/__pycache__/replicate.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiwei0921/CoNet/f9c1719f9c1a843f1eb04ac5271c4f2e061dfa90/CoNet/modeling/sync_batchnorm/__pycache__/replicate.cpython-35.pyc -------------------------------------------------------------------------------- /CoNet/modeling/sync_batchnorm/__pycache__/replicate.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiwei0921/CoNet/f9c1719f9c1a843f1eb04ac5271c4f2e061dfa90/CoNet/modeling/sync_batchnorm/__pycache__/replicate.cpython-36.pyc -------------------------------------------------------------------------------- /CoNet/modeling/sync_batchnorm/batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import collections 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | 16 | from torch.nn.modules.batchnorm import _BatchNorm 17 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast 18 | 19 | from .comm import SyncMaster 20 | 21 | __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] 22 | 23 | 24 | def _sum_ft(tensor): 25 | """sum over the first and last dimention""" 26 | return tensor.sum(dim=0).sum(dim=-1) 27 | 28 | 29 | def _unsqueeze_ft(tensor): 30 | """add new dementions at the front and the tail""" 31 | return tensor.unsqueeze(0).unsqueeze(-1) 32 | 33 | 34 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) 35 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) 36 | 37 | 38 | class _SynchronizedBatchNorm(_BatchNorm): 39 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): 40 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) 41 | 42 | self._sync_master = SyncMaster(self._data_parallel_master) 43 | 44 | self._is_parallel = False 45 | self._parallel_id = None 46 | self._slave_pipe = None 47 | 48 | def forward(self, input): 49 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. 50 | if not (self._is_parallel and self.training): 51 | return F.batch_norm( 52 | input, self.running_mean, self.running_var, self.weight, self.bias, 53 | self.training, self.momentum, self.eps) 54 | 55 | # Resize the input to (B, C, -1). 56 | input_shape = input.size() 57 | input = input.view(input.size(0), self.num_features, -1) 58 | 59 | # Compute the sum and square-sum. 60 | sum_size = input.size(0) * input.size(2) 61 | input_sum = _sum_ft(input) 62 | input_ssum = _sum_ft(input ** 2) 63 | 64 | # Reduce-and-broadcast the statistics. 65 | if self._parallel_id == 0: 66 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) 67 | else: 68 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) 69 | 70 | # Compute the output. 71 | if self.affine: 72 | # MJY:: Fuse the multiplication for speed. 73 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) 74 | else: 75 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) 76 | 77 | # Reshape it. 78 | return output.view(input_shape) 79 | 80 | def __data_parallel_replicate__(self, ctx, copy_id): 81 | self._is_parallel = True 82 | self._parallel_id = copy_id 83 | 84 | # parallel_id == 0 means master device. 85 | if self._parallel_id == 0: 86 | ctx.sync_master = self._sync_master 87 | else: 88 | self._slave_pipe = ctx.sync_master.register_slave(copy_id) 89 | 90 | def _data_parallel_master(self, intermediates): 91 | """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" 92 | 93 | # Always using same "device order" makes the ReduceAdd operation faster. 94 | # Thanks to:: Tete Xiao (http://tetexiao.com/) 95 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) 96 | 97 | to_reduce = [i[1][:2] for i in intermediates] 98 | to_reduce = [j for i in to_reduce for j in i] # flatten 99 | target_gpus = [i[1].sum.get_device() for i in intermediates] 100 | 101 | sum_size = sum([i[1].sum_size for i in intermediates]) 102 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) 103 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) 104 | 105 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std) 106 | 107 | outputs = [] 108 | for i, rec in enumerate(intermediates): 109 | outputs.append((rec[0], _MasterMessage(*broadcasted[i * 2:i * 2 + 2]))) 110 | 111 | return outputs 112 | 113 | def _compute_mean_std(self, sum_, ssum, size): 114 | """Compute the mean and standard-deviation with sum and square-sum. This method 115 | also maintains the moving average on the master device.""" 116 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' 117 | mean = sum_ / size 118 | sumvar = ssum - sum_ * mean 119 | unbias_var = sumvar / (size - 1) 120 | bias_var = sumvar / size 121 | 122 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data 123 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data 124 | 125 | return mean, bias_var.clamp(self.eps) ** -0.5 126 | 127 | 128 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): 129 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a 130 | mini-batch. 131 | .. math:: 132 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 133 | This module differs from the built-in PyTorch BatchNorm1d as the mean and 134 | standard-deviation are reduced across all devices during training. 135 | For example, when one uses `nn.DataParallel` to wrap the network during 136 | training, PyTorch's implementation normalize the tensor on each device using 137 | the statistics only on that device, which accelerated the computation and 138 | is also easy to implement, but the statistics might be inaccurate. 139 | Instead, in this synchronized version, the statistics will be computed 140 | over all training samples distributed on multiple devices. 141 | 142 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 143 | as the built-in PyTorch implementation. 144 | The mean and standard-deviation are calculated per-dimension over 145 | the mini-batches and gamma and beta are learnable parameter vectors 146 | of size C (where C is the input size). 147 | During training, this layer keeps a running estimate of its computed mean 148 | and variance. The running sum is kept with a default momentum of 0.1. 149 | During evaluation, this running mean/variance is used for normalization. 150 | Because the BatchNorm is done over the `C` dimension, computing statistics 151 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm 152 | Args: 153 | num_features: num_features from an expected input of size 154 | `batch_size x num_features [x width]` 155 | eps: a value added to the denominator for numerical stability. 156 | Default: 1e-5 157 | momentum: the value used for the running_mean and running_var 158 | computation. Default: 0.1 159 | affine: a boolean value that when set to ``True``, gives the layer learnable 160 | affine parameters. Default: ``True`` 161 | Shape: 162 | - Input: :math:`(N, C)` or :math:`(N, C, L)` 163 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) 164 | Examples: 165 | >>> # With Learnable Parameters 166 | >>> m = SynchronizedBatchNorm1d(100) 167 | >>> # Without Learnable Parameters 168 | >>> m = SynchronizedBatchNorm1d(100, affine=False) 169 | >>> input = torch.autograd.Variable(torch.randn(20, 100)) 170 | >>> output = m(input) 171 | """ 172 | 173 | def _check_input_dim(self, input): 174 | if input.dim() != 2 and input.dim() != 3: 175 | raise ValueError('expected 2D or 3D input (got {}D input)' 176 | .format(input.dim())) 177 | super(SynchronizedBatchNorm1d, self)._check_input_dim(input) 178 | 179 | 180 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): 181 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch 182 | of 3d inputs 183 | .. math:: 184 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 185 | This module differs from the built-in PyTorch BatchNorm2d as the mean and 186 | standard-deviation are reduced across all devices during training. 187 | For example, when one uses `nn.DataParallel` to wrap the network during 188 | training, PyTorch's implementation normalize the tensor on each device using 189 | the statistics only on that device, which accelerated the computation and 190 | is also easy to implement, but the statistics might be inaccurate. 191 | Instead, in this synchronized version, the statistics will be computed 192 | over all training samples distributed on multiple devices. 193 | 194 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 195 | as the built-in PyTorch implementation. 196 | The mean and standard-deviation are calculated per-dimension over 197 | the mini-batches and gamma and beta are learnable parameter vectors 198 | of size C (where C is the input size). 199 | During training, this layer keeps a running estimate of its computed mean 200 | and variance. The running sum is kept with a default momentum of 0.1. 201 | During evaluation, this running mean/variance is used for normalization. 202 | Because the BatchNorm is done over the `C` dimension, computing statistics 203 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm 204 | Args: 205 | num_features: num_features from an expected input of 206 | size batch_size x num_features x height x width 207 | eps: a value added to the denominator for numerical stability. 208 | Default: 1e-5 209 | momentum: the value used for the running_mean and running_var 210 | computation. Default: 0.1 211 | affine: a boolean value that when set to ``True``, gives the layer learnable 212 | affine parameters. Default: ``True`` 213 | Shape: 214 | - Input: :math:`(N, C, H, W)` 215 | - Output: :math:`(N, C, H, W)` (same shape as input) 216 | Examples: 217 | >>> # With Learnable Parameters 218 | >>> m = SynchronizedBatchNorm2d(100) 219 | >>> # Without Learnable Parameters 220 | >>> m = SynchronizedBatchNorm2d(100, affine=False) 221 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) 222 | >>> output = m(input) 223 | """ 224 | 225 | def _check_input_dim(self, input): 226 | if input.dim() != 4: 227 | raise ValueError('expected 4D input (got {}D input)' 228 | .format(input.dim())) 229 | super(SynchronizedBatchNorm2d, self)._check_input_dim(input) 230 | 231 | 232 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): 233 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch 234 | of 4d inputs 235 | .. math:: 236 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 237 | This module differs from the built-in PyTorch BatchNorm3d as the mean and 238 | standard-deviation are reduced across all devices during training. 239 | For example, when one uses `nn.DataParallel` to wrap the network during 240 | training, PyTorch's implementation normalize the tensor on each device using 241 | the statistics only on that device, which accelerated the computation and 242 | is also easy to implement, but the statistics might be inaccurate. 243 | Instead, in this synchronized version, the statistics will be computed 244 | over all training samples distributed on multiple devices. 245 | 246 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 247 | as the built-in PyTorch implementation. 248 | The mean and standard-deviation are calculated per-dimension over 249 | the mini-batches and gamma and beta are learnable parameter vectors 250 | of size C (where C is the input size). 251 | During training, this layer keeps a running estimate of its computed mean 252 | and variance. The running sum is kept with a default momentum of 0.1. 253 | During evaluation, this running mean/variance is used for normalization. 254 | Because the BatchNorm is done over the `C` dimension, computing statistics 255 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm 256 | or Spatio-temporal BatchNorm 257 | Args: 258 | num_features: num_features from an expected input of 259 | size batch_size x num_features x depth x height x width 260 | eps: a value added to the denominator for numerical stability. 261 | Default: 1e-5 262 | momentum: the value used for the running_mean and running_var 263 | computation. Default: 0.1 264 | affine: a boolean value that when set to ``True``, gives the layer learnable 265 | affine parameters. Default: ``True`` 266 | Shape: 267 | - Input: :math:`(N, C, D, H, W)` 268 | - Output: :math:`(N, C, D, H, W)` (same shape as input) 269 | Examples: 270 | >>> # With Learnable Parameters 271 | >>> m = SynchronizedBatchNorm3d(100) 272 | >>> # Without Learnable Parameters 273 | >>> m = SynchronizedBatchNorm3d(100, affine=False) 274 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) 275 | >>> output = m(input) 276 | """ 277 | 278 | def _check_input_dim(self, input): 279 | if input.dim() != 5: 280 | raise ValueError('expected 5D input (got {}D input)' 281 | .format(input.dim())) 282 | super(SynchronizedBatchNorm3d, self)._check_input_dim(input) -------------------------------------------------------------------------------- /CoNet/modeling/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 59 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 60 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 61 | and passed to a registered callback. 62 | - After receiving the messages, the master device should gather the information and determine to message passed 63 | back to each slave devices. 64 | """ 65 | 66 | def __init__(self, master_callback): 67 | """ 68 | Args: 69 | master_callback: a callback to be invoked after having collected messages from slave devices. 70 | """ 71 | self._master_callback = master_callback 72 | self._queue = queue.Queue() 73 | self._registry = collections.OrderedDict() 74 | self._activated = False 75 | 76 | def __getstate__(self): 77 | return {'master_callback': self._master_callback} 78 | 79 | def __setstate__(self, state): 80 | self.__init__(state['master_callback']) 81 | 82 | def register_slave(self, identifier): 83 | """ 84 | Register an slave device. 85 | Args: 86 | identifier: an identifier, usually is the device id. 87 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 88 | """ 89 | if self._activated: 90 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 91 | self._activated = False 92 | self._registry.clear() 93 | future = FutureResult() 94 | self._registry[identifier] = _MasterRegistry(future) 95 | return SlavePipe(identifier, self._queue, future) 96 | 97 | def run_master(self, master_msg): 98 | """ 99 | Main entry for the master device in each forward pass. 100 | The messages were first collected from each devices (including the master device), and then 101 | an callback will be invoked to compute the message to be sent back to each devices 102 | (including the master device). 103 | Args: 104 | master_msg: the message that the master want to send to itself. This will be placed as the first 105 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 106 | Returns: the message to be sent back to the master device. 107 | """ 108 | self._activated = True 109 | 110 | intermediates = [(0, master_msg)] 111 | for i in range(self.nr_slaves): 112 | intermediates.append(self._queue.get()) 113 | 114 | results = self._master_callback(intermediates) 115 | assert results[0][0] == 0, 'The first result should belongs to the master.' 116 | 117 | for i, res in results: 118 | if i == 0: 119 | continue 120 | self._registry[i].result.put(res) 121 | 122 | for i in range(self.nr_slaves): 123 | assert self._queue.get() is True 124 | 125 | return results[0][1] 126 | 127 | @property 128 | def nr_slaves(self): 129 | return len(self._registry) 130 | -------------------------------------------------------------------------------- /CoNet/modeling/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 31 | Note that, as all modules are isomorphism, we assign each sub-module with a context 32 | (shared among multiple copies of this module on different devices). 33 | Through this context, different copies can share some information. 34 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 35 | of any slave copies. 36 | """ 37 | master_copy = modules[0] 38 | nr_modules = len(list(master_copy.modules())) 39 | ctxs = [CallbackContext() for _ in range(nr_modules)] 40 | 41 | for i, module in enumerate(modules): 42 | for j, m in enumerate(module.modules()): 43 | if hasattr(m, '__data_parallel_replicate__'): 44 | m.__data_parallel_replicate__(ctxs[j], i) 45 | 46 | 47 | class DataParallelWithCallback(DataParallel): 48 | """ 49 | Data Parallel with a replication callback. 50 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 51 | original `replicate` function. 52 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 53 | Examples: 54 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 55 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 56 | # sync_bn.__data_parallel_replicate__ will be invoked. 57 | """ 58 | 59 | def replicate(self, module, device_ids): 60 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 61 | execute_replication_callbacks(modules) 62 | return modules 63 | 64 | 65 | def patch_replication_callback(data_parallel): 66 | """ 67 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 68 | Useful when you have customized `DataParallel` implementation. 69 | Examples: 70 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 71 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 72 | > patch_replication_callback(sync_bn) 73 | # this is equivalent to 74 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 75 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 76 | """ 77 | 78 | assert isinstance(data_parallel, DataParallel) 79 | 80 | old_replicate = data_parallel.replicate 81 | 82 | @functools.wraps(old_replicate) 83 | def new_replicate(module, device_ids): 84 | modules = old_replicate(module, device_ids) 85 | execute_replication_callbacks(modules) 86 | return modules 87 | 88 | data_parallel.replicate = new_replicate -------------------------------------------------------------------------------- /CoNet/modeling/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | 13 | import numpy as np 14 | from torch.autograd import Variable 15 | 16 | 17 | def as_numpy(v): 18 | if isinstance(v, Variable): 19 | v = v.data 20 | return v.cpu().numpy() 21 | 22 | 23 | class TorchTestCase(unittest.TestCase): 24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): 25 | npa, npb = as_numpy(a), as_numpy(b) 26 | self.assertTrue( 27 | np.allclose(npa, npb, atol=atol), 28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) 29 | ) 30 | -------------------------------------------------------------------------------- /CoNet/trainer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import datetime 3 | 4 | from torch.autograd import Variable 5 | import torch.nn.functional as F 6 | import torch 7 | 8 | 9 | 10 | running_loss_final = 0 11 | running_loss_pre = 0 12 | 13 | 14 | 15 | def cross_entropy2d(input, target, weight=None, size_average=True): 16 | n, c, h, w = input.size() 17 | 18 | input = input.transpose(1,2).transpose(2,3).contiguous() 19 | input = input[target.view(n, h, w, 1).repeat(1, 1, 1, c) >= 0] 20 | input = input.view(-1, c) 21 | # target: (n*h*w,) 22 | mask = target >= 0 23 | target = target[mask] 24 | loss = F.cross_entropy(input, target, weight=weight, size_average=False) 25 | if size_average: 26 | loss /= mask.data.sum() 27 | return loss 28 | 29 | 30 | 31 | 32 | 33 | 34 | class Trainer(object): 35 | 36 | def __init__(self, cuda, model_rgb,model_intergration,model_att, optimizer_rgb, 37 | optimizer_inter,optimizer_att, 38 | train_loader, max_iter, snapshot, outpath, sshow, size_average=False): 39 | self.cuda = cuda 40 | self.model_rgb = model_rgb 41 | self.model_intergration = model_intergration 42 | self.model_att = model_att 43 | self.optim_rgb = optimizer_rgb 44 | self.optim_inter = optimizer_inter 45 | self.optim_att = optimizer_att 46 | self.train_loader = train_loader 47 | self.epoch = 0 48 | self.iteration = 0 49 | self.max_iter = max_iter 50 | self.snapshot = snapshot 51 | self.outpath = outpath 52 | self.sshow = sshow 53 | self.size_average = size_average 54 | 55 | 56 | 57 | def train_epoch(self): 58 | 59 | for batch_idx, (data, target, depth, edge) in enumerate(self.train_loader): 60 | 61 | iteration = batch_idx + self.epoch * len(self.train_loader) 62 | if self.iteration != 0 and (iteration - 1) != self.iteration: 63 | continue # for resuming 64 | self.iteration = iteration 65 | if self.iteration >= self.max_iter: 66 | break 67 | if self.cuda: 68 | data, target, depth, edge = data.cuda(), target.cuda(), depth.cuda(), edge.cuda() 69 | data, target, depth, edge = Variable(data), Variable(target), Variable(depth), Variable(edge) 70 | n, c, h, w = data.size() # batch_size, channels, height, weight 71 | 72 | 73 | self.optim_rgb.zero_grad() 74 | self.optim_inter.zero_grad() 75 | self.optim_att.zero_grad() 76 | 77 | global running_loss_final 78 | global running_loss_pre 79 | 80 | low_1, low_2, high_1, high_2, high_3 = self.model_rgb(data) 81 | Features, _, _, pred_edge1, high_depth, high_sal, pred_depth, pred_sal2 = self.model_intergration(low_1, low_2, high_1, high_2, high_3) 82 | 83 | loss3 = cross_entropy2d(pred_edge1,edge,weight=None,size_average=self.size_average) 84 | loss4 = F.smooth_l1_loss(high_depth, depth, size_average=self.size_average) 85 | loss5 = cross_entropy2d(high_sal, target, weight=None, size_average=self.size_average) 86 | loss6 = F.smooth_l1_loss(pred_depth, depth, size_average=self.size_average) 87 | loss7 = cross_entropy2d(pred_sal2, target, weight=None, size_average=self.size_average) 88 | loss_pre = ((loss3+loss5+loss7) + (loss4+loss6)*2.5)/5 89 | running_loss_pre += loss_pre.item() 90 | 91 | loss_pre.backward(retain_graph=True) 92 | self.optim_inter.step() 93 | self.optim_rgb.step() 94 | 95 | low_1, low_2, high_1, high_2, high_3 = self.model_rgb(data) 96 | Features, _, _, Edge, _, _, Depth, Sal = self.model_intergration(low_1, low_2, high_1, high_2, high_3) 97 | outputs = self.model_att(Features, Edge, Sal, Depth) 98 | loss_all = cross_entropy2d(outputs, target, weight=None, size_average=self.size_average) 99 | running_loss_final += loss_all.item() 100 | 101 | 102 | 103 | if iteration % self.sshow == (self.sshow-1): 104 | curr_time = str(datetime.datetime.now())[:19] 105 | print('\n [%s, %3d, %6d, The training loss of Net: %.3f, and the auxiliary loss: %.3f]' % (curr_time, self.epoch + 1, iteration + 1, running_loss_final / (n * self.sshow),running_loss_pre / (n * self.sshow))) 106 | 107 | running_loss_pre = 0.0 108 | running_loss_final = 0.0 109 | 110 | 111 | if iteration <= 200000: 112 | if iteration % self.snapshot == (self.snapshot-1): 113 | savename = ('%s/snapshot_iter_%d.pth' % (self.outpath, iteration+1)) 114 | torch.save(self.model_rgb.state_dict(), savename) 115 | print('save: (snapshot: %d)' % (iteration+1)) 116 | 117 | savename_focal = ('%s/integrate_snapshot_iter_%d.pth' % (self.outpath, iteration+1)) 118 | torch.save(self.model_intergration.state_dict(), savename_focal) 119 | print('save: (snapshot_integrate: %d)' % (iteration+1)) 120 | 121 | savename_clstm = ('%s/att_snapshot_iter_%d.pth' % (self.outpath, iteration+1)) 122 | torch.save(self.model_att.state_dict(), savename_clstm) 123 | print('save: (snapshot_att: %d)' % (iteration+1)) 124 | 125 | else: 126 | if iteration % 20000 == (20000 - 1): 127 | savename = ('%s/snapshot_iter_%d.pth' % (self.outpath, iteration + 1)) 128 | torch.save(self.model_rgb.state_dict(), savename) 129 | print('save: (snapshot: %d)' % (iteration + 1)) 130 | 131 | savename_focal = ('%s/integrate_snapshot_iter_%d.pth' % (self.outpath, iteration + 1)) 132 | torch.save(self.model_intergration.state_dict(), savename_focal) 133 | print('save: (snapshot_integrate: %d)' % (iteration + 1)) 134 | 135 | savename_clstm = ('%s/att_snapshot_iter_%d.pth' % (self.outpath, iteration + 1)) 136 | torch.save(self.model_att.state_dict(), savename_clstm) 137 | print('save: (snapshot_att: %d)' % (iteration + 1)) 138 | 139 | 140 | 141 | if (iteration+1) == self.max_iter: 142 | savename = ('%s/snapshot_iter_%d.pth' % (self.outpath, iteration+1)) 143 | torch.save(self.model_rgb.state_dict(), savename) 144 | print('save: (snapshot: %d)' % (iteration+1)) 145 | 146 | savename_focal = ('%s/integrate_snapshot_iter_%d.pth' % (self.outpath, iteration+1)) 147 | torch.save(self.model_intergration.state_dict(), savename_focal) 148 | print('save: (snapshot_integrate: %d)' % (iteration+1)) 149 | 150 | savename_clstm = ('%s/att_snapshot_iter_%d.pth' % (self.outpath, iteration+1)) 151 | torch.save(self.model_att.state_dict(), savename_clstm) 152 | print('save: (snapshot_att: %d)' % (iteration+1)) 153 | 154 | 155 | loss_all.backward() 156 | self.optim_att.step() 157 | self.optim_inter.step() 158 | self.optim_rgb.step() 159 | 160 | def train(self): 161 | max_epoch = int(math.ceil(1. * self.max_iter / len(self.train_loader))) 162 | 163 | for epoch in range(max_epoch): 164 | self.epoch = epoch 165 | self.train_epoch() 166 | if self.iteration >= self.max_iter: 167 | break 168 | -------------------------------------------------------------------------------- /CoNet/transform.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from PIL import Image 5 | 6 | def colormap(n): #import n, then r'g'b obtain values, finally acquiring colormap 7 | cmap=np.zeros([n, 3]).astype(np.uint8) 8 | 9 | for i in np.arange(n): 10 | r, g, b = np.zeros(3) 11 | 12 | for j in np.arange(8): 13 | r = r + (1<<(7-j))*((i&(1<<(3*j))) >> (3*j)) 14 | g = g + (1<<(7-j))*((i&(1<<(3*j+1))) >> (3*j+1)) 15 | b = b + (1<<(7-j))*((i&(1<<(3*j+2))) >> (3*j+2)) 16 | 17 | cmap[i,:] = np.array([r, g, b]) 18 | 19 | return cmap 20 | 21 | class Relabel: 22 | 23 | def __init__(self, olabel, nlabel): 24 | self.olabel = olabel 25 | self.nlabel = nlabel 26 | 27 | def __call__(self, tensor): 28 | assert isinstance(tensor, torch.LongTensor), 'tensor needs to be LongTensor' 29 | tensor[tensor == self.olabel] = self.nlabel 30 | return tensor 31 | 32 | 33 | class ToLabel: 34 | 35 | def __call__(self, image): 36 | return torch.from_numpy(np.array(image)).long().unsqueeze(0) 37 | 38 | 39 | class Colorize: 40 | 41 | def __init__(self, n=21): 42 | self.cmap = colormap(256) 43 | self.cmap[n] = self.cmap[-1] 44 | self.cmap = torch.from_numpy(self.cmap[:n]) 45 | 46 | def __call__(self, gray_image): 47 | size = gray_image.size() 48 | color_image = torch.ByteTensor(3, size[0], size[1]).fill_(0) 49 | 50 | for label in range(1, len(self.cmap)): 51 | mask = (gray_image == label) 52 | 53 | color_image[0][mask] = self.cmap[label][0] 54 | color_image[1][mask] = self.cmap[label][1] 55 | color_image[2][mask] = self.cmap[label][2] 56 | 57 | return color_image 58 | -------------------------------------------------------------------------------- /Comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiwei0921/CoNet/f9c1719f9c1a843f1eb04ac5271c4f2e061dfa90/Comparison.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CoNet 2 | Code repository for our paper entilted ["Accurate RGB-D Salient Object Detection via Collaborative Learning"](https://arxiv.org/pdf/2007.11782.pdf) accepted at ECCV 2020 (poster). 3 | 4 | # Overall 5 | ![avatar](https://github.com/jiwei0921/CoNet/blob/master/overall111.png) 6 | 7 | 8 | ## CoNet Code 9 | 10 | ### > Requirment 11 | + pytorch 1.0.0+ 12 | + torchvision 13 | + PIL 14 | + numpy 15 | 16 | ### > Usage 17 | #### 1. Clone the repo 18 | ``` 19 | git clone https://github.com/jiwei0921/CoNet.git 20 | cd CoNet/ 21 | ``` 22 | 23 | #### 2. Train/Test 24 | + test 25 | Our test datasets [link](https://github.com/jiwei0921/RGBD-SOD-datasets) and checkpoint [link](https://pan.baidu.com/s/1ceRpBrSjIxM0ut3t8awDfg) code is **12yn**. You need to set dataset path and checkpoint name correctly. 26 | 27 | '--phase' as **test** in demo.py 28 | '--param' as **True** in demo.py 29 | ``` 30 | python demo.py 31 | ``` 32 | 33 | + train 34 | Our training dataset [link](https://pan.baidu.com/s/1EMKE7pwLg70sfYvQQAB1kA) code is **203g**. You need to set dataset path and checkpoint name correctly. 35 | 36 | '--phase' as **train** in demo.py 37 | '--param' as **True or False** in demo.py 38 | Note: True means loading checkpoint and False means no loading checkpoint. 39 | ``` 40 | python demo.py 41 | ``` 42 | 43 | ### > Results 44 | ![avatar](https://github.com/jiwei0921/CoNet/blob/master/Comparison.png) 45 | 46 | We provide [saliency maps](https://pan.baidu.com/s/1hQH89lhzgR3fk2Y3eI_Jww) (code: qrs2) of our CoNet on 8 datasets (DUT-RGBD, STEREO, NJUD, LFSD, RGBD135, NLPR, SSD, SIP) as well as 2 extended datasets (NJU2k and STERE1000) refer to CPFP_CVPR19. 47 | + Note: For evaluation, all results are implemented on this ready-to-use [toolbox](https://github.com/jiwei0921/Saliency-Evaluation-Toolbox). 48 | 49 | 50 | ### > Related RGB-D Saliency Datasets 51 | All common RGB-D Saliency Datasets we collected are shared in ready-to-use manner. 52 | + The web link is [here](https://github.com/jiwei0921/RGBD-SOD-datasets). 53 | 54 | 55 | ### If you think this work is helpful, please cite 56 | ``` 57 | @InProceedings{Wei_2020_ECCV, 58 | author={Ji, Wei and Li, Jingjing and Zhang, Miao and Piao, Yongri and Lu, Huchuan}, 59 | title = {Accurate {RGB-D} Salient Object Detection via Collaborative Learning}, 60 | booktitle = {European Conference on Computer Vision}, 61 | year = {2020} 62 | } 63 | ``` 64 | 65 | + For more info about CoNet, please read the [Manuscript](https://arxiv.org/pdf/2007.11782.pdf). 66 | + Thanks for related authors to provide the code or results, particularly, [Deng-ping Fan](http://dpfan.net), [Hao Chen](https://github.com/haochen593), [Chun-biao Zhu](https://github.com/ChunbiaoZhu), etc. 67 | 68 | ### Contact Us 69 | More details can be found in [Github Wei Ji.](https://github.com/jiwei0921/) 70 | If you have any questions, please contact us ( weiji.dlut@gmail.com ). 71 | 72 | -------------------------------------------------------------------------------- /egbib.bib: -------------------------------------------------------------------------------- 1 | @InProceedings{Wei_2020_ECCV, 2 | author = {Wei {Ji} and Jingjing {Li} and Miao {Zhang} and Yongri {Piao} and Huchuan {Lu}}, 3 | title = {Accurate RGB-D Salient Object Detection via Collaborative Learning}, 4 | booktitle = "ECCV", 5 | year = {2020} 6 | } -------------------------------------------------------------------------------- /overall111.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiwei0921/CoNet/f9c1719f9c1a843f1eb04ac5271c4f2e061dfa90/overall111.png --------------------------------------------------------------------------------