├── LICENCE.md
├── README.md
├── demo.py
├── demo.sh
├── examples
├── 000000000724.jpg
├── 000000404922.jpg
├── 00022_00197_outdoor_300_050.png
├── 0SpJOOTH7R4_144577767_image.jpg
├── 0SpJOOTH7R4_215215000_image.jpg
└── frame_0017.png
├── models
├── DepthNet.py
├── networks.py
├── resnet.py
└── syncbn
│ ├── LICENSE
│ ├── README.md
│ ├── make_ext.sh
│ ├── modules
│ ├── __init__.py
│ ├── __init__.pyc
│ ├── __pycache__
│ │ └── __init__.cpython-37.pyc
│ ├── functional
│ │ ├── __init__.py
│ │ ├── __init__.pyc
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-37.pyc
│ │ │ └── syncbn.cpython-37.pyc
│ │ ├── _syncbn
│ │ │ ├── __init__.py
│ │ │ ├── __init__.pyc
│ │ │ ├── __pycache__
│ │ │ │ └── __init__.cpython-37.pyc
│ │ │ ├── _ext
│ │ │ │ ├── __init__.py
│ │ │ │ ├── __init__.pyc
│ │ │ │ ├── __pycache__
│ │ │ │ │ └── __init__.cpython-37.pyc
│ │ │ │ └── syncbn
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── __init__.pyc
│ │ │ │ │ ├── __pycache__
│ │ │ │ │ └── __init__.cpython-37.pyc
│ │ │ │ │ └── _syncbn.so
│ │ │ ├── build.py
│ │ │ └── src
│ │ │ │ ├── common.h
│ │ │ │ ├── syncbn.cpp
│ │ │ │ ├── syncbn.cu
│ │ │ │ ├── syncbn.cu.h
│ │ │ │ ├── syncbn.cu.o
│ │ │ │ └── syncbn.h
│ │ ├── syncbn.py
│ │ └── syncbn.pyc
│ └── nn
│ │ ├── __init__.py
│ │ ├── __init__.pyc
│ │ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ └── syncbn.cpython-37.pyc
│ │ ├── syncbn.py
│ │ └── syncbn.pyc
│ ├── requirements.txt
│ └── test.py
└── ranking_loss.py
/LICENCE.md:
--------------------------------------------------------------------------------
1 |
2 | This software is for non-commercial purposes
3 |
4 | Copyright (c) 2020 Ke Xian All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
7 |
8 | Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
9 |
10 | Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
11 |
12 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
13 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Structure-Guided Ranking Loss for Single Image Depth Prediction
2 | This repository contains a pytorch implementation of our CVPR2020 paper "Structure-Guided Ranking Loss for Single Image Depth Prediction".
3 | [Project Page](https://KexianHust.github.io/Structure-Guided-Ranking-Loss/)
4 | 
5 |
6 | ## Changelog
7 | * [Jun. 2020] Initial release
8 |
9 | ## To do
10 | - [ ] Mix data training
11 |
12 | ## Prerequisites
13 | * Pytorch >= 0.4.1
14 | * CUDA >= 0.8
15 | * Python >= 2.7
16 | * glob, matplotlib
17 | * Need to compile the syncbn module in models/syncbn. Note that the directory of the syncbn module should be modified in some .py files (i.e., DepthNet.py, resnet.py and networks.py)
18 | * Download the [model.pth.tar](https://drive.google.com/file/d/1p8c8-nUTNry5usQmGdTC2TrwWrp3dQ0y/view?usp=sharing)
19 |
20 | ## Inference
21 | ```bash
22 | # Before running, you should set the CUDA_VISIBLE_DEVICES in demo.sh
23 | bash demo.sh
24 |
25 | ```
26 |
27 | If you find our work useful in your research, please consider citing the paper.
28 |
29 | ```
30 | @InProceedings{Xian_2020_CVPR,
31 | author = {Xian, Ke and Zhang, Jianming and Wang, Oliver and Mai, Long and Lin, Zhe and Cao, Zhiguo},
32 | title = {Structure-Guided Ranking Loss for Single Image Depth Prediction},
33 | booktitle = {The IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
34 | month = {June},
35 | year = {2020}
36 | }
37 | ```
38 |
39 | ## Dataset
40 | Our [HRWSI](https://drive.google.com/file/d/1OVOx6x-B0Cs-m2z_-7ZxSgRFHz_VBvDd/view?usp=sharing) dataset is for research only! Some researchers may interested in the stereo data, so we provide the right views [here](https://drive.google.com/file/d/1HzEB7yQI05Q21dP9rRjnyMoEmvCckAQp/view?usp=sharing). Please let me know if you have any questions.
41 |
42 | ## Lisence
43 | Research only
44 |
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 |
4 | # demo
5 |
6 | """
7 | Author: Ke Xian
8 | Email: kexian@hust.edu.cn
9 | Create_Date: 2019/05/21
10 | """
11 |
12 | import torch
13 | import torch.nn as nn
14 | import torchvision.transforms as transforms
15 | from torch.utils.data import DataLoader
16 | torch.backends.cudnn.deterministic = True
17 | torch.manual_seed(123)
18 |
19 | import os, argparse, sys
20 | import numpy as np
21 | import glob
22 | import matplotlib.pyplot as plt
23 | plt.switch_backend('agg')
24 | import warnings
25 | warnings.filterwarnings("ignore")
26 | from PIL import Image
27 |
28 | sys.path.append('models')
29 | import DepthNet
30 |
31 | # =======================
32 | # demo
33 | # =======================
34 | def demo(net, args):
35 | data_dir = args.data_dir
36 | img_transform = transforms.Compose([
37 | transforms.ToTensor(),
38 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
39 | ])
40 |
41 | for im in os.listdir(data_dir):
42 | im_dir = os.path.join(data_dir, im)
43 | print('Processing img: {}'.format(im_dir))
44 |
45 | # Read image
46 | img = Image.open(im_dir).convert('RGB')
47 | ori_width, ori_height = img.size
48 | int_width = args.img_size[0]
49 | int_height = args.img_size[1]
50 | img = img.resize((int_width, int_height), Image.ANTIALIAS)
51 | tensor_img = img_transform(img)
52 |
53 | # forward
54 | input_img = torch.autograd.Variable(tensor_img.cuda().unsqueeze(0), volatile=True)
55 | output = net(input_img)
56 |
57 | # Normalization and save results
58 | depth = output.squeeze().cpu().data.numpy()
59 | min_d, max_d = depth.min(), depth.max()
60 | depth_norm = (depth - min_d) / (max_d - min_d) * 255
61 | depth_norm = depth_norm.astype(np.uint8)
62 | image_pil = Image.fromarray(depth_norm)
63 |
64 | output_dir = os.path.join(args.result_dir, im)
65 | image_pil = image_pil.resize((ori_width, ori_height), Image.BILINEAR)
66 | plt.imsave(output_dir, np.asarray(image_pil), cmap='inferno')
67 |
68 |
69 | if __name__ == '__main__':
70 |
71 | parser = argparse.ArgumentParser(description='MRDP Testing/Evaluation')
72 | parser.add_argument('--img_size', default=[448, 448], type=list, help='Image size of network input')
73 | parser.add_argument('--data_dir', default='examples', type=str, help='Data path')
74 | parser.add_argument('--result_dir', default='demo_results', type=str, help='Directory for saving results, default: demo_results')
75 | parser.add_argument('--gpu_id', default=0, type=int, help='GPU id, default:0')
76 | args = parser.parse_args()
77 |
78 | args.checkpoint = 'model.pth.tar'
79 |
80 | if not os.path.exists(args.result_dir):
81 | os.makedirs(args.result_dir)
82 |
83 | gpu_id = args.gpu_id
84 | torch.cuda.device(gpu_id)
85 |
86 | net = DepthNet.DepthNet()
87 | net = torch.nn.DataParallel(net, device_ids=[0]).cuda()
88 | checkpoint = torch.load(args.checkpoint)
89 | net.load_state_dict(checkpoint['state_dict'])
90 | net.eval()
91 |
92 | print('Begin to test ...')
93 | with torch.no_grad():
94 | demo(net, args)
95 | print('Finished!')
96 |
--------------------------------------------------------------------------------
/demo.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python demo.py
2 |
--------------------------------------------------------------------------------
/examples/000000000724.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KexianHust/Structure-Guided-Ranking-Loss/6fa853da2aeb53ad0a66e95484a5ae7fc816bb4a/examples/000000000724.jpg
--------------------------------------------------------------------------------
/examples/000000404922.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KexianHust/Structure-Guided-Ranking-Loss/6fa853da2aeb53ad0a66e95484a5ae7fc816bb4a/examples/000000404922.jpg
--------------------------------------------------------------------------------
/examples/00022_00197_outdoor_300_050.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KexianHust/Structure-Guided-Ranking-Loss/6fa853da2aeb53ad0a66e95484a5ae7fc816bb4a/examples/00022_00197_outdoor_300_050.png
--------------------------------------------------------------------------------
/examples/0SpJOOTH7R4_144577767_image.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KexianHust/Structure-Guided-Ranking-Loss/6fa853da2aeb53ad0a66e95484a5ae7fc816bb4a/examples/0SpJOOTH7R4_144577767_image.jpg
--------------------------------------------------------------------------------
/examples/0SpJOOTH7R4_215215000_image.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KexianHust/Structure-Guided-Ranking-Loss/6fa853da2aeb53ad0a66e95484a5ae7fc816bb4a/examples/0SpJOOTH7R4_215215000_image.jpg
--------------------------------------------------------------------------------
/examples/frame_0017.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KexianHust/Structure-Guided-Ranking-Loss/6fa853da2aeb53ad0a66e95484a5ae7fc816bb4a/examples/frame_0017.png
--------------------------------------------------------------------------------
/models/DepthNet.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python2
2 | # coding: utf-8
3 |
4 | '''
5 | Author: Ke Xian
6 | Email: kexian@hust.edu.cn
7 | Date: 2019/04/09
8 | '''
9 |
10 | import torch
11 | import torchvision
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 | import torch.nn.init as init
15 |
16 | import sys
17 | sys.path.append('/data0/kexian/Code/kxian_Adobe/MPO_edgeGuidedRanking/models/syncbn')
18 | from modules import nn as NN
19 |
20 | import resnet
21 |
22 | from networks import *
23 |
24 | class Decoder(nn.Module):
25 | def __init__(self, inchannels = [256, 512, 1024, 2048], midchannels = [256, 256, 256, 512], upfactors = [2,2,2,2], outchannels = 1):
26 | super(Decoder, self).__init__()
27 | self.inchannels = inchannels
28 | self.midchannels = midchannels
29 | self.upfactors = upfactors
30 | self.outchannels = outchannels
31 |
32 | self.conv = FTB(inchannels=self.inchannels[3], midchannels=self.midchannels[3])
33 | self.conv1 = nn.Conv2d(in_channels=self.midchannels[3], out_channels=self.midchannels[2], kernel_size=3, padding=1, stride=1, bias=True)
34 | self.upsample = nn.Upsample(scale_factor=self.upfactors[3], mode='bilinear', align_corners=True)
35 |
36 | self.ffm2 = FFM(inchannels=self.inchannels[2], midchannels=self.midchannels[2], outchannels = self.midchannels[2], upfactor=self.upfactors[2])
37 | self.ffm1 = FFM(inchannels=self.inchannels[1], midchannels=self.midchannels[1], outchannels = self.midchannels[1], upfactor=self.upfactors[1])
38 | self.ffm0 = FFM(inchannels=self.inchannels[0], midchannels=self.midchannels[0], outchannels = self.midchannels[0], upfactor=self.upfactors[0])
39 |
40 | self.outconv = AO(inchannels=self.inchannels[0], outchannels=self.outchannels, upfactor=2)
41 |
42 | self._init_params()
43 |
44 | def _init_params(self):
45 | for m in self.modules():
46 | if isinstance(m, nn.Conv2d):
47 | #init.kaiming_normal_(m.weight, mode='fan_out')
48 | init.normal_(m.weight, std=0.01)
49 | #init.xavier_normal_(m.weight)
50 | if m.bias is not None:
51 | init.constant_(m.bias, 0)
52 | elif isinstance(m, nn.ConvTranspose2d):
53 | #init.kaiming_normal_(m.weight, mode='fan_out')
54 | init.normal_(m.weight, std=0.01)
55 | #init.xavier_normal_(m.weight)
56 | if m.bias is not None:
57 | init.constant_(m.bias, 0)
58 | elif isinstance(m, NN.BatchNorm2d): #NN.BatchNorm2d
59 | init.constant_(m.weight, 1)
60 | init.constant_(m.bias, 0)
61 | elif isinstance(m, nn.Linear):
62 | init.normal_(m.weight, std=0.01)
63 | if m.bias is not None:
64 | init.constant_(m.bias, 0)
65 |
66 | def forward(self, features):
67 | _,_,h,w = features[3].size()
68 | x = self.conv(features[3])
69 | x = self.conv1(x)
70 | x = self.upsample(x)
71 |
72 | x = self.ffm2(features[2], x)
73 | x = self.ffm1(features[1], x)
74 | x = self.ffm0(features[0], x)
75 |
76 | #-----------------------------------------
77 | x = self.outconv(x)
78 |
79 | return x
80 |
81 | class DepthNet(nn.Module):
82 | __factory = {
83 | 18: resnet.resnet18,
84 | 34: resnet.resnet34,
85 | 50: resnet.resnet50,
86 | 101: resnet.resnet101,
87 | 152: resnet.resnet152
88 | }
89 | def __init__(self,
90 | backbone='resnet',
91 | depth=50,
92 | pretrained=True,
93 | inchannels=[256, 512, 1024, 2048],
94 | midchannels=[256, 256, 256, 512],
95 | upfactors=[2, 2, 2, 2],
96 | outchannels=1):
97 | super(DepthNet, self).__init__()
98 | self.backbone = backbone
99 | self.depth = depth
100 | self.pretrained = pretrained
101 | self.inchannels = inchannels
102 | self.midchannels = midchannels
103 | self.upfactors = upfactors
104 | self.outchannels = outchannels
105 |
106 | # Build model
107 | if self.depth not in DepthNet.__factory:
108 | raise KeyError("Unsupported depth:", self.depth)
109 | self.encoder = DepthNet.__factory[depth](pretrained=pretrained)
110 |
111 | self.decoder = Decoder(inchannels=self.inchannels, midchannels=self.midchannels, upfactors=self.upfactors, outchannels=self.outchannels)
112 |
113 | def forward(self, x):
114 | x = self.encoder(x) # 1/4, 1/8, 1/16, 1/32
115 | x = self.decoder(x)
116 |
117 | return x
118 |
119 | if __name__ == '__main__':
120 | net = DepthNet(depth=50, pretrained=True)
121 | print(net)
122 | inputs = torch.ones(4,3,128,128)
123 | out = net(inputs)
124 | print(out.size())
125 |
--------------------------------------------------------------------------------
/models/networks.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python2
2 | # -*- coding: utf-8 -*-
3 | '''
4 | Author: Ke Xian
5 | Email: kexian@hust.edu.cn
6 | Date: 2019/04/09
7 | '''
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.init as init
12 | import sys
13 | sys.path.append('/data0/kexian/Code/kxian_Adobe/MPO_edgeGuidedRanking/models/syncbn')
14 | import modules.nn as NN
15 |
16 | # ==============================================================================================================
17 |
18 | class FTB(nn.Module):
19 | def __init__(self, inchannels, midchannels=512):
20 | super(FTB, self).__init__()
21 | self.in1 = inchannels
22 | self.mid = midchannels
23 |
24 | self.conv1 = nn.Conv2d(in_channels=self.in1, out_channels=self.mid, kernel_size=3, padding=1, stride=1, bias=True)
25 | # NN.BatchNorm2d
26 | self.conv_branch = nn.Sequential(nn.ReLU(inplace=True),\
27 | nn.Conv2d(in_channels=self.mid, out_channels=self.mid, kernel_size=3, padding=1, stride=1, bias=True),\
28 | NN.BatchNorm2d(num_features=self.mid),\
29 | nn.ReLU(inplace=True),\
30 | nn.Conv2d(in_channels=self.mid, out_channels= self.mid, kernel_size=3, padding=1, stride=1, bias=True))
31 | self.relu = nn.ReLU(inplace=True)
32 |
33 | self.init_params()
34 |
35 | def forward(self, x):
36 | x = self.conv1(x)
37 | x = x + self.conv_branch(x)
38 | x = self.relu(x)
39 |
40 | return x
41 |
42 | def init_params(self):
43 | for m in self.modules():
44 | if isinstance(m, nn.Conv2d):
45 | #init.kaiming_normal_(m.weight, mode='fan_out')
46 | init.normal_(m.weight, std=0.01)
47 | # init.xavier_normal_(m.weight)
48 | if m.bias is not None:
49 | init.constant_(m.bias, 0)
50 | elif isinstance(m, nn.ConvTranspose2d):
51 | #init.kaiming_normal_(m.weight, mode='fan_out')
52 | init.normal_(m.weight, std=0.01)
53 | # init.xavier_normal_(m.weight)
54 | if m.bias is not None:
55 | init.constant_(m.bias, 0)
56 | elif isinstance(m, NN.BatchNorm2d): #NN.BatchNorm2d
57 | init.constant_(m.weight, 1)
58 | init.constant_(m.bias, 0)
59 | elif isinstance(m, nn.Linear):
60 | init.normal_(m.weight, std=0.01)
61 | if m.bias is not None:
62 | init.constant_(m.bias, 0)
63 |
64 |
65 | class FFM(nn.Module):
66 | def __init__(self, inchannels, midchannels, outchannels, upfactor=2):
67 | super(FFM, self).__init__()
68 | self.inchannels = inchannels
69 | self.midchannels = midchannels
70 | self.outchannels = outchannels
71 | self.upfactor = upfactor
72 |
73 | self.ftb1 = FTB(inchannels=self.inchannels, midchannels=self.midchannels)
74 | self.ftb2 = FTB(inchannels=self.midchannels, midchannels=self.outchannels)
75 |
76 | self.upsample = nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True)
77 |
78 | self.init_params()
79 |
80 | def forward(self, low_x, high_x):
81 | x = self.ftb1(low_x)
82 | x = x + high_x
83 | x = self.ftb2(x)
84 | x = self.upsample(x)
85 |
86 | return x
87 |
88 | def init_params(self):
89 | for m in self.modules():
90 | if isinstance(m, nn.Conv2d):
91 | #init.kaiming_normal_(m.weight, mode='fan_out')
92 | init.normal_(m.weight, std=0.01)
93 | #init.xavier_normal_(m.weight)
94 | if m.bias is not None:
95 | init.constant_(m.bias, 0)
96 | elif isinstance(m, nn.ConvTranspose2d):
97 | #init.kaiming_normal_(m.weight, mode='fan_out')
98 | init.normal_(m.weight, std=0.01)
99 | #init.xavier_normal_(m.weight)
100 | if m.bias is not None:
101 | init.constant_(m.bias, 0)
102 | elif isinstance(m, NN.BatchNorm2d): #NN.Batchnorm2d
103 | init.constant_(m.weight, 1)
104 | init.constant_(m.bias, 0)
105 | elif isinstance(m, nn.Linear):
106 | init.normal_(m.weight, std=0.01)
107 | if m.bias is not None:
108 | init.constant_(m.bias, 0)
109 |
110 |
111 | class AO(nn.Module):
112 | # Adaptive output module
113 | def __init__(self, inchannels, outchannels, upfactor=2):
114 | super(AO, self).__init__()
115 | self.inchannels = inchannels
116 | self.outchannels = outchannels
117 | self.upfactor = upfactor
118 |
119 | self.adapt_conv = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.inchannels/2, kernel_size=3, padding=1, stride=1, bias=True),\
120 | NN.BatchNorm2d(num_features=self.inchannels/2),\
121 | nn.ReLU(inplace=True),\
122 | nn.Conv2d(in_channels=self.inchannels/2, out_channels=self.outchannels, kernel_size=3, padding=1, stride=1, bias=True),\
123 | nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True))
124 |
125 | self.init_params()
126 |
127 | def forward(self, x):
128 | x = self.adapt_conv(x)
129 | return x
130 |
131 | def init_params(self):
132 | for m in self.modules():
133 | if isinstance(m, nn.Conv2d):
134 | #init.kaiming_normal_(m.weight, mode='fan_out')
135 | init.normal_(m.weight, std=0.01)
136 | #init.xavier_normal_(m.weight)
137 | if m.bias is not None:
138 | init.constant_(m.bias, 0)
139 | elif isinstance(m, nn.ConvTranspose2d):
140 | #init.kaiming_normal_(m.weight, mode='fan_out')
141 | init.normal_(m.weight, std=0.01)
142 | #init.xavier_normal_(m.weight)
143 | if m.bias is not None:
144 | init.constant_(m.bias, 0)
145 | elif isinstance(m, NN.BatchNorm2d): #NN.Batchnorm2d
146 | init.constant_(m.weight, 1)
147 | init.constant_(m.bias, 0)
148 | elif isinstance(m, nn.Linear):
149 | init.normal_(m.weight, std=0.01)
150 | if m.bias is not None:
151 | init.constant_(m.bias, 0)
152 |
--------------------------------------------------------------------------------
/models/resnet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import math
3 | import torch.utils.model_zoo as model_zoo
4 | import torchvision
5 |
6 | import sys
7 | sys.path.append('/data0/kexian/Code/kxian_Adobe/MPO_edgeGuidedRanking/models/syncbn')
8 | from modules import nn as NN
9 |
10 |
11 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
12 | 'resnet152']
13 |
14 |
15 | model_urls = {
16 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
17 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
18 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
19 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
20 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
21 | }
22 |
23 |
24 | def conv3x3(in_planes, out_planes, stride=1):
25 | """3x3 convolution with padding"""
26 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
27 | padding=1, bias=False)
28 |
29 |
30 | class BasicBlock(nn.Module):
31 | expansion = 1
32 |
33 | def __init__(self, inplanes, planes, stride=1, downsample=None):
34 | super(BasicBlock, self).__init__()
35 | self.conv1 = conv3x3(inplanes, planes, stride)
36 | self.bn1 = NN.BatchNorm2d(planes) #NN.BatchNorm2d
37 | self.relu = nn.ReLU(inplace=True)
38 | self.conv2 = conv3x3(planes, planes)
39 | self.bn2 = NN.BatchNorm2d(planes) #NN.BatchNorm2d
40 | self.downsample = downsample
41 | self.stride = stride
42 |
43 | def forward(self, x):
44 | residual = x
45 |
46 | out = self.conv1(x)
47 | out = self.bn1(out)
48 | out = self.relu(out)
49 |
50 | out = self.conv2(out)
51 | out = self.bn2(out)
52 |
53 | if self.downsample is not None:
54 | residual = self.downsample(x)
55 |
56 | out += residual
57 | out = self.relu(out)
58 |
59 | return out
60 |
61 |
62 | class Bottleneck(nn.Module):
63 | expansion = 4
64 |
65 | def __init__(self, inplanes, planes, stride=1, downsample=None):
66 | super(Bottleneck, self).__init__()
67 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
68 | self.bn1 = NN.BatchNorm2d(planes) #NN.BatchNorm2d
69 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
70 | padding=1, bias=False)
71 | self.bn2 = NN.BatchNorm2d(planes) #NN.BatchNorm2d
72 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
73 | self.bn3 = NN.BatchNorm2d(planes * self.expansion) #NN.BatchNorm2d
74 | self.relu = nn.ReLU(inplace=True)
75 | self.downsample = downsample
76 | self.stride = stride
77 |
78 | def forward(self, x):
79 | residual = x
80 |
81 | out = self.conv1(x)
82 | out = self.bn1(out)
83 | out = self.relu(out)
84 |
85 | out = self.conv2(out)
86 | out = self.bn2(out)
87 | out = self.relu(out)
88 |
89 | out = self.conv3(out)
90 | out = self.bn3(out)
91 |
92 | if self.downsample is not None:
93 | residual = self.downsample(x)
94 |
95 | out += residual
96 | out = self.relu(out)
97 |
98 | return out
99 |
100 |
101 | class ResNet(nn.Module):
102 |
103 | def __init__(self, block, layers, num_classes=1000):
104 | self.inplanes = 64
105 | super(ResNet, self).__init__()
106 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
107 | bias=False)
108 | self.bn1 = NN.BatchNorm2d(64) #NN.BatchNorm2d
109 | self.relu = nn.ReLU(inplace=True)
110 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
111 | self.layer1 = self._make_layer(block, 64, layers[0])
112 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
113 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
114 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
115 | #self.avgpool = nn.AvgPool2d(7, stride=1)
116 | #self.fc = nn.Linear(512 * block.expansion, num_classes)
117 |
118 | for m in self.modules():
119 | if isinstance(m, nn.Conv2d):
120 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
121 | elif isinstance(m, nn.BatchNorm2d):
122 | nn.init.constant_(m.weight, 1)
123 | nn.init.constant_(m.bias, 0)
124 |
125 | def _make_layer(self, block, planes, blocks, stride=1):
126 | downsample = None
127 | if stride != 1 or self.inplanes != planes * block.expansion:
128 | downsample = nn.Sequential(
129 | nn.Conv2d(self.inplanes, planes * block.expansion,
130 | kernel_size=1, stride=stride, bias=False),
131 | NN.BatchNorm2d(planes * block.expansion), #NN.BatchNorm2d
132 | )
133 |
134 | layers = []
135 | layers.append(block(self.inplanes, planes, stride, downsample))
136 | self.inplanes = planes * block.expansion
137 | for i in range(1, blocks):
138 | layers.append(block(self.inplanes, planes))
139 |
140 | return nn.Sequential(*layers)
141 |
142 | def forward(self, x):
143 | features = []
144 |
145 | x = self.conv1(x)
146 | x = self.bn1(x)
147 | x = self.relu(x)
148 | x = self.maxpool(x)
149 |
150 | x = self.layer1(x)
151 | features.append(x)
152 | x = self.layer2(x)
153 | features.append(x)
154 | x = self.layer3(x)
155 | features.append(x)
156 | x = self.layer4(x)
157 | features.append(x)
158 |
159 | return features
160 |
161 |
162 | def resnet18(pretrained=True, **kwargs):
163 | """Constructs a ResNet-18 model.
164 | Args:
165 | pretrained (bool): If True, returns a model pre-trained on ImageNet
166 | """
167 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
168 | if pretrained:
169 | pretrained_model = torchvision.models.resnet18(pretrained=True)
170 | pretrained_dict = pretrained_model.state_dict()
171 | model_dict = model.state_dict()
172 | pretrained_dict = {k:v for k, v in pretrained_dict.items() if k in model_dict}
173 | model_dict.update(pretrained_dict)
174 | model.load_state_dict(model_dict)
175 |
176 | return model
177 |
178 |
179 | def resnet34(pretrained=True, **kwargs):
180 | """Constructs a ResNet-34 model.
181 | Args:
182 | pretrained (bool): If True, returns a model pre-trained on ImageNet
183 | """
184 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
185 | if pretrained:
186 | pretrained_model = torchvision.models.resnet34(pretrained=True)
187 | pretrained_dict = pretrained_model.state_dict()
188 | model_dict = model.state_dict()
189 | pretrained_dict = {k:v for k, v in pretrained_dict.items() if k in model_dict}
190 | model_dict.update(pretrained_dict)
191 | model.load_state_dict(model_dict)
192 |
193 | return model
194 |
195 |
196 | def resnet50(pretrained=True, **kwargs):
197 | """Constructs a ResNet-50 model.
198 | Args:
199 | pretrained (bool): If True, returns a model pre-trained on ImageNet
200 | """
201 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
202 | if pretrained:
203 | pretrained_model = torchvision.models.resnet50(pretrained=True)
204 | pretrained_dict = pretrained_model.state_dict()
205 | model_dict = model.state_dict()
206 | pretrained_dict = {k:v for k, v in pretrained_dict.items() if k in model_dict}
207 | model_dict.update(pretrained_dict)
208 | model.load_state_dict(model_dict)
209 |
210 | return model
211 |
212 |
213 | def resnet101(pretrained=True, **kwargs):
214 | """Constructs a ResNet-101 model.
215 | Args:
216 | pretrained (bool): If True, returns a model pre-trained on ImageNet
217 | """
218 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
219 | if pretrained:
220 | pretrained_model = torchvision.models.resnet101(pretrained=True)
221 | pretrained_dict = pretrained_model.state_dict()
222 | model_dict = model.state_dict()
223 | pretrained_dict = {k:v for k, v in pretrained_dict.items() if k in model_dict}
224 | model_dict.update(pretrained_dict)
225 | model.load_state_dict(model_dict)
226 |
227 | return model
228 |
229 |
230 | def resnet152(pretrained=True, **kwargs):
231 | """Constructs a ResNet-152 model.
232 | Args:
233 | pretrained (bool): If True, returns a model pre-trained on ImageNet
234 | """
235 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
236 | if pretrained:
237 | pretrained_model = torchvision.models.resnet152(pretrained=True)
238 | pretrained_dict = pretrained_model.state_dict()
239 | model_dict = model.state_dict()
240 | pretrained_dict = {k:v for k, v in pretrained_dict.items() if k in model_dict}
241 | model_dict.update(pretrained_dict)
242 | model.load_state_dict(model_dict)
243 |
244 | return model
245 |
--------------------------------------------------------------------------------
/models/syncbn/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Tamaki Kojima
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/models/syncbn/README.md:
--------------------------------------------------------------------------------
1 | # pytorch-syncbn
2 |
3 | Tamaki Kojima(tamakoji@gmail.com)
4 |
5 | ## Overview
6 | This is alternative implementation of "Synchronized Multi-GPU Batch Normalization" which computes global stats across gpus instead of locally computed. SyncBN are getting important for those input image is large, and must use multi-gpu to increase the minibatch-size for the training.
7 |
8 | The code was inspired by [Pytorch-Encoding](https://github.com/zhanghang1989/PyTorch-Encoding) and [Inplace-ABN](https://github.com/mapillary/inplace_abn)
9 |
10 | ## Remarks
11 | - Unlike [Pytorch-Encoding](https://github.com/zhanghang1989/PyTorch-Encoding), you don't need custom `nn.DataParallel`
12 | - Unlike [Inplace-ABN](https://github.com/mapillary/inplace_abn), you can just replace your `nn.BatchNorm2d` to this module implementation, since it will not mark for inplace operation
13 | - You can plug into arbitrary module written in PyTorch to enable Synchronized BatchNorm
14 | - Backward computation is rewritten and tested against behavior of `nn.BatchNorm2d`
15 |
16 | ## Requirements
17 | For PyTorch, please refer to https://pytorch.org/
18 |
19 | NOTE : The code is tested only with PyTorch v0.4.0, CUDA9.1.85/CuDNN7.1.4 on ubuntu16.04
20 |
21 | (It can also be compiled and run on the JetsonTX2, but won't work as multi-gpu synchronnized BN.)
22 |
23 | To install all dependencies using pip, run:
24 |
25 | ```
26 | pip install -U -r requirements.txt
27 | ```
28 |
29 | ## Build
30 |
31 | use `make_ext.sh` to build the extension. for example:
32 | ```
33 | PYTHON_CMD=python3 ./make_ext.sh
34 | ```
35 |
36 | ## Usage
37 |
38 | Please refer to [`test.py`](./test.py) for testing the difference between `nn.BatchNorm2d` and `modules.nn.BatchNorm2d`
39 |
40 | ```
41 | import torch
42 | from modules import nn as NN
43 | num_gpu = torch.cuda.device_count()
44 | model = nn.Sequential(
45 | nn.Conv2d(3, 3, 1, 1, bias=False),
46 | NN.BatchNorm2d(3),
47 | nn.ReLU(inplace=True),
48 | nn.Conv2d(3, 3, 1, 1, bias=False),
49 | NN.BatchNorm2d(3),
50 | ).cuda()
51 | model = nn.DataParallel(model, device_ids=range(num_gpu))
52 | x = torch.rand(num_gpu, 3, 2, 2).cuda()
53 | z = model(x)
54 | ```
55 |
56 | ## Math
57 |
58 | ### Forward
59 | 1. compute
in each gpu
60 | 2. gather all
from workers to master and compute
where
61 |
62 |
63 |
64 | and
65 |
66 |
67 |
68 | and then above global stats to be shared to all gpus, update running_mean and running_var by moving average using global stats.
69 |
70 | 3. forward batchnorm using global stats by
71 |
72 |
73 |
74 | and then
75 |
76 |
77 |
78 | where
is weight parameter and
is bias parameter.
79 |
80 | 4. save
for backward
81 |
82 | ### Backward
83 |
84 | 1. Restore saved
85 |
86 | 2. Compute below sums on each gpu
87 |
88 |
89 |
90 | and
91 |
92 |
93 |
94 | where
95 |
96 | then gather them at master node to sum up global, and normalize with N where N is total number of elements for each channels. Global sums are then shared among all gpus.
97 |
98 | 3. compute gradients using global stats
99 |
100 |
101 |
102 | where
103 |
104 |
105 |
106 | and
107 |
108 |
109 |
110 | and finally,
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 | Note that in the implementation, normalization with N is performed at step (2) and above equation and implementation is not exactly the same, but mathematically is same.
119 |
120 | You can go deeper on above explanation at [Kevin Zakka's Blog](https://kevinzakka.github.io/2016/09/14/batch_normalization/)
--------------------------------------------------------------------------------
/models/syncbn/make_ext.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | PYTHON_CMD=${PYTHON_CMD:=python}
4 | CUDA_PATH=/usr/local/cuda-8.0
5 | CUDA_INCLUDE_DIR=/usr/local/cuda-8.0/include
6 | GENCODE="-gencode arch=compute_61,code=sm_61 \
7 | -gencode arch=compute_52,code=sm_52 \
8 | -gencode arch=compute_52,code=compute_52"
9 | NVCCOPT="-std=c++11 -x cu --expt-extended-lambda -O3 -Xcompiler -fPIC"
10 |
11 | ROOTDIR=$PWD
12 | echo "========= Build BatchNorm2dSync ========="
13 | if [ -z "$1" ]; then TORCH=$($PYTHON_CMD -c "import os; import torch; print(os.path.dirname(torch.__file__))"); else TORCH="$1"; fi
14 | cd modules/functional/_syncbn/src
15 | $CUDA_PATH/bin/nvcc -c -o syncbn.cu.o syncbn.cu $NVCCOPT $GENCODE -I $CUDA_INCLUDE_DIR
16 | cd ../
17 | $PYTHON_CMD build.py
18 | cd $ROOTDIR
19 |
20 | # END
21 | echo "========= Build Complete ========="
22 |
--------------------------------------------------------------------------------
/models/syncbn/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KexianHust/Structure-Guided-Ranking-Loss/6fa853da2aeb53ad0a66e95484a5ae7fc816bb4a/models/syncbn/modules/__init__.py
--------------------------------------------------------------------------------
/models/syncbn/modules/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KexianHust/Structure-Guided-Ranking-Loss/6fa853da2aeb53ad0a66e95484a5ae7fc816bb4a/models/syncbn/modules/__init__.pyc
--------------------------------------------------------------------------------
/models/syncbn/modules/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KexianHust/Structure-Guided-Ranking-Loss/6fa853da2aeb53ad0a66e95484a5ae7fc816bb4a/models/syncbn/modules/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/models/syncbn/modules/functional/__init__.py:
--------------------------------------------------------------------------------
1 | from .syncbn import batchnorm2d_sync
2 |
--------------------------------------------------------------------------------
/models/syncbn/modules/functional/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KexianHust/Structure-Guided-Ranking-Loss/6fa853da2aeb53ad0a66e95484a5ae7fc816bb4a/models/syncbn/modules/functional/__init__.pyc
--------------------------------------------------------------------------------
/models/syncbn/modules/functional/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KexianHust/Structure-Guided-Ranking-Loss/6fa853da2aeb53ad0a66e95484a5ae7fc816bb4a/models/syncbn/modules/functional/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/models/syncbn/modules/functional/__pycache__/syncbn.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KexianHust/Structure-Guided-Ranking-Loss/6fa853da2aeb53ad0a66e95484a5ae7fc816bb4a/models/syncbn/modules/functional/__pycache__/syncbn.cpython-37.pyc
--------------------------------------------------------------------------------
/models/syncbn/modules/functional/_syncbn/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KexianHust/Structure-Guided-Ranking-Loss/6fa853da2aeb53ad0a66e95484a5ae7fc816bb4a/models/syncbn/modules/functional/_syncbn/__init__.py
--------------------------------------------------------------------------------
/models/syncbn/modules/functional/_syncbn/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KexianHust/Structure-Guided-Ranking-Loss/6fa853da2aeb53ad0a66e95484a5ae7fc816bb4a/models/syncbn/modules/functional/_syncbn/__init__.pyc
--------------------------------------------------------------------------------
/models/syncbn/modules/functional/_syncbn/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KexianHust/Structure-Guided-Ranking-Loss/6fa853da2aeb53ad0a66e95484a5ae7fc816bb4a/models/syncbn/modules/functional/_syncbn/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/models/syncbn/modules/functional/_syncbn/_ext/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KexianHust/Structure-Guided-Ranking-Loss/6fa853da2aeb53ad0a66e95484a5ae7fc816bb4a/models/syncbn/modules/functional/_syncbn/_ext/__init__.py
--------------------------------------------------------------------------------
/models/syncbn/modules/functional/_syncbn/_ext/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KexianHust/Structure-Guided-Ranking-Loss/6fa853da2aeb53ad0a66e95484a5ae7fc816bb4a/models/syncbn/modules/functional/_syncbn/_ext/__init__.pyc
--------------------------------------------------------------------------------
/models/syncbn/modules/functional/_syncbn/_ext/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KexianHust/Structure-Guided-Ranking-Loss/6fa853da2aeb53ad0a66e95484a5ae7fc816bb4a/models/syncbn/modules/functional/_syncbn/_ext/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/models/syncbn/modules/functional/_syncbn/_ext/syncbn/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from torch.utils.ffi import _wrap_function
3 | from ._syncbn import lib as _lib, ffi as _ffi
4 |
5 | __all__ = []
6 | def _import_symbols(locals):
7 | for symbol in dir(_lib):
8 | fn = getattr(_lib, symbol)
9 | if callable(fn):
10 | locals[symbol] = _wrap_function(fn, _ffi)
11 | else:
12 | locals[symbol] = fn
13 | __all__.append(symbol)
14 |
15 | _import_symbols(locals())
16 |
--------------------------------------------------------------------------------
/models/syncbn/modules/functional/_syncbn/_ext/syncbn/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KexianHust/Structure-Guided-Ranking-Loss/6fa853da2aeb53ad0a66e95484a5ae7fc816bb4a/models/syncbn/modules/functional/_syncbn/_ext/syncbn/__init__.pyc
--------------------------------------------------------------------------------
/models/syncbn/modules/functional/_syncbn/_ext/syncbn/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KexianHust/Structure-Guided-Ranking-Loss/6fa853da2aeb53ad0a66e95484a5ae7fc816bb4a/models/syncbn/modules/functional/_syncbn/_ext/syncbn/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/models/syncbn/modules/functional/_syncbn/_ext/syncbn/_syncbn.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KexianHust/Structure-Guided-Ranking-Loss/6fa853da2aeb53ad0a66e95484a5ae7fc816bb4a/models/syncbn/modules/functional/_syncbn/_ext/syncbn/_syncbn.so
--------------------------------------------------------------------------------
/models/syncbn/modules/functional/_syncbn/build.py:
--------------------------------------------------------------------------------
1 | import os
2 | from torch.utils.ffi import create_extension
3 |
4 | sources = ['src/syncbn.cpp']
5 | headers = ['src/syncbn.h']
6 | extra_objects = ['src/syncbn.cu.o']
7 | with_cuda = True
8 |
9 | this_file = os.path.dirname(os.path.realpath(__file__))
10 | extra_objects = [os.path.join(this_file, fname) for fname in extra_objects]
11 |
12 | ffi = create_extension(
13 | '_ext.syncbn',
14 | headers=headers,
15 | sources=sources,
16 | relative_to=__file__,
17 | with_cuda=with_cuda,
18 | extra_objects=extra_objects,
19 | extra_compile_args=["-std=c++11"]
20 | )
21 |
22 | if __name__ == '__main__':
23 | ffi.build()
24 |
--------------------------------------------------------------------------------
/models/syncbn/modules/functional/_syncbn/src/common.h:
--------------------------------------------------------------------------------
1 | #ifndef __COMMON__
2 | #define __COMMON__
3 | #include
4 |
5 | /*
6 | * General settings
7 | */
8 | const int WARP_SIZE = 32;
9 | const int MAX_BLOCK_SIZE = 512;
10 |
11 | /*
12 | * Utility functions
13 | */
14 | template
15 | __device__ __forceinline__ T WARP_SHFL_XOR(
16 | T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) {
17 | #if CUDART_VERSION >= 9000
18 | return __shfl_xor_sync(mask, value, laneMask, width);
19 | #else
20 | return __shfl_xor(value, laneMask, width);
21 | #endif
22 | }
23 |
24 | __device__ __forceinline__ int getMSB(int val) { return 31 - __clz(val); }
25 |
26 | static int getNumThreads(int nElem) {
27 | int threadSizes[5] = {32, 64, 128, 256, MAX_BLOCK_SIZE};
28 | for (int i = 0; i != 5; ++i) {
29 | if (nElem <= threadSizes[i]) {
30 | return threadSizes[i];
31 | }
32 | }
33 | return MAX_BLOCK_SIZE;
34 | }
35 |
36 |
37 | #endif
--------------------------------------------------------------------------------
/models/syncbn/modules/functional/_syncbn/src/syncbn.cpp:
--------------------------------------------------------------------------------
1 | // All functions assume that input and output tensors are already initialized
2 | // and have the correct dimensions
3 | #include
4 |
5 | extern THCState *state;
6 |
7 | void get_sizes(const THCudaTensor *t, int *N, int *C, int *S) {
8 | // Get sizes
9 | *S = 1;
10 | *N = THCudaTensor_size(state, t, 0);
11 | *C = THCudaTensor_size(state, t, 1);
12 | if (THCudaTensor_nDimension(state, t) > 2) {
13 | for (int i = 2; i < THCudaTensor_nDimension(state, t); ++i) {
14 | *S *= THCudaTensor_size(state, t, i);
15 | }
16 | }
17 | }
18 |
19 | // Forward definition of implementation functions
20 | extern "C" {
21 | int _syncbn_sum_sqsum_cuda(int N, int C, int S,
22 | const float *x, float *sum, float *sqsum,
23 | cudaStream_t stream);
24 | int _syncbn_forward_cuda(
25 | int N, int C, int S, float *z, const float *x,
26 | const float *gamma, const float *beta,
27 | const float *mean, const float *var, float eps, cudaStream_t stream);
28 | int _syncbn_backward_xhat_cuda(
29 | int N, int C, int S, const float *dz, const float *x,
30 | const float *mean, const float *var, float *sum_dz, float *sum_dz_xhat,
31 | float eps, cudaStream_t stream);
32 | int _syncbn_backward_cuda(
33 | int N, int C, int S, const float *dz, const float *x,
34 | const float *gamma, const float *beta,
35 | const float *mean, const float *var,
36 | const float *sum_dz, const float *sum_dz_xhat,
37 | float *dx, float *dgamma, float *dbeta,
38 | float eps, cudaStream_t stream);
39 | }
40 |
41 | extern "C" int syncbn_sum_sqsum_cuda(
42 | const THCudaTensor *x, THCudaTensor *sum, THCudaTensor *sqsum) {
43 | cudaStream_t stream = THCState_getCurrentStream(state);
44 |
45 | int S, N, C;
46 | get_sizes(x, &N, &C, &S);
47 |
48 | // Get pointers
49 | const float *x_data = THCudaTensor_data(state, x);
50 | float *sum_data = THCudaTensor_data(state, sum);
51 | float *sqsum_data = THCudaTensor_data(state, sqsum);
52 |
53 | return _syncbn_sum_sqsum_cuda(N, C, S, x_data, sum_data, sqsum_data, stream);
54 | }
55 |
56 | extern "C" int syncbn_forward_cuda(
57 | THCudaTensor *z, const THCudaTensor *x,
58 | const THCudaTensor *gamma, const THCudaTensor *beta,
59 | const THCudaTensor *mean, const THCudaTensor *var, float eps){
60 | cudaStream_t stream = THCState_getCurrentStream(state);
61 |
62 | int S, N, C;
63 | get_sizes(x, &N, &C, &S);
64 |
65 | // Get pointers
66 | float *z_data = THCudaTensor_data(state, z);
67 | const float *x_data = THCudaTensor_data(state, x);
68 | const float *gamma_data = THCudaTensor_nDimension(state, gamma) != 0 ?
69 | THCudaTensor_data(state, gamma) : 0;
70 | const float *beta_data = THCudaTensor_nDimension(state, beta) != 0 ?
71 | THCudaTensor_data(state, beta) : 0;
72 | const float *mean_data = THCudaTensor_data(state, mean);
73 | const float *var_data = THCudaTensor_data(state, var);
74 |
75 | return _syncbn_forward_cuda(
76 | N, C, S, z_data, x_data, gamma_data, beta_data,
77 | mean_data, var_data, eps, stream);
78 |
79 | }
80 |
81 | extern "C" int syncbn_backward_xhat_cuda(
82 | const THCudaTensor *dz, const THCudaTensor *x,
83 | const THCudaTensor *mean, const THCudaTensor *var,
84 | THCudaTensor *sum_dz, THCudaTensor *sum_dz_xhat, float eps) {
85 | cudaStream_t stream = THCState_getCurrentStream(state);
86 |
87 | int S, N, C;
88 | get_sizes(dz, &N, &C, &S);
89 |
90 | // Get pointers
91 | const float *dz_data = THCudaTensor_data(state, dz);
92 | const float *x_data = THCudaTensor_data(state, x);
93 | const float *mean_data = THCudaTensor_data(state, mean);
94 | const float *var_data = THCudaTensor_data(state, var);
95 | float *sum_dz_data = THCudaTensor_data(state, sum_dz);
96 | float *sum_dz_xhat_data = THCudaTensor_data(state, sum_dz_xhat);
97 |
98 | return _syncbn_backward_xhat_cuda(
99 | N, C, S, dz_data, x_data, mean_data, var_data,
100 | sum_dz_data, sum_dz_xhat_data, eps, stream);
101 |
102 | }
103 | extern "C" int syncbn_backard_cuda(
104 | const THCudaTensor *dz, const THCudaTensor *x,
105 | const THCudaTensor *gamma, const THCudaTensor *beta,
106 | const THCudaTensor *mean, const THCudaTensor *var,
107 | const THCudaTensor *sum_dz, const THCudaTensor *sum_dz_xhat,
108 | THCudaTensor *dx, THCudaTensor *dgamma, THCudaTensor *dbeta, float eps) {
109 | cudaStream_t stream = THCState_getCurrentStream(state);
110 |
111 | int S, N, C;
112 | get_sizes(dz, &N, &C, &S);
113 |
114 | // Get pointers
115 | const float *dz_data = THCudaTensor_data(state, dz);
116 | const float *x_data = THCudaTensor_data(state, x);
117 | const float *gamma_data = THCudaTensor_nDimension(state, gamma) != 0 ?
118 | THCudaTensor_data(state, gamma) : 0;
119 | const float *beta_data = THCudaTensor_nDimension(state, beta) != 0 ?
120 | THCudaTensor_data(state, beta) : 0;
121 | const float *mean_data = THCudaTensor_data(state, mean);
122 | const float *var_data = THCudaTensor_data(state, var);
123 | const float *sum_dz_data = THCudaTensor_data(state, sum_dz);
124 | const float *sum_dz_xhat_data = THCudaTensor_data(state, sum_dz_xhat);
125 | float *dx_data = THCudaTensor_nDimension(state, dx) != 0 ?
126 | THCudaTensor_data(state, dx) : 0;
127 | float *dgamma_data = THCudaTensor_nDimension(state, dgamma) != 0 ?
128 | THCudaTensor_data(state, dgamma) : 0;
129 | float *dbeta_data = THCudaTensor_nDimension(state, dbeta) != 0 ?
130 | THCudaTensor_data(state, dbeta) : 0;
131 |
132 | return _syncbn_backward_cuda(
133 | N, C, S, dz_data, x_data, gamma_data, beta_data,
134 | mean_data, var_data, sum_dz_data, sum_dz_xhat_data,
135 | dx_data, dgamma_data, dbeta_data, eps, stream);
136 | }
--------------------------------------------------------------------------------
/models/syncbn/modules/functional/_syncbn/src/syncbn.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 |
5 | #include "common.h"
6 | #include "syncbn.cu.h"
7 |
8 | /*
9 | * Device functions and data structures
10 | */
11 | struct Float2 {
12 | float v1, v2;
13 | __device__ Float2() {}
14 | __device__ Float2(float _v1, float _v2) : v1(_v1), v2(_v2) {}
15 | __device__ Float2(float v) : v1(v), v2(v) {}
16 | __device__ Float2(int v) : v1(v), v2(v) {}
17 | __device__ Float2 &operator+=(const Float2 &a) {
18 | v1 += a.v1;
19 | v2 += a.v2;
20 | return *this;
21 | }
22 | };
23 |
24 | struct GradOp {
25 | __device__ GradOp(float _gamma, float _beta, const float *_z,
26 | const float *_dz, int c, int s)
27 | : gamma(_gamma), beta(_beta), z(_z), dz(_dz), C(c), S(s) {}
28 | __device__ __forceinline__ Float2 operator()(int batch, int plane, int n) {
29 | float _y = (z[(batch * C + plane) * S + n] - beta) / gamma;
30 | float _dz = dz[(batch * C + plane) * S + n];
31 | return Float2(_dz, _y * _dz);
32 | }
33 | const float gamma;
34 | const float beta;
35 | const float *z;
36 | const float *dz;
37 | const int C;
38 | const int S;
39 | };
40 |
41 | static __device__ __forceinline__ float warpSum(float val) {
42 | #if __CUDA_ARCH__ >= 300
43 | for (int i = 0; i < getMSB(WARP_SIZE); ++i) {
44 | val += WARP_SHFL_XOR(val, 1 << i, WARP_SIZE);
45 | }
46 | #else
47 | __shared__ float values[MAX_BLOCK_SIZE];
48 | values[threadIdx.x] = val;
49 | __threadfence_block();
50 | const int base = (threadIdx.x / WARP_SIZE) * WARP_SIZE;
51 | for (int i = 1; i < WARP_SIZE; i++) {
52 | val += values[base + ((i + threadIdx.x) % WARP_SIZE)];
53 | }
54 | #endif
55 | return val;
56 | }
57 |
58 | static __device__ __forceinline__ Float2 warpSum(Float2 value) {
59 | value.v1 = warpSum(value.v1);
60 | value.v2 = warpSum(value.v2);
61 | return value;
62 | }
63 |
64 | template
65 | __device__ T reduce(Op op, int plane, int N, int C, int S) {
66 | T sum = (T)0;
67 | for (int batch = 0; batch < N; ++batch) {
68 | for (int x = threadIdx.x; x < S; x += blockDim.x) {
69 | sum += op(batch, plane, x);
70 | }
71 | }
72 |
73 | // sum over NumThreads within a warp
74 | sum = warpSum(sum);
75 |
76 | // 'transpose', and reduce within warp again
77 | __shared__ T shared[32];
78 | __syncthreads();
79 | if (threadIdx.x % WARP_SIZE == 0) {
80 | shared[threadIdx.x / WARP_SIZE] = sum;
81 | }
82 | if (threadIdx.x >= blockDim.x / WARP_SIZE && threadIdx.x < WARP_SIZE) {
83 | // zero out the other entries in shared
84 | shared[threadIdx.x] = (T)0;
85 | }
86 | __syncthreads();
87 | if (threadIdx.x / WARP_SIZE == 0) {
88 | sum = warpSum(shared[threadIdx.x]);
89 | if (threadIdx.x == 0) {
90 | shared[0] = sum;
91 | }
92 | }
93 | __syncthreads();
94 |
95 | // Everyone picks it up, should be broadcast into the whole gradInput
96 | return shared[0];
97 | }
98 |
99 | /*----------------------------------------------------------------------------
100 | *
101 | * BatchNorm2dSyncFunc Kernel implementations
102 | *
103 | *---------------------------------------------------------------------------*/
104 |
105 | struct SqSumOp {
106 | __device__ SqSumOp(const float *t, int c, int s)
107 | : tensor(t), C(c), S(s) {}
108 | __device__ __forceinline__ Float2 operator()(int batch, int plane, int n) {
109 | float t = tensor[(batch * C + plane) * S + n];
110 | return Float2(t, t * t);
111 | }
112 | const float *tensor;
113 | const int C;
114 | const int S;
115 | };
116 |
117 | struct XHatOp {
118 | __device__ XHatOp(float _gamma, float _beta, const float *_z,
119 | const float *_dz, int c, int s)
120 | : gamma(_gamma), beta(_beta), z(_z), dz(_dz), C(c), S(s) {}
121 | __device__ __forceinline__ Float2 operator()(int batch, int plane, int n) {
122 | // xhat = (x-beta)*gamma
123 | float _xhat = (z[(batch * C + plane) * S + n] - beta) * gamma;
124 | // for dxhat*x_hat
125 | float _dz = dz[(batch * C + plane) * S + n];
126 | return Float2(_dz, _dz * _xhat);
127 | }
128 | const float gamma;
129 | const float beta;
130 | const float *z;
131 | const float *dz;
132 | const int C;
133 | const int S;
134 | };
135 |
136 | __global__ void syncbn_sum_sqsum_kernel(const float *x, float *sum, float *sqsum,
137 | int N, int C, int S) {
138 | int plane = blockIdx.x;
139 | Float2 res = reduce(SqSumOp(x, C, S), plane, N, C, S);
140 | float _sum = res.v1;
141 | float _sqsum = res.v2;
142 | __syncthreads();
143 | if (threadIdx.x == 0) {
144 | sum[plane] = _sum;
145 | sqsum[plane] = _sqsum;
146 | }
147 | }
148 |
149 | __global__ void syncbn_forward_kernel(
150 | float *z, const float *x, const float *gamma, const float *beta,
151 | const float *mean, const float *var, float eps, int N, int C, int S) {
152 |
153 | int c = blockIdx.x;
154 | float _mean = mean[c];
155 | float _var = var[c];
156 | float invtsd = 0;
157 | if (_var != 0.f || eps != 0.f) {
158 | invtsd = 1 / sqrt(_var + eps);
159 | }
160 | float _gamma = gamma != 0 ? gamma[c] : 1.f;
161 | float _beta = beta != 0 ? beta[c] : 0.f;
162 | for (int batch = 0; batch < N; ++batch) {
163 | for (int n = threadIdx.x; n < S; n += blockDim.x) {
164 | float _x = x[(batch * C + c) * S + n];
165 | float _xhat = (_x - _mean) * invtsd;
166 | float _z = _xhat * _gamma + _beta;
167 | z[(batch * C + c) * S + n] = _z;
168 | }
169 | }
170 | }
171 |
172 | __global__ void syncbn_backward_xhat_kernel(
173 | const float *dz, const float *x, const float *mean, const float *var,
174 | float *sum_dz, float *sum_dz_xhat, float eps, int N, int C, int S) {
175 |
176 | int c = blockIdx.x;
177 | float _mean = mean[c];
178 | float _var = var[c];
179 | float _invstd = 0;
180 | if (_var != 0.f || eps != 0.f) {
181 | _invstd = 1 / sqrt(_var + eps);
182 | }
183 | Float2 res = reduce(
184 | XHatOp(_invstd, _mean, x, dz, C, S), c, N, C, S);
185 | // \sum(\frac{dJ}{dy_i})
186 | float _sum_dz = res.v1;
187 | // \sum(\frac{dJ}{dy_i}*\hat{x_i})
188 | float _sum_dz_xhat = res.v2;
189 | __syncthreads();
190 | if (threadIdx.x == 0) {
191 | // \sum(\frac{dJ}{dy_i})
192 | sum_dz[c] = _sum_dz;
193 | // \sum(\frac{dJ}{dy_i}*\hat{x_i})
194 | sum_dz_xhat[c] = _sum_dz_xhat;
195 | }
196 | }
197 |
198 |
199 | __global__ void syncbn_backward_kernel(
200 | const float *dz, const float *x, const float *gamma, const float *beta,
201 | const float *mean, const float *var,
202 | const float *sum_dz, const float *sum_dz_xhat,
203 | float *dx, float *dgamma, float *dbeta,
204 | float eps, int N, int C, int S) {
205 |
206 | int c = blockIdx.x;
207 | float _mean = mean[c];
208 | float _var = var[c];
209 | float _gamma = gamma != 0 ? gamma[c] : 1.f;
210 | float _sum_dz = sum_dz[c];
211 | float _sum_dz_xhat = sum_dz_xhat[c];
212 | float _invstd = 0;
213 | if (_var != 0.f || eps != 0.f) {
214 | _invstd = 1 / sqrt(_var + eps);
215 | }
216 | /*
217 | \frac{dJ}{dx_i} = \frac{1}{N\sqrt{(\sigma^2+\epsilon)}} (
218 | N\frac{dJ}{d\hat{x_i}} -
219 | \sum_{j=1}^{N}(\frac{dJ}{d\hat{x_j}}) -
220 | \hat{x_i}\sum_{j=1}^{N}(\frac{dJ}{d\hat{x_j}}\hat{x_j})
221 | )
222 | Note : N is omitted here since it will be accumulated and
223 | _sum_dz and _sum_dz_xhat expected to be already normalized
224 | before the call.
225 | */
226 | if (dx != 0) {
227 | float _mul = _gamma * _invstd;
228 | for (int batch = 0; batch < N; ++batch) {
229 | for (int n = threadIdx.x; n < S; n += blockDim.x) {
230 | float _dz = dz[(batch * C + c) * S + n];
231 | float _xhat = (x[(batch * C + c) * S + n] - _mean) * _invstd;
232 | float _dx = (_dz - _sum_dz - _xhat * _sum_dz_xhat) * _mul;
233 | dx[(batch * C + c) * S + n] = _dx;
234 | }
235 | }
236 | }
237 | float _norm = N * S;
238 | if (dgamma != 0) {
239 | if (threadIdx.x == 0) {
240 | // \frac{dJ}{d\gamma} = \sum(\frac{dJ}{dy_i}*\hat{x_i})
241 | dgamma[c] += _sum_dz_xhat * _norm;
242 | }
243 | }
244 | if (dbeta != 0) {
245 | if (threadIdx.x == 0) {
246 | // \frac{dJ}{d\beta} = \sum(\frac{dJ}{dy_i})
247 | dbeta[c] += _sum_dz * _norm;
248 | }
249 | }
250 | }
251 |
252 | extern "C" int _syncbn_sum_sqsum_cuda(int N, int C, int S,
253 | const float *x, float *sum, float *sqsum,
254 | cudaStream_t stream) {
255 | // Run kernel
256 | dim3 blocks(C);
257 | dim3 threads(getNumThreads(S));
258 | syncbn_sum_sqsum_kernel<<>>(x, sum, sqsum, N, C, S);
259 |
260 | // Check for errors
261 | cudaError_t err = cudaGetLastError();
262 | if (err != cudaSuccess)
263 | return 0;
264 | else
265 | return 1;
266 | }
267 |
268 | extern "C" int _syncbn_forward_cuda(
269 | int N, int C, int S, float *z, const float *x,
270 | const float *gamma, const float *beta, const float *mean, const float *var,
271 | float eps, cudaStream_t stream) {
272 |
273 | // Run kernel
274 | dim3 blocks(C);
275 | dim3 threads(getNumThreads(S));
276 | syncbn_forward_kernel<<>>(
277 | z, x, gamma, beta, mean, var, eps, N, C, S);
278 |
279 | // Check for errors
280 | cudaError_t err = cudaGetLastError();
281 | if (err != cudaSuccess)
282 | return 0;
283 | else
284 | return 1;
285 | }
286 |
287 |
288 | extern "C" int _syncbn_backward_xhat_cuda(
289 | int N, int C, int S, const float *dz, const float *x,
290 | const float *mean, const float *var, float *sum_dz, float *sum_dz_xhat,
291 | float eps, cudaStream_t stream) {
292 |
293 | // Run kernel
294 | dim3 blocks(C);
295 | dim3 threads(getNumThreads(S));
296 | syncbn_backward_xhat_kernel<<>>(
297 | dz, x,mean, var, sum_dz, sum_dz_xhat, eps, N, C, S);
298 |
299 | // Check for errors
300 | cudaError_t err = cudaGetLastError();
301 | if (err != cudaSuccess)
302 | return 0;
303 | else
304 | return 1;
305 | }
306 |
307 |
308 | extern "C" int _syncbn_backward_cuda(
309 | int N, int C, int S, const float *dz, const float *x,
310 | const float *gamma, const float *beta, const float *mean, const float *var,
311 | const float *sum_dz, const float *sum_dz_xhat,
312 | float *dx, float *dgamma, float *dbeta, float eps, cudaStream_t stream) {
313 |
314 | // Run kernel
315 | dim3 blocks(C);
316 | dim3 threads(getNumThreads(S));
317 | syncbn_backward_kernel<<>>(
318 | dz, x, gamma, beta, mean, var, sum_dz, sum_dz_xhat,
319 | dx, dgamma, dbeta, eps, N, C, S);
320 |
321 | // Check for errors
322 | cudaError_t err = cudaGetLastError();
323 | if (err != cudaSuccess)
324 | return 0;
325 | else
326 | return 1;
327 | }
328 |
329 |
--------------------------------------------------------------------------------
/models/syncbn/modules/functional/_syncbn/src/syncbn.cu.h:
--------------------------------------------------------------------------------
1 | #ifndef __SYNCBN__
2 | #define __SYNCBN__
3 |
4 | /*
5 | * Exported functions
6 | */
7 | extern "C" int _syncbn_sum_sqsum_cuda(int N, int C, int S, const float *x,
8 | float *sum, float *sqsum,
9 | cudaStream_t stream);
10 | extern "C" int _syncbn_forward_cuda(
11 | int N, int C, int S, float *z, const float *x,
12 | const float *gamma, const float *beta, const float *mean, const float *var,
13 | float eps, cudaStream_t stream);
14 | extern "C" int _syncbn_backward_xhat_cuda(
15 | int N, int C, int S, const float *dz, const float *x,
16 | const float *mean, const float *var, float *sum_dz, float *sum_dz_xhat,
17 | float eps, cudaStream_t stream);
18 | extern "C" int _syncbn_backward_cuda(
19 | int N, int C, int S, const float *dz, const float *x,
20 | const float *gamma, const float *beta, const float *mean, const float *var,
21 | const float *sum_dz, const float *sum_dz_xhat,
22 | float *dx, float *dweight, float *dbias,
23 | float eps, cudaStream_t stream);
24 |
25 |
26 | #endif
27 |
--------------------------------------------------------------------------------
/models/syncbn/modules/functional/_syncbn/src/syncbn.cu.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KexianHust/Structure-Guided-Ranking-Loss/6fa853da2aeb53ad0a66e95484a5ae7fc816bb4a/models/syncbn/modules/functional/_syncbn/src/syncbn.cu.o
--------------------------------------------------------------------------------
/models/syncbn/modules/functional/_syncbn/src/syncbn.h:
--------------------------------------------------------------------------------
1 | int syncbn_sum_sqsum_cuda(
2 | const THCudaTensor *x, THCudaTensor *sum, THCudaTensor *sqsum);
3 | int syncbn_forward_cuda(
4 | THCudaTensor *z, const THCudaTensor *x,
5 | const THCudaTensor *gamma, const THCudaTensor *beta,
6 | const THCudaTensor *mean, const THCudaTensor *var, float eps);
7 | int syncbn_backward_xhat_cuda(
8 | const THCudaTensor *dz, const THCudaTensor *x,
9 | const THCudaTensor *mean, const THCudaTensor *var,
10 | THCudaTensor *sum_dz, THCudaTensor *sum_dz_xhat,
11 | float eps);
12 | int syncbn_backard_cuda(
13 | const THCudaTensor *dz, const THCudaTensor *x,
14 | const THCudaTensor *gamma, const THCudaTensor *beta,
15 | const THCudaTensor *mean, const THCudaTensor *var,
16 | const THCudaTensor *sum_dz, const THCudaTensor *sum_dz_xhat,
17 | THCudaTensor *dx, THCudaTensor *dgamma, THCudaTensor *dbeta, float eps);
18 |
--------------------------------------------------------------------------------
/models/syncbn/modules/functional/syncbn.py:
--------------------------------------------------------------------------------
1 | """
2 | /*****************************************************************************/
3 |
4 | BatchNorm2dSync with multi-gpu
5 |
6 | code referenced from : https://github.com/mapillary/inplace_abn
7 |
8 | /*****************************************************************************/
9 | """
10 | from __future__ import absolute_import
11 | from __future__ import division
12 | from __future__ import print_function
13 |
14 | import torch.cuda.comm as comm
15 | from torch.autograd import Function
16 | from torch.autograd.function import once_differentiable
17 |
18 | from ._syncbn._ext import syncbn as _lib_bn
19 |
20 |
21 | def _count_samples(x):
22 | count = 1
23 | for i, s in enumerate(x.size()):
24 | if i != 1:
25 | count *= s
26 | return count
27 |
28 |
29 | def _check_contiguous(*args):
30 | if not all([mod is None or mod.is_contiguous() for mod in args]):
31 | raise ValueError("Non-contiguous input")
32 |
33 |
34 | class BatchNorm2dSyncFunc(Function):
35 |
36 | @classmethod
37 | def forward(cls, ctx, x, weight, bias, running_mean, running_var,
38 | extra, compute_stats=True, momentum=0.1, eps=1e-05):
39 | # Save context
40 | if extra is not None:
41 | cls._parse_extra(ctx, extra)
42 | ctx.compute_stats = compute_stats
43 | ctx.momentum = momentum
44 | ctx.eps = eps
45 | if ctx.compute_stats:
46 | N = _count_samples(x) * (ctx.master_queue.maxsize + 1)
47 | assert N > 1
48 | num_features = running_mean.size(0)
49 | # 1. compute sum(x) and sum(x^2)
50 | xsum = x.new().resize_(num_features)
51 | xsqsum = x.new().resize_(num_features)
52 | _check_contiguous(x, xsum, xsqsum)
53 | _lib_bn.syncbn_sum_sqsum_cuda(x.detach(), xsum, xsqsum)
54 | if ctx.is_master:
55 | xsums, xsqsums = [xsum], [xsqsum]
56 | # master : gatther all sum(x) and sum(x^2) from slaves
57 | for _ in range(ctx.master_queue.maxsize):
58 | xsum_w, xsqsum_w = ctx.master_queue.get()
59 | ctx.master_queue.task_done()
60 | xsums.append(xsum_w)
61 | xsqsums.append(xsqsum_w)
62 | xsum = comm.reduce_add(xsums)
63 | xsqsum = comm.reduce_add(xsqsums)
64 | mean = xsum / N
65 | sumvar = xsqsum - xsum * mean
66 | var = sumvar / N
67 | uvar = sumvar / (N - 1)
68 | # master : broadcast global mean, variance to all slaves
69 | tensors = comm.broadcast_coalesced(
70 | (mean, uvar, var), [mean.get_device()] + ctx.worker_ids)
71 | for ts, queue in zip(tensors[1:], ctx.worker_queues):
72 | queue.put(ts)
73 | else:
74 | # slave : send sum(x) and sum(x^2) to master
75 | ctx.master_queue.put((xsum, xsqsum))
76 | # slave : get global mean and variance
77 | mean, uvar, var = ctx.worker_queue.get()
78 | ctx.worker_queue.task_done()
79 |
80 | # Update running stats
81 | running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean)
82 | running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * uvar)
83 | ctx.N = N
84 | ctx.save_for_backward(x, weight, bias, mean, var)
85 | else:
86 | mean, var = running_mean, running_var
87 |
88 | output = x.new().resize_as_(x)
89 | _check_contiguous(output, x, mean, var, weight, bias)
90 | # do batch norm forward
91 | _lib_bn.syncbn_forward_cuda(
92 | output, x, weight if weight is not None else x.new(),
93 | bias if bias is not None else x.new(), mean, var, ctx.eps)
94 | return output
95 |
96 | @staticmethod
97 | @once_differentiable
98 | def backward(ctx, dz):
99 | x, weight, bias, mean, var = ctx.saved_tensors
100 | dz = dz.contiguous()
101 | if ctx.needs_input_grad[0]:
102 | dx = dz.new().resize_as_(dz)
103 | else:
104 | dx = None
105 | if ctx.needs_input_grad[1]:
106 | dweight = dz.new().resize_as_(mean).zero_()
107 | else:
108 | dweight = None
109 | if ctx.needs_input_grad[2]:
110 | dbias = dz.new().resize_as_(mean).zero_()
111 | else:
112 | dbias = None
113 | _check_contiguous(x, dz, weight, bias, mean, var)
114 |
115 | # 1. compute \sum(\frac{dJ}{dy_i}) and \sum(\frac{dJ}{dy_i}*\hat{x_i})
116 | num_features = mean.size(0)
117 | sum_dz = x.new().resize_(num_features)
118 | sum_dz_xhat = x.new().resize_(num_features)
119 | _check_contiguous(sum_dz, sum_dz_xhat)
120 | _lib_bn.syncbn_backward_xhat_cuda(
121 | dz, x, mean, var, sum_dz, sum_dz_xhat, ctx.eps)
122 | if ctx.is_master:
123 | sum_dzs, sum_dz_xhats = [sum_dz], [sum_dz_xhat]
124 | # master : gatther from slaves
125 | for _ in range(ctx.master_queue.maxsize):
126 | sum_dz_w, sum_dz_xhat_w = ctx.master_queue.get()
127 | ctx.master_queue.task_done()
128 | sum_dzs.append(sum_dz_w)
129 | sum_dz_xhats.append(sum_dz_xhat_w)
130 | # master : compute global stats
131 | sum_dz = comm.reduce_add(sum_dzs)
132 | sum_dz_xhat = comm.reduce_add(sum_dz_xhats)
133 | sum_dz /= ctx.N
134 | sum_dz_xhat /= ctx.N
135 | # master : broadcast global stats
136 | tensors = comm.broadcast_coalesced(
137 | (sum_dz, sum_dz_xhat), [mean.get_device()] + ctx.worker_ids)
138 | for ts, queue in zip(tensors[1:], ctx.worker_queues):
139 | queue.put(ts)
140 | else:
141 | # slave : send to master
142 | ctx.master_queue.put((sum_dz, sum_dz_xhat))
143 | # slave : get global stats
144 | sum_dz, sum_dz_xhat = ctx.worker_queue.get()
145 | ctx.worker_queue.task_done()
146 |
147 | # do batch norm backward
148 | _lib_bn.syncbn_backard_cuda(
149 | dz, x, weight if weight is not None else dz.new(),
150 | bias if bias is not None else dz.new(),
151 | mean, var, sum_dz, sum_dz_xhat,
152 | dx if dx is not None else dz.new(),
153 | dweight if dweight is not None else dz.new(),
154 | dbias if dbias is not None else dz.new(), ctx.eps)
155 |
156 | return dx, dweight, dbias, None, None, None, \
157 | None, None, None, None, None
158 |
159 | @staticmethod
160 | def _parse_extra(ctx, extra):
161 | ctx.is_master = extra["is_master"]
162 | if ctx.is_master:
163 | ctx.master_queue = extra["master_queue"]
164 | ctx.worker_queues = extra["worker_queues"]
165 | ctx.worker_ids = extra["worker_ids"]
166 | else:
167 | ctx.master_queue = extra["master_queue"]
168 | ctx.worker_queue = extra["worker_queue"]
169 |
170 | batchnorm2d_sync = BatchNorm2dSyncFunc.apply
171 |
172 | __all__ = ["batchnorm2d_sync"]
173 |
--------------------------------------------------------------------------------
/models/syncbn/modules/functional/syncbn.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KexianHust/Structure-Guided-Ranking-Loss/6fa853da2aeb53ad0a66e95484a5ae7fc816bb4a/models/syncbn/modules/functional/syncbn.pyc
--------------------------------------------------------------------------------
/models/syncbn/modules/nn/__init__.py:
--------------------------------------------------------------------------------
1 | from .syncbn import *
2 |
--------------------------------------------------------------------------------
/models/syncbn/modules/nn/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KexianHust/Structure-Guided-Ranking-Loss/6fa853da2aeb53ad0a66e95484a5ae7fc816bb4a/models/syncbn/modules/nn/__init__.pyc
--------------------------------------------------------------------------------
/models/syncbn/modules/nn/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KexianHust/Structure-Guided-Ranking-Loss/6fa853da2aeb53ad0a66e95484a5ae7fc816bb4a/models/syncbn/modules/nn/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/models/syncbn/modules/nn/__pycache__/syncbn.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KexianHust/Structure-Guided-Ranking-Loss/6fa853da2aeb53ad0a66e95484a5ae7fc816bb4a/models/syncbn/modules/nn/__pycache__/syncbn.cpython-37.pyc
--------------------------------------------------------------------------------
/models/syncbn/modules/nn/syncbn.py:
--------------------------------------------------------------------------------
1 | """
2 | /*****************************************************************************/
3 |
4 | BatchNorm2dSync with multi-gpu
5 |
6 | /*****************************************************************************/
7 | """
8 | from __future__ import absolute_import
9 | from __future__ import division
10 | from __future__ import print_function
11 |
12 | try:
13 | # python 3
14 | from queue import Queue
15 | except ImportError:
16 | # python 2
17 | from Queue import Queue
18 |
19 | import torch
20 | import torch.nn as nn
21 | from modules.functional import batchnorm2d_sync
22 |
23 |
24 | class BatchNorm2d(nn.BatchNorm2d):
25 | """
26 | BatchNorm2d with automatic multi-GPU Sync
27 | """
28 |
29 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
30 | track_running_stats=True):
31 | super(BatchNorm2d, self).__init__(
32 | num_features, eps=eps, momentum=momentum, affine=affine,
33 | track_running_stats=track_running_stats)
34 | self.devices = list(range(torch.cuda.device_count()))
35 | if len(self.devices) > 1:
36 | # Initialize queues
37 | self.worker_ids = self.devices[1:]
38 | self.master_queue = Queue(len(self.worker_ids))
39 | self.worker_queues = [Queue(1) for _ in self.worker_ids]
40 |
41 | def forward(self, x):
42 | compute_stats = self.training or not self.track_running_stats
43 | if compute_stats and len(self.devices) > 1:
44 | if x.get_device() == self.devices[0]:
45 | # Master mode
46 | extra = {
47 | "is_master": True,
48 | "master_queue": self.master_queue,
49 | "worker_queues": self.worker_queues,
50 | "worker_ids": self.worker_ids
51 | }
52 | else:
53 | # Worker mode
54 | extra = {
55 | "is_master": False,
56 | "master_queue": self.master_queue,
57 | "worker_queue": self.worker_queues[
58 | self.worker_ids.index(x.get_device())]
59 | }
60 | return batchnorm2d_sync(x, self.weight, self.bias,
61 | self.running_mean, self.running_var,
62 | extra, compute_stats, self.momentum,
63 | self.eps)
64 | return super(BatchNorm2d, self).forward(x)
65 |
66 | def __repr__(self):
67 | """repr"""
68 | rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \
69 | ' affine={affine}, devices={devices})'
70 | return rep.format(name=self.__class__.__name__, **self.__dict__)
71 |
--------------------------------------------------------------------------------
/models/syncbn/modules/nn/syncbn.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KexianHust/Structure-Guided-Ranking-Loss/6fa853da2aeb53ad0a66e95484a5ae7fc816bb4a/models/syncbn/modules/nn/syncbn.pyc
--------------------------------------------------------------------------------
/models/syncbn/requirements.txt:
--------------------------------------------------------------------------------
1 | future
2 | cffi
3 |
--------------------------------------------------------------------------------
/models/syncbn/test.py:
--------------------------------------------------------------------------------
1 | """
2 | /*****************************************************************************/
3 |
4 | Test for BatchNorm2dSync with multi-gpu
5 |
6 | /*****************************************************************************/
7 | """
8 | from __future__ import absolute_import
9 | from __future__ import division
10 | from __future__ import print_function
11 |
12 | import sys
13 | import numpy as np
14 | import torch
15 | from torch import nn
16 | from torch.nn import functional as F
17 | sys.path.append("./")
18 | from modules import nn as NN
19 |
20 | torch.backends.cudnn.deterministic = True
21 |
22 |
23 | def init_weight(model):
24 | for m in model.modules():
25 | if isinstance(m, nn.Conv2d):
26 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
27 | m.weight.data.normal_(0, np.sqrt(2. / n))
28 | elif isinstance(m, NN.BatchNorm2d) or isinstance(m, nn.BatchNorm2d):
29 | m.weight.data.fill_(1)
30 | m.bias.data.zero_()
31 | elif isinstance(m, nn.Linear):
32 | m.bias.data.zero_()
33 |
34 | num_gpu = torch.cuda.device_count()
35 | print("num_gpu={}".format(num_gpu))
36 | if num_gpu < 2:
37 | print("No multi-gpu found. NN.BatchNorm2d will act as normal nn.BatchNorm2d")
38 |
39 | m1 = nn.Sequential(
40 | nn.Conv2d(3, 3, 1, 1, bias=False),
41 | nn.BatchNorm2d(3),
42 | nn.ReLU(inplace=True),
43 | nn.Conv2d(3, 3, 1, 1, bias=False),
44 | nn.BatchNorm2d(3),
45 | ).cuda()
46 | torch.manual_seed(123)
47 | init_weight(m1)
48 | m2 = nn.Sequential(
49 | nn.Conv2d(3, 3, 1, 1, bias=False),
50 | NN.BatchNorm2d(3),
51 | nn.ReLU(inplace=True),
52 | nn.Conv2d(3, 3, 1, 1, bias=False),
53 | NN.BatchNorm2d(3),
54 | ).cuda()
55 | torch.manual_seed(123)
56 | init_weight(m2)
57 | m2 = nn.DataParallel(m2, device_ids=range(num_gpu))
58 | o1 = torch.optim.SGD(m1.parameters(), 1e-3)
59 | o2 = torch.optim.SGD(m2.parameters(), 1e-3)
60 | y = torch.ones(num_gpu).float().cuda()
61 | torch.manual_seed(123)
62 | for _ in range(100):
63 | x = torch.rand(num_gpu, 3, 2, 2).cuda()
64 | o1.zero_grad()
65 | z1 = m1(x)
66 | l1 = F.mse_loss(z1.mean(-1).mean(-1).mean(-1), y)
67 | l1.backward()
68 | o1.step()
69 | o2.zero_grad()
70 | z2 = m2(x)
71 | l2 = F.mse_loss(z2.mean(-1).mean(-1).mean(-1), y)
72 | l2.backward()
73 | o2.step()
74 | print(m2.module[1].bias.grad - m1[1].bias.grad)
75 | print(m2.module[1].weight.grad - m1[1].weight.grad)
76 | print(m2.module[-1].bias.grad - m1[-1].bias.grad)
77 | print(m2.module[-1].weight.grad - m1[-1].weight.grad)
78 | m2 = m2.module
79 | print("===============================")
80 | print("m1(nn.BatchNorm2d) running_mean",
81 | m1[1].running_mean, m1[-1].running_mean)
82 | print("m2(NN.BatchNorm2d) running_mean",
83 | m2[1].running_mean, m2[-1].running_mean)
84 | print("m1(nn.BatchNorm2d) running_var", m1[1].running_var, m1[-1].running_var)
85 | print("m2(NN.BatchNorm2d) running_var", m2[1].running_var, m2[-1].running_var)
86 | print("m1(nn.BatchNorm2d) weight", m1[1].weight, m1[-1].weight)
87 | print("m2(NN.BatchNorm2d) weight", m2[1].weight, m2[-1].weight)
88 | print("m1(nn.BatchNorm2d) bias", m1[1].bias, m1[-1].bias)
89 | print("m2(NN.BatchNorm2d) bias", m2[1].bias, m2[-1].bias)
90 |
--------------------------------------------------------------------------------
/ranking_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import numpy as np
4 | import torch.nn.functional as F
5 |
6 |
7 | """
8 | Sampling strategies: RS (Random Sampling), EGS (Edge-Guided Sampling), and IGS (Instance-Guided Sampling)
9 | """
10 | ###########
11 | # RANDOM SAMPLING
12 | # input:
13 | # inputs[i,:], targets[i, :], masks[i, :], self.mask_value, self.point_pairs
14 | # return:
15 | # inputs_A, inputs_B, targets_A, targets_B, consistent_masks_A, consistent_masks_B
16 | ###########
17 | def randomSampling(inputs, targets, masks, threshold, sample_num):
18 |
19 | # find A-B point pairs from predictions
20 | inputs_index = torch.masked_select(inputs, targets.gt(threshold))
21 | num_effect_pixels = len(inputs_index)
22 | shuffle_effect_pixels = torch.randperm(num_effect_pixels).cuda()
23 | inputs_A = inputs_index[shuffle_effect_pixels[0:sample_num*2:2]]
24 | inputs_B = inputs_index[shuffle_effect_pixels[1:sample_num*2:2]]
25 | # find corresponding pairs from GT
26 | target_index = torch.masked_select(targets, targets.gt(threshold))
27 | targets_A = target_index[shuffle_effect_pixels[0:sample_num*2:2]]
28 | targets_B = target_index[shuffle_effect_pixels[1:sample_num*2:2]]
29 | # only compute the losses of point pairs with valid GT
30 | consistent_masks_index = torch.masked_select(masks, targets.gt(threshold))
31 | consistent_masks_A = consistent_masks_index[shuffle_effect_pixels[0:sample_num*2:2]]
32 | consistent_masks_B = consistent_masks_index[shuffle_effect_pixels[1:sample_num*2:2]]
33 |
34 | # The amount of A and B should be the same!!
35 | if len(targets_A) > len(targets_B):
36 | targets_A = targets_A[:-1]
37 | inputs_A = inputs_A[:-1]
38 | consistent_masks_A = consistent_masks_A[:-1]
39 |
40 | return inputs_A, inputs_B, targets_A, targets_B, consistent_masks_A, consistent_masks_B
41 |
42 | ###########
43 | # EDGE-GUIDED SAMPLING
44 | # input:
45 | # inputs[i,:], targets[i, :], masks[i, :], edges_img[i], thetas_img[i], masks[i, :], h, w
46 | # return:
47 | # inputs_A, inputs_B, targets_A, targets_B, masks_A, masks_B
48 | ###########
49 | def ind2sub(idx, cols):
50 | r = idx / cols
51 | c = idx - r * cols
52 | return r, c
53 |
54 | def sub2ind(r, c, cols):
55 | idx = r * cols + c
56 | return idx
57 |
58 | def edgeGuidedSampling(inputs, targets, edges_img, thetas_img, masks, h, w):
59 |
60 | # find edges
61 | edges_max = edges_img.max()
62 | edges_mask = edges_img.ge(edges_max*0.1)
63 | edges_loc = edges_mask.nonzero()
64 |
65 | inputs_edge = torch.masked_select(inputs, edges_mask)
66 | targets_edge = torch.masked_select(targets, edges_mask)
67 | thetas_edge = torch.masked_select(thetas_img, edges_mask)
68 | minlen = inputs_edge.size()[0]
69 |
70 | # find anchor points (i.e, edge points)
71 | sample_num = minlen
72 | index_anchors = torch.randint(0, minlen, (sample_num,), dtype=torch.long).cuda()
73 | anchors = torch.gather(inputs_edge, 0, index_anchors)
74 | theta_anchors = torch.gather(thetas_edge, 0, index_anchors)
75 | row_anchors, col_anchors = ind2sub(edges_loc[index_anchors].squeeze(1), w)
76 | ## compute the coordinates of 4-points, distances are from [2, 30]
77 | distance_matrix = torch.randint(2, 31, (4,sample_num)).cuda()
78 | pos_or_neg = torch.ones(4, sample_num).cuda()
79 | pos_or_neg[:2,:] = -pos_or_neg[:2,:]
80 | distance_matrix = distance_matrix.float() * pos_or_neg
81 | col = col_anchors.unsqueeze(0).expand(4, sample_num).long() + torch.round(distance_matrix.double() * torch.cos(theta_anchors).unsqueeze(0)).long()
82 | row = row_anchors.unsqueeze(0).expand(4, sample_num).long() + torch.round(distance_matrix.double() * torch.sin(theta_anchors).unsqueeze(0)).long()
83 |
84 | # constrain 0=w-1] = w-1
88 | row[row<0] = 0
89 | row[row>h-1] = h-1
90 |
91 | # a-b, b-c, c-d
92 | a = sub2ind(row[0,:], col[0,:], w)
93 | b = sub2ind(row[1,:], col[1,:], w)
94 | c = sub2ind(row[2,:], col[2,:], w)
95 | d = sub2ind(row[3,:], col[3,:], w)
96 | A = torch.cat((a,b,c), 0)
97 | B = torch.cat((b,c,d), 0)
98 |
99 | inputs_A = torch.gather(inputs, 0, A.long())
100 | inputs_B = torch.gather(inputs, 0, B.long())
101 | targets_A = torch.gather(targets, 0, A.long())
102 | targets_B = torch.gather(targets, 0, B.long())
103 | masks_A = torch.gather(masks, 0, A.long())
104 | masks_B = torch.gather(masks, 0, B.long())
105 |
106 | return inputs_A, inputs_B, targets_A, targets_B, masks_A, masks_B, sample_num
107 |
108 | ######################################################
109 | # EdgeguidedRankingLoss (with regularization term)
110 | # Please comment regularization_loss if you don't want to use multi-scale gradient matching term
111 | #####################################################
112 | class EdgeguidedRankingLoss(nn.Module):
113 | def __init__(self, point_pairs=10000, sigma=0.03, alpha=1.0, mask_value=-1e-8):
114 | super(EdgeguidedRankingLoss, self).__init__()
115 | self.point_pairs = point_pairs # number of point pairs
116 | self.sigma = sigma # used for determining the ordinal relationship between a selected pair
117 | self.alpha = alpha # used for balancing the effect of = and (<,>)
118 | self.mask_value = mask_value
119 | #self.regularization_loss = GradientLoss(scales=4)
120 |
121 | def getEdge(self, images):
122 | n,c,h,w = images.size()
123 | a = torch.Tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]).cuda().view((1,1,3,3)).repeat(1, 1, 1, 1)
124 | b = torch.Tensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]]).cuda().view((1,1,3,3)).repeat(1, 1, 1, 1)
125 | if c == 3:
126 | gradient_x = F.conv2d(images[:,0,:,:].unsqueeze(1), a)
127 | gradient_y = F.conv2d(images[:,0,:,:].unsqueeze(1), b)
128 | else:
129 | gradient_x = F.conv2d(images, a)
130 | gradient_y = F.conv2d(images, b)
131 | edges = torch.sqrt(torch.pow(gradient_x,2)+ torch.pow(gradient_y,2))
132 | edges = F.pad(edges, (1,1,1,1), "constant", 0)
133 | thetas = torch.atan2(gradient_y, gradient_x)
134 | thetas = F.pad(thetas, (1,1,1,1), "constant", 0)
135 |
136 | return edges, thetas
137 |
138 | def forward(self, inputs, targets, images, masks=None):
139 | if masks == None:
140 | masks = targets > self.mask_value
141 | # Comment this line if you don't want to use the multi-scale gradient matching term !!!
142 | # regularization_loss = self.regularization_loss(inputs.squeeze(1), targets.squeeze(1), masks.squeeze(1))
143 | # find edges from RGB
144 | edges_img, thetas_img = self.getEdge(images)
145 |
146 | #=============================
147 | n,c,h,w = targets.size()
148 | if n != 1:
149 | inputs = inputs.view(n, -1).double()
150 | targets = targets.view(n, -1).double()
151 | masks = masks.view(n, -1).double()
152 | edges_img = edges_img.view(n, -1).double()
153 | thetas_img = thetas_img.view(n, -1).double()
154 |
155 | else:
156 | inputs = inputs.contiguous().view(1, -1).double()
157 | targets = targets.contiguous().view(1, -1).double()
158 | masks = masks.contiguous().view(1, -1).double()
159 | edges_img = edges_img.contiguous().view(1, -1).double()
160 | thetas_img = thetas_img.contiguous().view(1, -1).double()
161 |
162 | # initialization
163 | loss = torch.DoubleTensor([0.0]).cuda()
164 |
165 |
166 | for i in range(n):
167 | # Edge-Guided sampling
168 | inputs_A, inputs_B, targets_A, targets_B, masks_A, masks_B, sample_num = edgeGuidedSampling(inputs[i,:], targets[i, :], edges_img[i], thetas_img[i], masks[i, :], h, w)
169 | # Random Sampling
170 | random_sample_num = sample_num
171 | random_inputs_A, random_inputs_B, random_targets_A, random_targets_B, random_masks_A, random_masks_B = randomSampling(inputs[i,:], targets[i, :], masks[i, :], self.mask_value, random_sample_num)
172 |
173 | # Combine EGS + RS
174 | inputs_A = torch.cat((inputs_A, random_inputs_A), 0)
175 | inputs_B = torch.cat((inputs_B, random_inputs_B), 0)
176 | targets_A = torch.cat((targets_A, random_targets_A), 0)
177 | targets_B = torch.cat((targets_B, random_targets_B), 0)
178 | masks_A = torch.cat((masks_A, random_masks_A), 0)
179 | masks_B = torch.cat((masks_B, random_masks_B), 0)
180 |
181 | #GT ordinal relationship
182 | target_ratio = torch.div(targets_A+1e-6, targets_B+1e-6)
183 | mask_eq = target_ratio.lt(1.0 + self.sigma) * target_ratio.gt(1.0/(1.0+self.sigma))
184 | labels = torch.zeros_like(target_ratio)
185 | labels[target_ratio.ge(1.0 + self.sigma)] = 1
186 | labels[target_ratio.le(1.0/(1.0+self.sigma))] = -1
187 |
188 | # consider forward-backward consistency checking, i.e, only compute losses of point pairs with valid GT
189 | consistency_mask = masks_A * masks_B
190 |
191 | equal_loss = (inputs_A - inputs_B).pow(2) * mask_eq.double() * consistency_mask
192 | unequal_loss = torch.log(1 + torch.exp((-inputs_A + inputs_B) * labels)) * (~mask_eq).double() * consistency_mask
193 |
194 | # Please comment the regularization term if you don't want to use the multi-scale gradient matching loss !!!
195 | loss = loss + self.alpha * equal_loss.mean() + 1.0 * unequal_loss.mean() #+ 0.2 * regularization_loss.double()
196 |
197 | return loss[0].float()/n
198 |
--------------------------------------------------------------------------------