├── LICENCE.md ├── README.md ├── demo.py ├── demo.sh ├── examples ├── 000000000724.jpg ├── 000000404922.jpg ├── 00022_00197_outdoor_300_050.png ├── 0SpJOOTH7R4_144577767_image.jpg ├── 0SpJOOTH7R4_215215000_image.jpg └── frame_0017.png ├── models ├── DepthNet.py ├── networks.py ├── resnet.py └── syncbn │ ├── LICENSE │ ├── README.md │ ├── make_ext.sh │ ├── modules │ ├── __init__.py │ ├── __init__.pyc │ ├── __pycache__ │ │ └── __init__.cpython-37.pyc │ ├── functional │ │ ├── __init__.py │ │ ├── __init__.pyc │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ └── syncbn.cpython-37.pyc │ │ ├── _syncbn │ │ │ ├── __init__.py │ │ │ ├── __init__.pyc │ │ │ ├── __pycache__ │ │ │ │ └── __init__.cpython-37.pyc │ │ │ ├── _ext │ │ │ │ ├── __init__.py │ │ │ │ ├── __init__.pyc │ │ │ │ ├── __pycache__ │ │ │ │ │ └── __init__.cpython-37.pyc │ │ │ │ └── syncbn │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── __init__.pyc │ │ │ │ │ ├── __pycache__ │ │ │ │ │ └── __init__.cpython-37.pyc │ │ │ │ │ └── _syncbn.so │ │ │ ├── build.py │ │ │ └── src │ │ │ │ ├── common.h │ │ │ │ ├── syncbn.cpp │ │ │ │ ├── syncbn.cu │ │ │ │ ├── syncbn.cu.h │ │ │ │ ├── syncbn.cu.o │ │ │ │ └── syncbn.h │ │ ├── syncbn.py │ │ └── syncbn.pyc │ └── nn │ │ ├── __init__.py │ │ ├── __init__.pyc │ │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ └── syncbn.cpython-37.pyc │ │ ├── syncbn.py │ │ └── syncbn.pyc │ ├── requirements.txt │ └── test.py └── ranking_loss.py /LICENCE.md: -------------------------------------------------------------------------------- 1 | 2 | This software is for non-commercial purposes 3 | 4 | Copyright (c) 2020 Ke Xian All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 7 | 8 | Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 9 | 10 | Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 11 | 12 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 13 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Structure-Guided Ranking Loss for Single Image Depth Prediction 2 | This repository contains a pytorch implementation of our CVPR2020 paper "Structure-Guided Ranking Loss for Single Image Depth Prediction". 3 | [Project Page](https://KexianHust.github.io/Structure-Guided-Ranking-Loss/) 4 | ![Teaser Image](https://KexianHust.github.io/Structure-Guided-Ranking-Loss/teaser.png) 5 | 6 | ## Changelog 7 | * [Jun. 2020] Initial release 8 | 9 | ## To do 10 | - [ ] Mix data training 11 | 12 | ## Prerequisites 13 | * Pytorch >= 0.4.1 14 | * CUDA >= 0.8 15 | * Python >= 2.7 16 | * glob, matplotlib 17 | * Need to compile the syncbn module in models/syncbn. Note that the directory of the syncbn module should be modified in some .py files (i.e., DepthNet.py, resnet.py and networks.py) 18 | * Download the [model.pth.tar](https://drive.google.com/file/d/1p8c8-nUTNry5usQmGdTC2TrwWrp3dQ0y/view?usp=sharing) 19 | 20 | ## Inference 21 | ```bash 22 | # Before running, you should set the CUDA_VISIBLE_DEVICES in demo.sh 23 | bash demo.sh 24 | 25 | ``` 26 | 27 | If you find our work useful in your research, please consider citing the paper. 28 | 29 | ``` 30 | @InProceedings{Xian_2020_CVPR, 31 | author = {Xian, Ke and Zhang, Jianming and Wang, Oliver and Mai, Long and Lin, Zhe and Cao, Zhiguo}, 32 | title = {Structure-Guided Ranking Loss for Single Image Depth Prediction}, 33 | booktitle = {The IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 34 | month = {June}, 35 | year = {2020} 36 | } 37 | ``` 38 | 39 | ## Dataset 40 | Our [HRWSI](https://drive.google.com/file/d/1OVOx6x-B0Cs-m2z_-7ZxSgRFHz_VBvDd/view?usp=sharing) dataset is for research only! Some researchers may interested in the stereo data, so we provide the right views [here](https://drive.google.com/file/d/1HzEB7yQI05Q21dP9rRjnyMoEmvCckAQp/view?usp=sharing). Please let me know if you have any questions. 41 | 42 | ## Lisence 43 | Research only 44 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # demo 5 | 6 | """ 7 | Author: Ke Xian 8 | Email: kexian@hust.edu.cn 9 | Create_Date: 2019/05/21 10 | """ 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torchvision.transforms as transforms 15 | from torch.utils.data import DataLoader 16 | torch.backends.cudnn.deterministic = True 17 | torch.manual_seed(123) 18 | 19 | import os, argparse, sys 20 | import numpy as np 21 | import glob 22 | import matplotlib.pyplot as plt 23 | plt.switch_backend('agg') 24 | import warnings 25 | warnings.filterwarnings("ignore") 26 | from PIL import Image 27 | 28 | sys.path.append('models') 29 | import DepthNet 30 | 31 | # ======================= 32 | # demo 33 | # ======================= 34 | def demo(net, args): 35 | data_dir = args.data_dir 36 | img_transform = transforms.Compose([ 37 | transforms.ToTensor(), 38 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 39 | ]) 40 | 41 | for im in os.listdir(data_dir): 42 | im_dir = os.path.join(data_dir, im) 43 | print('Processing img: {}'.format(im_dir)) 44 | 45 | # Read image 46 | img = Image.open(im_dir).convert('RGB') 47 | ori_width, ori_height = img.size 48 | int_width = args.img_size[0] 49 | int_height = args.img_size[1] 50 | img = img.resize((int_width, int_height), Image.ANTIALIAS) 51 | tensor_img = img_transform(img) 52 | 53 | # forward 54 | input_img = torch.autograd.Variable(tensor_img.cuda().unsqueeze(0), volatile=True) 55 | output = net(input_img) 56 | 57 | # Normalization and save results 58 | depth = output.squeeze().cpu().data.numpy() 59 | min_d, max_d = depth.min(), depth.max() 60 | depth_norm = (depth - min_d) / (max_d - min_d) * 255 61 | depth_norm = depth_norm.astype(np.uint8) 62 | image_pil = Image.fromarray(depth_norm) 63 | 64 | output_dir = os.path.join(args.result_dir, im) 65 | image_pil = image_pil.resize((ori_width, ori_height), Image.BILINEAR) 66 | plt.imsave(output_dir, np.asarray(image_pil), cmap='inferno') 67 | 68 | 69 | if __name__ == '__main__': 70 | 71 | parser = argparse.ArgumentParser(description='MRDP Testing/Evaluation') 72 | parser.add_argument('--img_size', default=[448, 448], type=list, help='Image size of network input') 73 | parser.add_argument('--data_dir', default='examples', type=str, help='Data path') 74 | parser.add_argument('--result_dir', default='demo_results', type=str, help='Directory for saving results, default: demo_results') 75 | parser.add_argument('--gpu_id', default=0, type=int, help='GPU id, default:0') 76 | args = parser.parse_args() 77 | 78 | args.checkpoint = 'model.pth.tar' 79 | 80 | if not os.path.exists(args.result_dir): 81 | os.makedirs(args.result_dir) 82 | 83 | gpu_id = args.gpu_id 84 | torch.cuda.device(gpu_id) 85 | 86 | net = DepthNet.DepthNet() 87 | net = torch.nn.DataParallel(net, device_ids=[0]).cuda() 88 | checkpoint = torch.load(args.checkpoint) 89 | net.load_state_dict(checkpoint['state_dict']) 90 | net.eval() 91 | 92 | print('Begin to test ...') 93 | with torch.no_grad(): 94 | demo(net, args) 95 | print('Finished!') 96 | -------------------------------------------------------------------------------- /demo.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python demo.py 2 | -------------------------------------------------------------------------------- /examples/000000000724.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KexianHust/Structure-Guided-Ranking-Loss/6fa853da2aeb53ad0a66e95484a5ae7fc816bb4a/examples/000000000724.jpg -------------------------------------------------------------------------------- /examples/000000404922.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KexianHust/Structure-Guided-Ranking-Loss/6fa853da2aeb53ad0a66e95484a5ae7fc816bb4a/examples/000000404922.jpg -------------------------------------------------------------------------------- /examples/00022_00197_outdoor_300_050.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KexianHust/Structure-Guided-Ranking-Loss/6fa853da2aeb53ad0a66e95484a5ae7fc816bb4a/examples/00022_00197_outdoor_300_050.png -------------------------------------------------------------------------------- /examples/0SpJOOTH7R4_144577767_image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KexianHust/Structure-Guided-Ranking-Loss/6fa853da2aeb53ad0a66e95484a5ae7fc816bb4a/examples/0SpJOOTH7R4_144577767_image.jpg -------------------------------------------------------------------------------- /examples/0SpJOOTH7R4_215215000_image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KexianHust/Structure-Guided-Ranking-Loss/6fa853da2aeb53ad0a66e95484a5ae7fc816bb4a/examples/0SpJOOTH7R4_215215000_image.jpg -------------------------------------------------------------------------------- /examples/frame_0017.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KexianHust/Structure-Guided-Ranking-Loss/6fa853da2aeb53ad0a66e95484a5ae7fc816bb4a/examples/frame_0017.png -------------------------------------------------------------------------------- /models/DepthNet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | # coding: utf-8 3 | 4 | ''' 5 | Author: Ke Xian 6 | Email: kexian@hust.edu.cn 7 | Date: 2019/04/09 8 | ''' 9 | 10 | import torch 11 | import torchvision 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | import torch.nn.init as init 15 | 16 | import sys 17 | sys.path.append('/data0/kexian/Code/kxian_Adobe/MPO_edgeGuidedRanking/models/syncbn') 18 | from modules import nn as NN 19 | 20 | import resnet 21 | 22 | from networks import * 23 | 24 | class Decoder(nn.Module): 25 | def __init__(self, inchannels = [256, 512, 1024, 2048], midchannels = [256, 256, 256, 512], upfactors = [2,2,2,2], outchannels = 1): 26 | super(Decoder, self).__init__() 27 | self.inchannels = inchannels 28 | self.midchannels = midchannels 29 | self.upfactors = upfactors 30 | self.outchannels = outchannels 31 | 32 | self.conv = FTB(inchannels=self.inchannels[3], midchannels=self.midchannels[3]) 33 | self.conv1 = nn.Conv2d(in_channels=self.midchannels[3], out_channels=self.midchannels[2], kernel_size=3, padding=1, stride=1, bias=True) 34 | self.upsample = nn.Upsample(scale_factor=self.upfactors[3], mode='bilinear', align_corners=True) 35 | 36 | self.ffm2 = FFM(inchannels=self.inchannels[2], midchannels=self.midchannels[2], outchannels = self.midchannels[2], upfactor=self.upfactors[2]) 37 | self.ffm1 = FFM(inchannels=self.inchannels[1], midchannels=self.midchannels[1], outchannels = self.midchannels[1], upfactor=self.upfactors[1]) 38 | self.ffm0 = FFM(inchannels=self.inchannels[0], midchannels=self.midchannels[0], outchannels = self.midchannels[0], upfactor=self.upfactors[0]) 39 | 40 | self.outconv = AO(inchannels=self.inchannels[0], outchannels=self.outchannels, upfactor=2) 41 | 42 | self._init_params() 43 | 44 | def _init_params(self): 45 | for m in self.modules(): 46 | if isinstance(m, nn.Conv2d): 47 | #init.kaiming_normal_(m.weight, mode='fan_out') 48 | init.normal_(m.weight, std=0.01) 49 | #init.xavier_normal_(m.weight) 50 | if m.bias is not None: 51 | init.constant_(m.bias, 0) 52 | elif isinstance(m, nn.ConvTranspose2d): 53 | #init.kaiming_normal_(m.weight, mode='fan_out') 54 | init.normal_(m.weight, std=0.01) 55 | #init.xavier_normal_(m.weight) 56 | if m.bias is not None: 57 | init.constant_(m.bias, 0) 58 | elif isinstance(m, NN.BatchNorm2d): #NN.BatchNorm2d 59 | init.constant_(m.weight, 1) 60 | init.constant_(m.bias, 0) 61 | elif isinstance(m, nn.Linear): 62 | init.normal_(m.weight, std=0.01) 63 | if m.bias is not None: 64 | init.constant_(m.bias, 0) 65 | 66 | def forward(self, features): 67 | _,_,h,w = features[3].size() 68 | x = self.conv(features[3]) 69 | x = self.conv1(x) 70 | x = self.upsample(x) 71 | 72 | x = self.ffm2(features[2], x) 73 | x = self.ffm1(features[1], x) 74 | x = self.ffm0(features[0], x) 75 | 76 | #----------------------------------------- 77 | x = self.outconv(x) 78 | 79 | return x 80 | 81 | class DepthNet(nn.Module): 82 | __factory = { 83 | 18: resnet.resnet18, 84 | 34: resnet.resnet34, 85 | 50: resnet.resnet50, 86 | 101: resnet.resnet101, 87 | 152: resnet.resnet152 88 | } 89 | def __init__(self, 90 | backbone='resnet', 91 | depth=50, 92 | pretrained=True, 93 | inchannels=[256, 512, 1024, 2048], 94 | midchannels=[256, 256, 256, 512], 95 | upfactors=[2, 2, 2, 2], 96 | outchannels=1): 97 | super(DepthNet, self).__init__() 98 | self.backbone = backbone 99 | self.depth = depth 100 | self.pretrained = pretrained 101 | self.inchannels = inchannels 102 | self.midchannels = midchannels 103 | self.upfactors = upfactors 104 | self.outchannels = outchannels 105 | 106 | # Build model 107 | if self.depth not in DepthNet.__factory: 108 | raise KeyError("Unsupported depth:", self.depth) 109 | self.encoder = DepthNet.__factory[depth](pretrained=pretrained) 110 | 111 | self.decoder = Decoder(inchannels=self.inchannels, midchannels=self.midchannels, upfactors=self.upfactors, outchannels=self.outchannels) 112 | 113 | def forward(self, x): 114 | x = self.encoder(x) # 1/4, 1/8, 1/16, 1/32 115 | x = self.decoder(x) 116 | 117 | return x 118 | 119 | if __name__ == '__main__': 120 | net = DepthNet(depth=50, pretrained=True) 121 | print(net) 122 | inputs = torch.ones(4,3,128,128) 123 | out = net(inputs) 124 | print(out.size()) 125 | -------------------------------------------------------------------------------- /models/networks.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | # -*- coding: utf-8 -*- 3 | ''' 4 | Author: Ke Xian 5 | Email: kexian@hust.edu.cn 6 | Date: 2019/04/09 7 | ''' 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.init as init 12 | import sys 13 | sys.path.append('/data0/kexian/Code/kxian_Adobe/MPO_edgeGuidedRanking/models/syncbn') 14 | import modules.nn as NN 15 | 16 | # ============================================================================================================== 17 | 18 | class FTB(nn.Module): 19 | def __init__(self, inchannels, midchannels=512): 20 | super(FTB, self).__init__() 21 | self.in1 = inchannels 22 | self.mid = midchannels 23 | 24 | self.conv1 = nn.Conv2d(in_channels=self.in1, out_channels=self.mid, kernel_size=3, padding=1, stride=1, bias=True) 25 | # NN.BatchNorm2d 26 | self.conv_branch = nn.Sequential(nn.ReLU(inplace=True),\ 27 | nn.Conv2d(in_channels=self.mid, out_channels=self.mid, kernel_size=3, padding=1, stride=1, bias=True),\ 28 | NN.BatchNorm2d(num_features=self.mid),\ 29 | nn.ReLU(inplace=True),\ 30 | nn.Conv2d(in_channels=self.mid, out_channels= self.mid, kernel_size=3, padding=1, stride=1, bias=True)) 31 | self.relu = nn.ReLU(inplace=True) 32 | 33 | self.init_params() 34 | 35 | def forward(self, x): 36 | x = self.conv1(x) 37 | x = x + self.conv_branch(x) 38 | x = self.relu(x) 39 | 40 | return x 41 | 42 | def init_params(self): 43 | for m in self.modules(): 44 | if isinstance(m, nn.Conv2d): 45 | #init.kaiming_normal_(m.weight, mode='fan_out') 46 | init.normal_(m.weight, std=0.01) 47 | # init.xavier_normal_(m.weight) 48 | if m.bias is not None: 49 | init.constant_(m.bias, 0) 50 | elif isinstance(m, nn.ConvTranspose2d): 51 | #init.kaiming_normal_(m.weight, mode='fan_out') 52 | init.normal_(m.weight, std=0.01) 53 | # init.xavier_normal_(m.weight) 54 | if m.bias is not None: 55 | init.constant_(m.bias, 0) 56 | elif isinstance(m, NN.BatchNorm2d): #NN.BatchNorm2d 57 | init.constant_(m.weight, 1) 58 | init.constant_(m.bias, 0) 59 | elif isinstance(m, nn.Linear): 60 | init.normal_(m.weight, std=0.01) 61 | if m.bias is not None: 62 | init.constant_(m.bias, 0) 63 | 64 | 65 | class FFM(nn.Module): 66 | def __init__(self, inchannels, midchannels, outchannels, upfactor=2): 67 | super(FFM, self).__init__() 68 | self.inchannels = inchannels 69 | self.midchannels = midchannels 70 | self.outchannels = outchannels 71 | self.upfactor = upfactor 72 | 73 | self.ftb1 = FTB(inchannels=self.inchannels, midchannels=self.midchannels) 74 | self.ftb2 = FTB(inchannels=self.midchannels, midchannels=self.outchannels) 75 | 76 | self.upsample = nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True) 77 | 78 | self.init_params() 79 | 80 | def forward(self, low_x, high_x): 81 | x = self.ftb1(low_x) 82 | x = x + high_x 83 | x = self.ftb2(x) 84 | x = self.upsample(x) 85 | 86 | return x 87 | 88 | def init_params(self): 89 | for m in self.modules(): 90 | if isinstance(m, nn.Conv2d): 91 | #init.kaiming_normal_(m.weight, mode='fan_out') 92 | init.normal_(m.weight, std=0.01) 93 | #init.xavier_normal_(m.weight) 94 | if m.bias is not None: 95 | init.constant_(m.bias, 0) 96 | elif isinstance(m, nn.ConvTranspose2d): 97 | #init.kaiming_normal_(m.weight, mode='fan_out') 98 | init.normal_(m.weight, std=0.01) 99 | #init.xavier_normal_(m.weight) 100 | if m.bias is not None: 101 | init.constant_(m.bias, 0) 102 | elif isinstance(m, NN.BatchNorm2d): #NN.Batchnorm2d 103 | init.constant_(m.weight, 1) 104 | init.constant_(m.bias, 0) 105 | elif isinstance(m, nn.Linear): 106 | init.normal_(m.weight, std=0.01) 107 | if m.bias is not None: 108 | init.constant_(m.bias, 0) 109 | 110 | 111 | class AO(nn.Module): 112 | # Adaptive output module 113 | def __init__(self, inchannels, outchannels, upfactor=2): 114 | super(AO, self).__init__() 115 | self.inchannels = inchannels 116 | self.outchannels = outchannels 117 | self.upfactor = upfactor 118 | 119 | self.adapt_conv = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.inchannels/2, kernel_size=3, padding=1, stride=1, bias=True),\ 120 | NN.BatchNorm2d(num_features=self.inchannels/2),\ 121 | nn.ReLU(inplace=True),\ 122 | nn.Conv2d(in_channels=self.inchannels/2, out_channels=self.outchannels, kernel_size=3, padding=1, stride=1, bias=True),\ 123 | nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True)) 124 | 125 | self.init_params() 126 | 127 | def forward(self, x): 128 | x = self.adapt_conv(x) 129 | return x 130 | 131 | def init_params(self): 132 | for m in self.modules(): 133 | if isinstance(m, nn.Conv2d): 134 | #init.kaiming_normal_(m.weight, mode='fan_out') 135 | init.normal_(m.weight, std=0.01) 136 | #init.xavier_normal_(m.weight) 137 | if m.bias is not None: 138 | init.constant_(m.bias, 0) 139 | elif isinstance(m, nn.ConvTranspose2d): 140 | #init.kaiming_normal_(m.weight, mode='fan_out') 141 | init.normal_(m.weight, std=0.01) 142 | #init.xavier_normal_(m.weight) 143 | if m.bias is not None: 144 | init.constant_(m.bias, 0) 145 | elif isinstance(m, NN.BatchNorm2d): #NN.Batchnorm2d 146 | init.constant_(m.weight, 1) 147 | init.constant_(m.bias, 0) 148 | elif isinstance(m, nn.Linear): 149 | init.normal_(m.weight, std=0.01) 150 | if m.bias is not None: 151 | init.constant_(m.bias, 0) 152 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | import torchvision 5 | 6 | import sys 7 | sys.path.append('/data0/kexian/Code/kxian_Adobe/MPO_edgeGuidedRanking/models/syncbn') 8 | from modules import nn as NN 9 | 10 | 11 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 12 | 'resnet152'] 13 | 14 | 15 | model_urls = { 16 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 17 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 18 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 19 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 20 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 21 | } 22 | 23 | 24 | def conv3x3(in_planes, out_planes, stride=1): 25 | """3x3 convolution with padding""" 26 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 27 | padding=1, bias=False) 28 | 29 | 30 | class BasicBlock(nn.Module): 31 | expansion = 1 32 | 33 | def __init__(self, inplanes, planes, stride=1, downsample=None): 34 | super(BasicBlock, self).__init__() 35 | self.conv1 = conv3x3(inplanes, planes, stride) 36 | self.bn1 = NN.BatchNorm2d(planes) #NN.BatchNorm2d 37 | self.relu = nn.ReLU(inplace=True) 38 | self.conv2 = conv3x3(planes, planes) 39 | self.bn2 = NN.BatchNorm2d(planes) #NN.BatchNorm2d 40 | self.downsample = downsample 41 | self.stride = stride 42 | 43 | def forward(self, x): 44 | residual = x 45 | 46 | out = self.conv1(x) 47 | out = self.bn1(out) 48 | out = self.relu(out) 49 | 50 | out = self.conv2(out) 51 | out = self.bn2(out) 52 | 53 | if self.downsample is not None: 54 | residual = self.downsample(x) 55 | 56 | out += residual 57 | out = self.relu(out) 58 | 59 | return out 60 | 61 | 62 | class Bottleneck(nn.Module): 63 | expansion = 4 64 | 65 | def __init__(self, inplanes, planes, stride=1, downsample=None): 66 | super(Bottleneck, self).__init__() 67 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 68 | self.bn1 = NN.BatchNorm2d(planes) #NN.BatchNorm2d 69 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 70 | padding=1, bias=False) 71 | self.bn2 = NN.BatchNorm2d(planes) #NN.BatchNorm2d 72 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 73 | self.bn3 = NN.BatchNorm2d(planes * self.expansion) #NN.BatchNorm2d 74 | self.relu = nn.ReLU(inplace=True) 75 | self.downsample = downsample 76 | self.stride = stride 77 | 78 | def forward(self, x): 79 | residual = x 80 | 81 | out = self.conv1(x) 82 | out = self.bn1(out) 83 | out = self.relu(out) 84 | 85 | out = self.conv2(out) 86 | out = self.bn2(out) 87 | out = self.relu(out) 88 | 89 | out = self.conv3(out) 90 | out = self.bn3(out) 91 | 92 | if self.downsample is not None: 93 | residual = self.downsample(x) 94 | 95 | out += residual 96 | out = self.relu(out) 97 | 98 | return out 99 | 100 | 101 | class ResNet(nn.Module): 102 | 103 | def __init__(self, block, layers, num_classes=1000): 104 | self.inplanes = 64 105 | super(ResNet, self).__init__() 106 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 107 | bias=False) 108 | self.bn1 = NN.BatchNorm2d(64) #NN.BatchNorm2d 109 | self.relu = nn.ReLU(inplace=True) 110 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 111 | self.layer1 = self._make_layer(block, 64, layers[0]) 112 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 113 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 114 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 115 | #self.avgpool = nn.AvgPool2d(7, stride=1) 116 | #self.fc = nn.Linear(512 * block.expansion, num_classes) 117 | 118 | for m in self.modules(): 119 | if isinstance(m, nn.Conv2d): 120 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 121 | elif isinstance(m, nn.BatchNorm2d): 122 | nn.init.constant_(m.weight, 1) 123 | nn.init.constant_(m.bias, 0) 124 | 125 | def _make_layer(self, block, planes, blocks, stride=1): 126 | downsample = None 127 | if stride != 1 or self.inplanes != planes * block.expansion: 128 | downsample = nn.Sequential( 129 | nn.Conv2d(self.inplanes, planes * block.expansion, 130 | kernel_size=1, stride=stride, bias=False), 131 | NN.BatchNorm2d(planes * block.expansion), #NN.BatchNorm2d 132 | ) 133 | 134 | layers = [] 135 | layers.append(block(self.inplanes, planes, stride, downsample)) 136 | self.inplanes = planes * block.expansion 137 | for i in range(1, blocks): 138 | layers.append(block(self.inplanes, planes)) 139 | 140 | return nn.Sequential(*layers) 141 | 142 | def forward(self, x): 143 | features = [] 144 | 145 | x = self.conv1(x) 146 | x = self.bn1(x) 147 | x = self.relu(x) 148 | x = self.maxpool(x) 149 | 150 | x = self.layer1(x) 151 | features.append(x) 152 | x = self.layer2(x) 153 | features.append(x) 154 | x = self.layer3(x) 155 | features.append(x) 156 | x = self.layer4(x) 157 | features.append(x) 158 | 159 | return features 160 | 161 | 162 | def resnet18(pretrained=True, **kwargs): 163 | """Constructs a ResNet-18 model. 164 | Args: 165 | pretrained (bool): If True, returns a model pre-trained on ImageNet 166 | """ 167 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 168 | if pretrained: 169 | pretrained_model = torchvision.models.resnet18(pretrained=True) 170 | pretrained_dict = pretrained_model.state_dict() 171 | model_dict = model.state_dict() 172 | pretrained_dict = {k:v for k, v in pretrained_dict.items() if k in model_dict} 173 | model_dict.update(pretrained_dict) 174 | model.load_state_dict(model_dict) 175 | 176 | return model 177 | 178 | 179 | def resnet34(pretrained=True, **kwargs): 180 | """Constructs a ResNet-34 model. 181 | Args: 182 | pretrained (bool): If True, returns a model pre-trained on ImageNet 183 | """ 184 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 185 | if pretrained: 186 | pretrained_model = torchvision.models.resnet34(pretrained=True) 187 | pretrained_dict = pretrained_model.state_dict() 188 | model_dict = model.state_dict() 189 | pretrained_dict = {k:v for k, v in pretrained_dict.items() if k in model_dict} 190 | model_dict.update(pretrained_dict) 191 | model.load_state_dict(model_dict) 192 | 193 | return model 194 | 195 | 196 | def resnet50(pretrained=True, **kwargs): 197 | """Constructs a ResNet-50 model. 198 | Args: 199 | pretrained (bool): If True, returns a model pre-trained on ImageNet 200 | """ 201 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 202 | if pretrained: 203 | pretrained_model = torchvision.models.resnet50(pretrained=True) 204 | pretrained_dict = pretrained_model.state_dict() 205 | model_dict = model.state_dict() 206 | pretrained_dict = {k:v for k, v in pretrained_dict.items() if k in model_dict} 207 | model_dict.update(pretrained_dict) 208 | model.load_state_dict(model_dict) 209 | 210 | return model 211 | 212 | 213 | def resnet101(pretrained=True, **kwargs): 214 | """Constructs a ResNet-101 model. 215 | Args: 216 | pretrained (bool): If True, returns a model pre-trained on ImageNet 217 | """ 218 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 219 | if pretrained: 220 | pretrained_model = torchvision.models.resnet101(pretrained=True) 221 | pretrained_dict = pretrained_model.state_dict() 222 | model_dict = model.state_dict() 223 | pretrained_dict = {k:v for k, v in pretrained_dict.items() if k in model_dict} 224 | model_dict.update(pretrained_dict) 225 | model.load_state_dict(model_dict) 226 | 227 | return model 228 | 229 | 230 | def resnet152(pretrained=True, **kwargs): 231 | """Constructs a ResNet-152 model. 232 | Args: 233 | pretrained (bool): If True, returns a model pre-trained on ImageNet 234 | """ 235 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 236 | if pretrained: 237 | pretrained_model = torchvision.models.resnet152(pretrained=True) 238 | pretrained_dict = pretrained_model.state_dict() 239 | model_dict = model.state_dict() 240 | pretrained_dict = {k:v for k, v in pretrained_dict.items() if k in model_dict} 241 | model_dict.update(pretrained_dict) 242 | model.load_state_dict(model_dict) 243 | 244 | return model 245 | -------------------------------------------------------------------------------- /models/syncbn/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Tamaki Kojima 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /models/syncbn/README.md: -------------------------------------------------------------------------------- 1 | # pytorch-syncbn 2 | 3 | Tamaki Kojima(tamakoji@gmail.com) 4 | 5 | ## Overview 6 | This is alternative implementation of "Synchronized Multi-GPU Batch Normalization" which computes global stats across gpus instead of locally computed. SyncBN are getting important for those input image is large, and must use multi-gpu to increase the minibatch-size for the training. 7 | 8 | The code was inspired by [Pytorch-Encoding](https://github.com/zhanghang1989/PyTorch-Encoding) and [Inplace-ABN](https://github.com/mapillary/inplace_abn) 9 | 10 | ## Remarks 11 | - Unlike [Pytorch-Encoding](https://github.com/zhanghang1989/PyTorch-Encoding), you don't need custom `nn.DataParallel` 12 | - Unlike [Inplace-ABN](https://github.com/mapillary/inplace_abn), you can just replace your `nn.BatchNorm2d` to this module implementation, since it will not mark for inplace operation 13 | - You can plug into arbitrary module written in PyTorch to enable Synchronized BatchNorm 14 | - Backward computation is rewritten and tested against behavior of `nn.BatchNorm2d` 15 | 16 | ## Requirements 17 | For PyTorch, please refer to https://pytorch.org/ 18 | 19 | NOTE : The code is tested only with PyTorch v0.4.0, CUDA9.1.85/CuDNN7.1.4 on ubuntu16.04 20 | 21 | (It can also be compiled and run on the JetsonTX2, but won't work as multi-gpu synchronnized BN.) 22 | 23 | To install all dependencies using pip, run: 24 | 25 | ``` 26 | pip install -U -r requirements.txt 27 | ``` 28 | 29 | ## Build 30 | 31 | use `make_ext.sh` to build the extension. for example: 32 | ``` 33 | PYTHON_CMD=python3 ./make_ext.sh 34 | ``` 35 | 36 | ## Usage 37 | 38 | Please refer to [`test.py`](./test.py) for testing the difference between `nn.BatchNorm2d` and `modules.nn.BatchNorm2d` 39 | 40 | ``` 41 | import torch 42 | from modules import nn as NN 43 | num_gpu = torch.cuda.device_count() 44 | model = nn.Sequential( 45 | nn.Conv2d(3, 3, 1, 1, bias=False), 46 | NN.BatchNorm2d(3), 47 | nn.ReLU(inplace=True), 48 | nn.Conv2d(3, 3, 1, 1, bias=False), 49 | NN.BatchNorm2d(3), 50 | ).cuda() 51 | model = nn.DataParallel(model, device_ids=range(num_gpu)) 52 | x = torch.rand(num_gpu, 3, 2, 2).cuda() 53 | z = model(x) 54 | ``` 55 | 56 | ## Math 57 | 58 | ### Forward 59 | 1. compute in each gpu 60 | 2. gather all from workers to master and compute where 61 | 62 | 63 | 64 | and 65 | 66 | 67 | 68 | and then above global stats to be shared to all gpus, update running_mean and running_var by moving average using global stats. 69 | 70 | 3. forward batchnorm using global stats by 71 | 72 | 73 | 74 | and then 75 | 76 | 77 | 78 | where is weight parameter and is bias parameter. 79 | 80 | 4. save for backward 81 | 82 | ### Backward 83 | 84 | 1. Restore saved 85 | 86 | 2. Compute below sums on each gpu 87 | 88 | 89 | 90 | and 91 | 92 | 93 | 94 | where 95 | 96 | then gather them at master node to sum up global, and normalize with N where N is total number of elements for each channels. Global sums are then shared among all gpus. 97 | 98 | 3. compute gradients using global stats 99 | 100 | 101 | 102 | where 103 | 104 | 105 | 106 | and 107 | 108 | 109 | 110 | and finally, 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | Note that in the implementation, normalization with N is performed at step (2) and above equation and implementation is not exactly the same, but mathematically is same. 119 | 120 | You can go deeper on above explanation at [Kevin Zakka's Blog](https://kevinzakka.github.io/2016/09/14/batch_normalization/) -------------------------------------------------------------------------------- /models/syncbn/make_ext.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | PYTHON_CMD=${PYTHON_CMD:=python} 4 | CUDA_PATH=/usr/local/cuda-8.0 5 | CUDA_INCLUDE_DIR=/usr/local/cuda-8.0/include 6 | GENCODE="-gencode arch=compute_61,code=sm_61 \ 7 | -gencode arch=compute_52,code=sm_52 \ 8 | -gencode arch=compute_52,code=compute_52" 9 | NVCCOPT="-std=c++11 -x cu --expt-extended-lambda -O3 -Xcompiler -fPIC" 10 | 11 | ROOTDIR=$PWD 12 | echo "========= Build BatchNorm2dSync =========" 13 | if [ -z "$1" ]; then TORCH=$($PYTHON_CMD -c "import os; import torch; print(os.path.dirname(torch.__file__))"); else TORCH="$1"; fi 14 | cd modules/functional/_syncbn/src 15 | $CUDA_PATH/bin/nvcc -c -o syncbn.cu.o syncbn.cu $NVCCOPT $GENCODE -I $CUDA_INCLUDE_DIR 16 | cd ../ 17 | $PYTHON_CMD build.py 18 | cd $ROOTDIR 19 | 20 | # END 21 | echo "========= Build Complete =========" 22 | -------------------------------------------------------------------------------- /models/syncbn/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KexianHust/Structure-Guided-Ranking-Loss/6fa853da2aeb53ad0a66e95484a5ae7fc816bb4a/models/syncbn/modules/__init__.py -------------------------------------------------------------------------------- /models/syncbn/modules/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KexianHust/Structure-Guided-Ranking-Loss/6fa853da2aeb53ad0a66e95484a5ae7fc816bb4a/models/syncbn/modules/__init__.pyc -------------------------------------------------------------------------------- /models/syncbn/modules/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KexianHust/Structure-Guided-Ranking-Loss/6fa853da2aeb53ad0a66e95484a5ae7fc816bb4a/models/syncbn/modules/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/syncbn/modules/functional/__init__.py: -------------------------------------------------------------------------------- 1 | from .syncbn import batchnorm2d_sync 2 | -------------------------------------------------------------------------------- /models/syncbn/modules/functional/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KexianHust/Structure-Guided-Ranking-Loss/6fa853da2aeb53ad0a66e95484a5ae7fc816bb4a/models/syncbn/modules/functional/__init__.pyc -------------------------------------------------------------------------------- /models/syncbn/modules/functional/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KexianHust/Structure-Guided-Ranking-Loss/6fa853da2aeb53ad0a66e95484a5ae7fc816bb4a/models/syncbn/modules/functional/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/syncbn/modules/functional/__pycache__/syncbn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KexianHust/Structure-Guided-Ranking-Loss/6fa853da2aeb53ad0a66e95484a5ae7fc816bb4a/models/syncbn/modules/functional/__pycache__/syncbn.cpython-37.pyc -------------------------------------------------------------------------------- /models/syncbn/modules/functional/_syncbn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KexianHust/Structure-Guided-Ranking-Loss/6fa853da2aeb53ad0a66e95484a5ae7fc816bb4a/models/syncbn/modules/functional/_syncbn/__init__.py -------------------------------------------------------------------------------- /models/syncbn/modules/functional/_syncbn/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KexianHust/Structure-Guided-Ranking-Loss/6fa853da2aeb53ad0a66e95484a5ae7fc816bb4a/models/syncbn/modules/functional/_syncbn/__init__.pyc -------------------------------------------------------------------------------- /models/syncbn/modules/functional/_syncbn/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KexianHust/Structure-Guided-Ranking-Loss/6fa853da2aeb53ad0a66e95484a5ae7fc816bb4a/models/syncbn/modules/functional/_syncbn/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/syncbn/modules/functional/_syncbn/_ext/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KexianHust/Structure-Guided-Ranking-Loss/6fa853da2aeb53ad0a66e95484a5ae7fc816bb4a/models/syncbn/modules/functional/_syncbn/_ext/__init__.py -------------------------------------------------------------------------------- /models/syncbn/modules/functional/_syncbn/_ext/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KexianHust/Structure-Guided-Ranking-Loss/6fa853da2aeb53ad0a66e95484a5ae7fc816bb4a/models/syncbn/modules/functional/_syncbn/_ext/__init__.pyc -------------------------------------------------------------------------------- /models/syncbn/modules/functional/_syncbn/_ext/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KexianHust/Structure-Guided-Ranking-Loss/6fa853da2aeb53ad0a66e95484a5ae7fc816bb4a/models/syncbn/modules/functional/_syncbn/_ext/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/syncbn/modules/functional/_syncbn/_ext/syncbn/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.utils.ffi import _wrap_function 3 | from ._syncbn import lib as _lib, ffi as _ffi 4 | 5 | __all__ = [] 6 | def _import_symbols(locals): 7 | for symbol in dir(_lib): 8 | fn = getattr(_lib, symbol) 9 | if callable(fn): 10 | locals[symbol] = _wrap_function(fn, _ffi) 11 | else: 12 | locals[symbol] = fn 13 | __all__.append(symbol) 14 | 15 | _import_symbols(locals()) 16 | -------------------------------------------------------------------------------- /models/syncbn/modules/functional/_syncbn/_ext/syncbn/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KexianHust/Structure-Guided-Ranking-Loss/6fa853da2aeb53ad0a66e95484a5ae7fc816bb4a/models/syncbn/modules/functional/_syncbn/_ext/syncbn/__init__.pyc -------------------------------------------------------------------------------- /models/syncbn/modules/functional/_syncbn/_ext/syncbn/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KexianHust/Structure-Guided-Ranking-Loss/6fa853da2aeb53ad0a66e95484a5ae7fc816bb4a/models/syncbn/modules/functional/_syncbn/_ext/syncbn/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/syncbn/modules/functional/_syncbn/_ext/syncbn/_syncbn.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KexianHust/Structure-Guided-Ranking-Loss/6fa853da2aeb53ad0a66e95484a5ae7fc816bb4a/models/syncbn/modules/functional/_syncbn/_ext/syncbn/_syncbn.so -------------------------------------------------------------------------------- /models/syncbn/modules/functional/_syncbn/build.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.ffi import create_extension 3 | 4 | sources = ['src/syncbn.cpp'] 5 | headers = ['src/syncbn.h'] 6 | extra_objects = ['src/syncbn.cu.o'] 7 | with_cuda = True 8 | 9 | this_file = os.path.dirname(os.path.realpath(__file__)) 10 | extra_objects = [os.path.join(this_file, fname) for fname in extra_objects] 11 | 12 | ffi = create_extension( 13 | '_ext.syncbn', 14 | headers=headers, 15 | sources=sources, 16 | relative_to=__file__, 17 | with_cuda=with_cuda, 18 | extra_objects=extra_objects, 19 | extra_compile_args=["-std=c++11"] 20 | ) 21 | 22 | if __name__ == '__main__': 23 | ffi.build() 24 | -------------------------------------------------------------------------------- /models/syncbn/modules/functional/_syncbn/src/common.h: -------------------------------------------------------------------------------- 1 | #ifndef __COMMON__ 2 | #define __COMMON__ 3 | #include 4 | 5 | /* 6 | * General settings 7 | */ 8 | const int WARP_SIZE = 32; 9 | const int MAX_BLOCK_SIZE = 512; 10 | 11 | /* 12 | * Utility functions 13 | */ 14 | template 15 | __device__ __forceinline__ T WARP_SHFL_XOR( 16 | T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) { 17 | #if CUDART_VERSION >= 9000 18 | return __shfl_xor_sync(mask, value, laneMask, width); 19 | #else 20 | return __shfl_xor(value, laneMask, width); 21 | #endif 22 | } 23 | 24 | __device__ __forceinline__ int getMSB(int val) { return 31 - __clz(val); } 25 | 26 | static int getNumThreads(int nElem) { 27 | int threadSizes[5] = {32, 64, 128, 256, MAX_BLOCK_SIZE}; 28 | for (int i = 0; i != 5; ++i) { 29 | if (nElem <= threadSizes[i]) { 30 | return threadSizes[i]; 31 | } 32 | } 33 | return MAX_BLOCK_SIZE; 34 | } 35 | 36 | 37 | #endif -------------------------------------------------------------------------------- /models/syncbn/modules/functional/_syncbn/src/syncbn.cpp: -------------------------------------------------------------------------------- 1 | // All functions assume that input and output tensors are already initialized 2 | // and have the correct dimensions 3 | #include 4 | 5 | extern THCState *state; 6 | 7 | void get_sizes(const THCudaTensor *t, int *N, int *C, int *S) { 8 | // Get sizes 9 | *S = 1; 10 | *N = THCudaTensor_size(state, t, 0); 11 | *C = THCudaTensor_size(state, t, 1); 12 | if (THCudaTensor_nDimension(state, t) > 2) { 13 | for (int i = 2; i < THCudaTensor_nDimension(state, t); ++i) { 14 | *S *= THCudaTensor_size(state, t, i); 15 | } 16 | } 17 | } 18 | 19 | // Forward definition of implementation functions 20 | extern "C" { 21 | int _syncbn_sum_sqsum_cuda(int N, int C, int S, 22 | const float *x, float *sum, float *sqsum, 23 | cudaStream_t stream); 24 | int _syncbn_forward_cuda( 25 | int N, int C, int S, float *z, const float *x, 26 | const float *gamma, const float *beta, 27 | const float *mean, const float *var, float eps, cudaStream_t stream); 28 | int _syncbn_backward_xhat_cuda( 29 | int N, int C, int S, const float *dz, const float *x, 30 | const float *mean, const float *var, float *sum_dz, float *sum_dz_xhat, 31 | float eps, cudaStream_t stream); 32 | int _syncbn_backward_cuda( 33 | int N, int C, int S, const float *dz, const float *x, 34 | const float *gamma, const float *beta, 35 | const float *mean, const float *var, 36 | const float *sum_dz, const float *sum_dz_xhat, 37 | float *dx, float *dgamma, float *dbeta, 38 | float eps, cudaStream_t stream); 39 | } 40 | 41 | extern "C" int syncbn_sum_sqsum_cuda( 42 | const THCudaTensor *x, THCudaTensor *sum, THCudaTensor *sqsum) { 43 | cudaStream_t stream = THCState_getCurrentStream(state); 44 | 45 | int S, N, C; 46 | get_sizes(x, &N, &C, &S); 47 | 48 | // Get pointers 49 | const float *x_data = THCudaTensor_data(state, x); 50 | float *sum_data = THCudaTensor_data(state, sum); 51 | float *sqsum_data = THCudaTensor_data(state, sqsum); 52 | 53 | return _syncbn_sum_sqsum_cuda(N, C, S, x_data, sum_data, sqsum_data, stream); 54 | } 55 | 56 | extern "C" int syncbn_forward_cuda( 57 | THCudaTensor *z, const THCudaTensor *x, 58 | const THCudaTensor *gamma, const THCudaTensor *beta, 59 | const THCudaTensor *mean, const THCudaTensor *var, float eps){ 60 | cudaStream_t stream = THCState_getCurrentStream(state); 61 | 62 | int S, N, C; 63 | get_sizes(x, &N, &C, &S); 64 | 65 | // Get pointers 66 | float *z_data = THCudaTensor_data(state, z); 67 | const float *x_data = THCudaTensor_data(state, x); 68 | const float *gamma_data = THCudaTensor_nDimension(state, gamma) != 0 ? 69 | THCudaTensor_data(state, gamma) : 0; 70 | const float *beta_data = THCudaTensor_nDimension(state, beta) != 0 ? 71 | THCudaTensor_data(state, beta) : 0; 72 | const float *mean_data = THCudaTensor_data(state, mean); 73 | const float *var_data = THCudaTensor_data(state, var); 74 | 75 | return _syncbn_forward_cuda( 76 | N, C, S, z_data, x_data, gamma_data, beta_data, 77 | mean_data, var_data, eps, stream); 78 | 79 | } 80 | 81 | extern "C" int syncbn_backward_xhat_cuda( 82 | const THCudaTensor *dz, const THCudaTensor *x, 83 | const THCudaTensor *mean, const THCudaTensor *var, 84 | THCudaTensor *sum_dz, THCudaTensor *sum_dz_xhat, float eps) { 85 | cudaStream_t stream = THCState_getCurrentStream(state); 86 | 87 | int S, N, C; 88 | get_sizes(dz, &N, &C, &S); 89 | 90 | // Get pointers 91 | const float *dz_data = THCudaTensor_data(state, dz); 92 | const float *x_data = THCudaTensor_data(state, x); 93 | const float *mean_data = THCudaTensor_data(state, mean); 94 | const float *var_data = THCudaTensor_data(state, var); 95 | float *sum_dz_data = THCudaTensor_data(state, sum_dz); 96 | float *sum_dz_xhat_data = THCudaTensor_data(state, sum_dz_xhat); 97 | 98 | return _syncbn_backward_xhat_cuda( 99 | N, C, S, dz_data, x_data, mean_data, var_data, 100 | sum_dz_data, sum_dz_xhat_data, eps, stream); 101 | 102 | } 103 | extern "C" int syncbn_backard_cuda( 104 | const THCudaTensor *dz, const THCudaTensor *x, 105 | const THCudaTensor *gamma, const THCudaTensor *beta, 106 | const THCudaTensor *mean, const THCudaTensor *var, 107 | const THCudaTensor *sum_dz, const THCudaTensor *sum_dz_xhat, 108 | THCudaTensor *dx, THCudaTensor *dgamma, THCudaTensor *dbeta, float eps) { 109 | cudaStream_t stream = THCState_getCurrentStream(state); 110 | 111 | int S, N, C; 112 | get_sizes(dz, &N, &C, &S); 113 | 114 | // Get pointers 115 | const float *dz_data = THCudaTensor_data(state, dz); 116 | const float *x_data = THCudaTensor_data(state, x); 117 | const float *gamma_data = THCudaTensor_nDimension(state, gamma) != 0 ? 118 | THCudaTensor_data(state, gamma) : 0; 119 | const float *beta_data = THCudaTensor_nDimension(state, beta) != 0 ? 120 | THCudaTensor_data(state, beta) : 0; 121 | const float *mean_data = THCudaTensor_data(state, mean); 122 | const float *var_data = THCudaTensor_data(state, var); 123 | const float *sum_dz_data = THCudaTensor_data(state, sum_dz); 124 | const float *sum_dz_xhat_data = THCudaTensor_data(state, sum_dz_xhat); 125 | float *dx_data = THCudaTensor_nDimension(state, dx) != 0 ? 126 | THCudaTensor_data(state, dx) : 0; 127 | float *dgamma_data = THCudaTensor_nDimension(state, dgamma) != 0 ? 128 | THCudaTensor_data(state, dgamma) : 0; 129 | float *dbeta_data = THCudaTensor_nDimension(state, dbeta) != 0 ? 130 | THCudaTensor_data(state, dbeta) : 0; 131 | 132 | return _syncbn_backward_cuda( 133 | N, C, S, dz_data, x_data, gamma_data, beta_data, 134 | mean_data, var_data, sum_dz_data, sum_dz_xhat_data, 135 | dx_data, dgamma_data, dbeta_data, eps, stream); 136 | } -------------------------------------------------------------------------------- /models/syncbn/modules/functional/_syncbn/src/syncbn.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "common.h" 6 | #include "syncbn.cu.h" 7 | 8 | /* 9 | * Device functions and data structures 10 | */ 11 | struct Float2 { 12 | float v1, v2; 13 | __device__ Float2() {} 14 | __device__ Float2(float _v1, float _v2) : v1(_v1), v2(_v2) {} 15 | __device__ Float2(float v) : v1(v), v2(v) {} 16 | __device__ Float2(int v) : v1(v), v2(v) {} 17 | __device__ Float2 &operator+=(const Float2 &a) { 18 | v1 += a.v1; 19 | v2 += a.v2; 20 | return *this; 21 | } 22 | }; 23 | 24 | struct GradOp { 25 | __device__ GradOp(float _gamma, float _beta, const float *_z, 26 | const float *_dz, int c, int s) 27 | : gamma(_gamma), beta(_beta), z(_z), dz(_dz), C(c), S(s) {} 28 | __device__ __forceinline__ Float2 operator()(int batch, int plane, int n) { 29 | float _y = (z[(batch * C + plane) * S + n] - beta) / gamma; 30 | float _dz = dz[(batch * C + plane) * S + n]; 31 | return Float2(_dz, _y * _dz); 32 | } 33 | const float gamma; 34 | const float beta; 35 | const float *z; 36 | const float *dz; 37 | const int C; 38 | const int S; 39 | }; 40 | 41 | static __device__ __forceinline__ float warpSum(float val) { 42 | #if __CUDA_ARCH__ >= 300 43 | for (int i = 0; i < getMSB(WARP_SIZE); ++i) { 44 | val += WARP_SHFL_XOR(val, 1 << i, WARP_SIZE); 45 | } 46 | #else 47 | __shared__ float values[MAX_BLOCK_SIZE]; 48 | values[threadIdx.x] = val; 49 | __threadfence_block(); 50 | const int base = (threadIdx.x / WARP_SIZE) * WARP_SIZE; 51 | for (int i = 1; i < WARP_SIZE; i++) { 52 | val += values[base + ((i + threadIdx.x) % WARP_SIZE)]; 53 | } 54 | #endif 55 | return val; 56 | } 57 | 58 | static __device__ __forceinline__ Float2 warpSum(Float2 value) { 59 | value.v1 = warpSum(value.v1); 60 | value.v2 = warpSum(value.v2); 61 | return value; 62 | } 63 | 64 | template 65 | __device__ T reduce(Op op, int plane, int N, int C, int S) { 66 | T sum = (T)0; 67 | for (int batch = 0; batch < N; ++batch) { 68 | for (int x = threadIdx.x; x < S; x += blockDim.x) { 69 | sum += op(batch, plane, x); 70 | } 71 | } 72 | 73 | // sum over NumThreads within a warp 74 | sum = warpSum(sum); 75 | 76 | // 'transpose', and reduce within warp again 77 | __shared__ T shared[32]; 78 | __syncthreads(); 79 | if (threadIdx.x % WARP_SIZE == 0) { 80 | shared[threadIdx.x / WARP_SIZE] = sum; 81 | } 82 | if (threadIdx.x >= blockDim.x / WARP_SIZE && threadIdx.x < WARP_SIZE) { 83 | // zero out the other entries in shared 84 | shared[threadIdx.x] = (T)0; 85 | } 86 | __syncthreads(); 87 | if (threadIdx.x / WARP_SIZE == 0) { 88 | sum = warpSum(shared[threadIdx.x]); 89 | if (threadIdx.x == 0) { 90 | shared[0] = sum; 91 | } 92 | } 93 | __syncthreads(); 94 | 95 | // Everyone picks it up, should be broadcast into the whole gradInput 96 | return shared[0]; 97 | } 98 | 99 | /*---------------------------------------------------------------------------- 100 | * 101 | * BatchNorm2dSyncFunc Kernel implementations 102 | * 103 | *---------------------------------------------------------------------------*/ 104 | 105 | struct SqSumOp { 106 | __device__ SqSumOp(const float *t, int c, int s) 107 | : tensor(t), C(c), S(s) {} 108 | __device__ __forceinline__ Float2 operator()(int batch, int plane, int n) { 109 | float t = tensor[(batch * C + plane) * S + n]; 110 | return Float2(t, t * t); 111 | } 112 | const float *tensor; 113 | const int C; 114 | const int S; 115 | }; 116 | 117 | struct XHatOp { 118 | __device__ XHatOp(float _gamma, float _beta, const float *_z, 119 | const float *_dz, int c, int s) 120 | : gamma(_gamma), beta(_beta), z(_z), dz(_dz), C(c), S(s) {} 121 | __device__ __forceinline__ Float2 operator()(int batch, int plane, int n) { 122 | // xhat = (x-beta)*gamma 123 | float _xhat = (z[(batch * C + plane) * S + n] - beta) * gamma; 124 | // for dxhat*x_hat 125 | float _dz = dz[(batch * C + plane) * S + n]; 126 | return Float2(_dz, _dz * _xhat); 127 | } 128 | const float gamma; 129 | const float beta; 130 | const float *z; 131 | const float *dz; 132 | const int C; 133 | const int S; 134 | }; 135 | 136 | __global__ void syncbn_sum_sqsum_kernel(const float *x, float *sum, float *sqsum, 137 | int N, int C, int S) { 138 | int plane = blockIdx.x; 139 | Float2 res = reduce(SqSumOp(x, C, S), plane, N, C, S); 140 | float _sum = res.v1; 141 | float _sqsum = res.v2; 142 | __syncthreads(); 143 | if (threadIdx.x == 0) { 144 | sum[plane] = _sum; 145 | sqsum[plane] = _sqsum; 146 | } 147 | } 148 | 149 | __global__ void syncbn_forward_kernel( 150 | float *z, const float *x, const float *gamma, const float *beta, 151 | const float *mean, const float *var, float eps, int N, int C, int S) { 152 | 153 | int c = blockIdx.x; 154 | float _mean = mean[c]; 155 | float _var = var[c]; 156 | float invtsd = 0; 157 | if (_var != 0.f || eps != 0.f) { 158 | invtsd = 1 / sqrt(_var + eps); 159 | } 160 | float _gamma = gamma != 0 ? gamma[c] : 1.f; 161 | float _beta = beta != 0 ? beta[c] : 0.f; 162 | for (int batch = 0; batch < N; ++batch) { 163 | for (int n = threadIdx.x; n < S; n += blockDim.x) { 164 | float _x = x[(batch * C + c) * S + n]; 165 | float _xhat = (_x - _mean) * invtsd; 166 | float _z = _xhat * _gamma + _beta; 167 | z[(batch * C + c) * S + n] = _z; 168 | } 169 | } 170 | } 171 | 172 | __global__ void syncbn_backward_xhat_kernel( 173 | const float *dz, const float *x, const float *mean, const float *var, 174 | float *sum_dz, float *sum_dz_xhat, float eps, int N, int C, int S) { 175 | 176 | int c = blockIdx.x; 177 | float _mean = mean[c]; 178 | float _var = var[c]; 179 | float _invstd = 0; 180 | if (_var != 0.f || eps != 0.f) { 181 | _invstd = 1 / sqrt(_var + eps); 182 | } 183 | Float2 res = reduce( 184 | XHatOp(_invstd, _mean, x, dz, C, S), c, N, C, S); 185 | // \sum(\frac{dJ}{dy_i}) 186 | float _sum_dz = res.v1; 187 | // \sum(\frac{dJ}{dy_i}*\hat{x_i}) 188 | float _sum_dz_xhat = res.v2; 189 | __syncthreads(); 190 | if (threadIdx.x == 0) { 191 | // \sum(\frac{dJ}{dy_i}) 192 | sum_dz[c] = _sum_dz; 193 | // \sum(\frac{dJ}{dy_i}*\hat{x_i}) 194 | sum_dz_xhat[c] = _sum_dz_xhat; 195 | } 196 | } 197 | 198 | 199 | __global__ void syncbn_backward_kernel( 200 | const float *dz, const float *x, const float *gamma, const float *beta, 201 | const float *mean, const float *var, 202 | const float *sum_dz, const float *sum_dz_xhat, 203 | float *dx, float *dgamma, float *dbeta, 204 | float eps, int N, int C, int S) { 205 | 206 | int c = blockIdx.x; 207 | float _mean = mean[c]; 208 | float _var = var[c]; 209 | float _gamma = gamma != 0 ? gamma[c] : 1.f; 210 | float _sum_dz = sum_dz[c]; 211 | float _sum_dz_xhat = sum_dz_xhat[c]; 212 | float _invstd = 0; 213 | if (_var != 0.f || eps != 0.f) { 214 | _invstd = 1 / sqrt(_var + eps); 215 | } 216 | /* 217 | \frac{dJ}{dx_i} = \frac{1}{N\sqrt{(\sigma^2+\epsilon)}} ( 218 | N\frac{dJ}{d\hat{x_i}} - 219 | \sum_{j=1}^{N}(\frac{dJ}{d\hat{x_j}}) - 220 | \hat{x_i}\sum_{j=1}^{N}(\frac{dJ}{d\hat{x_j}}\hat{x_j}) 221 | ) 222 | Note : N is omitted here since it will be accumulated and 223 | _sum_dz and _sum_dz_xhat expected to be already normalized 224 | before the call. 225 | */ 226 | if (dx != 0) { 227 | float _mul = _gamma * _invstd; 228 | for (int batch = 0; batch < N; ++batch) { 229 | for (int n = threadIdx.x; n < S; n += blockDim.x) { 230 | float _dz = dz[(batch * C + c) * S + n]; 231 | float _xhat = (x[(batch * C + c) * S + n] - _mean) * _invstd; 232 | float _dx = (_dz - _sum_dz - _xhat * _sum_dz_xhat) * _mul; 233 | dx[(batch * C + c) * S + n] = _dx; 234 | } 235 | } 236 | } 237 | float _norm = N * S; 238 | if (dgamma != 0) { 239 | if (threadIdx.x == 0) { 240 | // \frac{dJ}{d\gamma} = \sum(\frac{dJ}{dy_i}*\hat{x_i}) 241 | dgamma[c] += _sum_dz_xhat * _norm; 242 | } 243 | } 244 | if (dbeta != 0) { 245 | if (threadIdx.x == 0) { 246 | // \frac{dJ}{d\beta} = \sum(\frac{dJ}{dy_i}) 247 | dbeta[c] += _sum_dz * _norm; 248 | } 249 | } 250 | } 251 | 252 | extern "C" int _syncbn_sum_sqsum_cuda(int N, int C, int S, 253 | const float *x, float *sum, float *sqsum, 254 | cudaStream_t stream) { 255 | // Run kernel 256 | dim3 blocks(C); 257 | dim3 threads(getNumThreads(S)); 258 | syncbn_sum_sqsum_kernel<<>>(x, sum, sqsum, N, C, S); 259 | 260 | // Check for errors 261 | cudaError_t err = cudaGetLastError(); 262 | if (err != cudaSuccess) 263 | return 0; 264 | else 265 | return 1; 266 | } 267 | 268 | extern "C" int _syncbn_forward_cuda( 269 | int N, int C, int S, float *z, const float *x, 270 | const float *gamma, const float *beta, const float *mean, const float *var, 271 | float eps, cudaStream_t stream) { 272 | 273 | // Run kernel 274 | dim3 blocks(C); 275 | dim3 threads(getNumThreads(S)); 276 | syncbn_forward_kernel<<>>( 277 | z, x, gamma, beta, mean, var, eps, N, C, S); 278 | 279 | // Check for errors 280 | cudaError_t err = cudaGetLastError(); 281 | if (err != cudaSuccess) 282 | return 0; 283 | else 284 | return 1; 285 | } 286 | 287 | 288 | extern "C" int _syncbn_backward_xhat_cuda( 289 | int N, int C, int S, const float *dz, const float *x, 290 | const float *mean, const float *var, float *sum_dz, float *sum_dz_xhat, 291 | float eps, cudaStream_t stream) { 292 | 293 | // Run kernel 294 | dim3 blocks(C); 295 | dim3 threads(getNumThreads(S)); 296 | syncbn_backward_xhat_kernel<<>>( 297 | dz, x,mean, var, sum_dz, sum_dz_xhat, eps, N, C, S); 298 | 299 | // Check for errors 300 | cudaError_t err = cudaGetLastError(); 301 | if (err != cudaSuccess) 302 | return 0; 303 | else 304 | return 1; 305 | } 306 | 307 | 308 | extern "C" int _syncbn_backward_cuda( 309 | int N, int C, int S, const float *dz, const float *x, 310 | const float *gamma, const float *beta, const float *mean, const float *var, 311 | const float *sum_dz, const float *sum_dz_xhat, 312 | float *dx, float *dgamma, float *dbeta, float eps, cudaStream_t stream) { 313 | 314 | // Run kernel 315 | dim3 blocks(C); 316 | dim3 threads(getNumThreads(S)); 317 | syncbn_backward_kernel<<>>( 318 | dz, x, gamma, beta, mean, var, sum_dz, sum_dz_xhat, 319 | dx, dgamma, dbeta, eps, N, C, S); 320 | 321 | // Check for errors 322 | cudaError_t err = cudaGetLastError(); 323 | if (err != cudaSuccess) 324 | return 0; 325 | else 326 | return 1; 327 | } 328 | 329 | -------------------------------------------------------------------------------- /models/syncbn/modules/functional/_syncbn/src/syncbn.cu.h: -------------------------------------------------------------------------------- 1 | #ifndef __SYNCBN__ 2 | #define __SYNCBN__ 3 | 4 | /* 5 | * Exported functions 6 | */ 7 | extern "C" int _syncbn_sum_sqsum_cuda(int N, int C, int S, const float *x, 8 | float *sum, float *sqsum, 9 | cudaStream_t stream); 10 | extern "C" int _syncbn_forward_cuda( 11 | int N, int C, int S, float *z, const float *x, 12 | const float *gamma, const float *beta, const float *mean, const float *var, 13 | float eps, cudaStream_t stream); 14 | extern "C" int _syncbn_backward_xhat_cuda( 15 | int N, int C, int S, const float *dz, const float *x, 16 | const float *mean, const float *var, float *sum_dz, float *sum_dz_xhat, 17 | float eps, cudaStream_t stream); 18 | extern "C" int _syncbn_backward_cuda( 19 | int N, int C, int S, const float *dz, const float *x, 20 | const float *gamma, const float *beta, const float *mean, const float *var, 21 | const float *sum_dz, const float *sum_dz_xhat, 22 | float *dx, float *dweight, float *dbias, 23 | float eps, cudaStream_t stream); 24 | 25 | 26 | #endif 27 | -------------------------------------------------------------------------------- /models/syncbn/modules/functional/_syncbn/src/syncbn.cu.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KexianHust/Structure-Guided-Ranking-Loss/6fa853da2aeb53ad0a66e95484a5ae7fc816bb4a/models/syncbn/modules/functional/_syncbn/src/syncbn.cu.o -------------------------------------------------------------------------------- /models/syncbn/modules/functional/_syncbn/src/syncbn.h: -------------------------------------------------------------------------------- 1 | int syncbn_sum_sqsum_cuda( 2 | const THCudaTensor *x, THCudaTensor *sum, THCudaTensor *sqsum); 3 | int syncbn_forward_cuda( 4 | THCudaTensor *z, const THCudaTensor *x, 5 | const THCudaTensor *gamma, const THCudaTensor *beta, 6 | const THCudaTensor *mean, const THCudaTensor *var, float eps); 7 | int syncbn_backward_xhat_cuda( 8 | const THCudaTensor *dz, const THCudaTensor *x, 9 | const THCudaTensor *mean, const THCudaTensor *var, 10 | THCudaTensor *sum_dz, THCudaTensor *sum_dz_xhat, 11 | float eps); 12 | int syncbn_backard_cuda( 13 | const THCudaTensor *dz, const THCudaTensor *x, 14 | const THCudaTensor *gamma, const THCudaTensor *beta, 15 | const THCudaTensor *mean, const THCudaTensor *var, 16 | const THCudaTensor *sum_dz, const THCudaTensor *sum_dz_xhat, 17 | THCudaTensor *dx, THCudaTensor *dgamma, THCudaTensor *dbeta, float eps); 18 | -------------------------------------------------------------------------------- /models/syncbn/modules/functional/syncbn.py: -------------------------------------------------------------------------------- 1 | """ 2 | /*****************************************************************************/ 3 | 4 | BatchNorm2dSync with multi-gpu 5 | 6 | code referenced from : https://github.com/mapillary/inplace_abn 7 | 8 | /*****************************************************************************/ 9 | """ 10 | from __future__ import absolute_import 11 | from __future__ import division 12 | from __future__ import print_function 13 | 14 | import torch.cuda.comm as comm 15 | from torch.autograd import Function 16 | from torch.autograd.function import once_differentiable 17 | 18 | from ._syncbn._ext import syncbn as _lib_bn 19 | 20 | 21 | def _count_samples(x): 22 | count = 1 23 | for i, s in enumerate(x.size()): 24 | if i != 1: 25 | count *= s 26 | return count 27 | 28 | 29 | def _check_contiguous(*args): 30 | if not all([mod is None or mod.is_contiguous() for mod in args]): 31 | raise ValueError("Non-contiguous input") 32 | 33 | 34 | class BatchNorm2dSyncFunc(Function): 35 | 36 | @classmethod 37 | def forward(cls, ctx, x, weight, bias, running_mean, running_var, 38 | extra, compute_stats=True, momentum=0.1, eps=1e-05): 39 | # Save context 40 | if extra is not None: 41 | cls._parse_extra(ctx, extra) 42 | ctx.compute_stats = compute_stats 43 | ctx.momentum = momentum 44 | ctx.eps = eps 45 | if ctx.compute_stats: 46 | N = _count_samples(x) * (ctx.master_queue.maxsize + 1) 47 | assert N > 1 48 | num_features = running_mean.size(0) 49 | # 1. compute sum(x) and sum(x^2) 50 | xsum = x.new().resize_(num_features) 51 | xsqsum = x.new().resize_(num_features) 52 | _check_contiguous(x, xsum, xsqsum) 53 | _lib_bn.syncbn_sum_sqsum_cuda(x.detach(), xsum, xsqsum) 54 | if ctx.is_master: 55 | xsums, xsqsums = [xsum], [xsqsum] 56 | # master : gatther all sum(x) and sum(x^2) from slaves 57 | for _ in range(ctx.master_queue.maxsize): 58 | xsum_w, xsqsum_w = ctx.master_queue.get() 59 | ctx.master_queue.task_done() 60 | xsums.append(xsum_w) 61 | xsqsums.append(xsqsum_w) 62 | xsum = comm.reduce_add(xsums) 63 | xsqsum = comm.reduce_add(xsqsums) 64 | mean = xsum / N 65 | sumvar = xsqsum - xsum * mean 66 | var = sumvar / N 67 | uvar = sumvar / (N - 1) 68 | # master : broadcast global mean, variance to all slaves 69 | tensors = comm.broadcast_coalesced( 70 | (mean, uvar, var), [mean.get_device()] + ctx.worker_ids) 71 | for ts, queue in zip(tensors[1:], ctx.worker_queues): 72 | queue.put(ts) 73 | else: 74 | # slave : send sum(x) and sum(x^2) to master 75 | ctx.master_queue.put((xsum, xsqsum)) 76 | # slave : get global mean and variance 77 | mean, uvar, var = ctx.worker_queue.get() 78 | ctx.worker_queue.task_done() 79 | 80 | # Update running stats 81 | running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean) 82 | running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * uvar) 83 | ctx.N = N 84 | ctx.save_for_backward(x, weight, bias, mean, var) 85 | else: 86 | mean, var = running_mean, running_var 87 | 88 | output = x.new().resize_as_(x) 89 | _check_contiguous(output, x, mean, var, weight, bias) 90 | # do batch norm forward 91 | _lib_bn.syncbn_forward_cuda( 92 | output, x, weight if weight is not None else x.new(), 93 | bias if bias is not None else x.new(), mean, var, ctx.eps) 94 | return output 95 | 96 | @staticmethod 97 | @once_differentiable 98 | def backward(ctx, dz): 99 | x, weight, bias, mean, var = ctx.saved_tensors 100 | dz = dz.contiguous() 101 | if ctx.needs_input_grad[0]: 102 | dx = dz.new().resize_as_(dz) 103 | else: 104 | dx = None 105 | if ctx.needs_input_grad[1]: 106 | dweight = dz.new().resize_as_(mean).zero_() 107 | else: 108 | dweight = None 109 | if ctx.needs_input_grad[2]: 110 | dbias = dz.new().resize_as_(mean).zero_() 111 | else: 112 | dbias = None 113 | _check_contiguous(x, dz, weight, bias, mean, var) 114 | 115 | # 1. compute \sum(\frac{dJ}{dy_i}) and \sum(\frac{dJ}{dy_i}*\hat{x_i}) 116 | num_features = mean.size(0) 117 | sum_dz = x.new().resize_(num_features) 118 | sum_dz_xhat = x.new().resize_(num_features) 119 | _check_contiguous(sum_dz, sum_dz_xhat) 120 | _lib_bn.syncbn_backward_xhat_cuda( 121 | dz, x, mean, var, sum_dz, sum_dz_xhat, ctx.eps) 122 | if ctx.is_master: 123 | sum_dzs, sum_dz_xhats = [sum_dz], [sum_dz_xhat] 124 | # master : gatther from slaves 125 | for _ in range(ctx.master_queue.maxsize): 126 | sum_dz_w, sum_dz_xhat_w = ctx.master_queue.get() 127 | ctx.master_queue.task_done() 128 | sum_dzs.append(sum_dz_w) 129 | sum_dz_xhats.append(sum_dz_xhat_w) 130 | # master : compute global stats 131 | sum_dz = comm.reduce_add(sum_dzs) 132 | sum_dz_xhat = comm.reduce_add(sum_dz_xhats) 133 | sum_dz /= ctx.N 134 | sum_dz_xhat /= ctx.N 135 | # master : broadcast global stats 136 | tensors = comm.broadcast_coalesced( 137 | (sum_dz, sum_dz_xhat), [mean.get_device()] + ctx.worker_ids) 138 | for ts, queue in zip(tensors[1:], ctx.worker_queues): 139 | queue.put(ts) 140 | else: 141 | # slave : send to master 142 | ctx.master_queue.put((sum_dz, sum_dz_xhat)) 143 | # slave : get global stats 144 | sum_dz, sum_dz_xhat = ctx.worker_queue.get() 145 | ctx.worker_queue.task_done() 146 | 147 | # do batch norm backward 148 | _lib_bn.syncbn_backard_cuda( 149 | dz, x, weight if weight is not None else dz.new(), 150 | bias if bias is not None else dz.new(), 151 | mean, var, sum_dz, sum_dz_xhat, 152 | dx if dx is not None else dz.new(), 153 | dweight if dweight is not None else dz.new(), 154 | dbias if dbias is not None else dz.new(), ctx.eps) 155 | 156 | return dx, dweight, dbias, None, None, None, \ 157 | None, None, None, None, None 158 | 159 | @staticmethod 160 | def _parse_extra(ctx, extra): 161 | ctx.is_master = extra["is_master"] 162 | if ctx.is_master: 163 | ctx.master_queue = extra["master_queue"] 164 | ctx.worker_queues = extra["worker_queues"] 165 | ctx.worker_ids = extra["worker_ids"] 166 | else: 167 | ctx.master_queue = extra["master_queue"] 168 | ctx.worker_queue = extra["worker_queue"] 169 | 170 | batchnorm2d_sync = BatchNorm2dSyncFunc.apply 171 | 172 | __all__ = ["batchnorm2d_sync"] 173 | -------------------------------------------------------------------------------- /models/syncbn/modules/functional/syncbn.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KexianHust/Structure-Guided-Ranking-Loss/6fa853da2aeb53ad0a66e95484a5ae7fc816bb4a/models/syncbn/modules/functional/syncbn.pyc -------------------------------------------------------------------------------- /models/syncbn/modules/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from .syncbn import * 2 | -------------------------------------------------------------------------------- /models/syncbn/modules/nn/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KexianHust/Structure-Guided-Ranking-Loss/6fa853da2aeb53ad0a66e95484a5ae7fc816bb4a/models/syncbn/modules/nn/__init__.pyc -------------------------------------------------------------------------------- /models/syncbn/modules/nn/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KexianHust/Structure-Guided-Ranking-Loss/6fa853da2aeb53ad0a66e95484a5ae7fc816bb4a/models/syncbn/modules/nn/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/syncbn/modules/nn/__pycache__/syncbn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KexianHust/Structure-Guided-Ranking-Loss/6fa853da2aeb53ad0a66e95484a5ae7fc816bb4a/models/syncbn/modules/nn/__pycache__/syncbn.cpython-37.pyc -------------------------------------------------------------------------------- /models/syncbn/modules/nn/syncbn.py: -------------------------------------------------------------------------------- 1 | """ 2 | /*****************************************************************************/ 3 | 4 | BatchNorm2dSync with multi-gpu 5 | 6 | /*****************************************************************************/ 7 | """ 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | try: 13 | # python 3 14 | from queue import Queue 15 | except ImportError: 16 | # python 2 17 | from Queue import Queue 18 | 19 | import torch 20 | import torch.nn as nn 21 | from modules.functional import batchnorm2d_sync 22 | 23 | 24 | class BatchNorm2d(nn.BatchNorm2d): 25 | """ 26 | BatchNorm2d with automatic multi-GPU Sync 27 | """ 28 | 29 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, 30 | track_running_stats=True): 31 | super(BatchNorm2d, self).__init__( 32 | num_features, eps=eps, momentum=momentum, affine=affine, 33 | track_running_stats=track_running_stats) 34 | self.devices = list(range(torch.cuda.device_count())) 35 | if len(self.devices) > 1: 36 | # Initialize queues 37 | self.worker_ids = self.devices[1:] 38 | self.master_queue = Queue(len(self.worker_ids)) 39 | self.worker_queues = [Queue(1) for _ in self.worker_ids] 40 | 41 | def forward(self, x): 42 | compute_stats = self.training or not self.track_running_stats 43 | if compute_stats and len(self.devices) > 1: 44 | if x.get_device() == self.devices[0]: 45 | # Master mode 46 | extra = { 47 | "is_master": True, 48 | "master_queue": self.master_queue, 49 | "worker_queues": self.worker_queues, 50 | "worker_ids": self.worker_ids 51 | } 52 | else: 53 | # Worker mode 54 | extra = { 55 | "is_master": False, 56 | "master_queue": self.master_queue, 57 | "worker_queue": self.worker_queues[ 58 | self.worker_ids.index(x.get_device())] 59 | } 60 | return batchnorm2d_sync(x, self.weight, self.bias, 61 | self.running_mean, self.running_var, 62 | extra, compute_stats, self.momentum, 63 | self.eps) 64 | return super(BatchNorm2d, self).forward(x) 65 | 66 | def __repr__(self): 67 | """repr""" 68 | rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \ 69 | ' affine={affine}, devices={devices})' 70 | return rep.format(name=self.__class__.__name__, **self.__dict__) 71 | -------------------------------------------------------------------------------- /models/syncbn/modules/nn/syncbn.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KexianHust/Structure-Guided-Ranking-Loss/6fa853da2aeb53ad0a66e95484a5ae7fc816bb4a/models/syncbn/modules/nn/syncbn.pyc -------------------------------------------------------------------------------- /models/syncbn/requirements.txt: -------------------------------------------------------------------------------- 1 | future 2 | cffi 3 | -------------------------------------------------------------------------------- /models/syncbn/test.py: -------------------------------------------------------------------------------- 1 | """ 2 | /*****************************************************************************/ 3 | 4 | Test for BatchNorm2dSync with multi-gpu 5 | 6 | /*****************************************************************************/ 7 | """ 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | import sys 13 | import numpy as np 14 | import torch 15 | from torch import nn 16 | from torch.nn import functional as F 17 | sys.path.append("./") 18 | from modules import nn as NN 19 | 20 | torch.backends.cudnn.deterministic = True 21 | 22 | 23 | def init_weight(model): 24 | for m in model.modules(): 25 | if isinstance(m, nn.Conv2d): 26 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 27 | m.weight.data.normal_(0, np.sqrt(2. / n)) 28 | elif isinstance(m, NN.BatchNorm2d) or isinstance(m, nn.BatchNorm2d): 29 | m.weight.data.fill_(1) 30 | m.bias.data.zero_() 31 | elif isinstance(m, nn.Linear): 32 | m.bias.data.zero_() 33 | 34 | num_gpu = torch.cuda.device_count() 35 | print("num_gpu={}".format(num_gpu)) 36 | if num_gpu < 2: 37 | print("No multi-gpu found. NN.BatchNorm2d will act as normal nn.BatchNorm2d") 38 | 39 | m1 = nn.Sequential( 40 | nn.Conv2d(3, 3, 1, 1, bias=False), 41 | nn.BatchNorm2d(3), 42 | nn.ReLU(inplace=True), 43 | nn.Conv2d(3, 3, 1, 1, bias=False), 44 | nn.BatchNorm2d(3), 45 | ).cuda() 46 | torch.manual_seed(123) 47 | init_weight(m1) 48 | m2 = nn.Sequential( 49 | nn.Conv2d(3, 3, 1, 1, bias=False), 50 | NN.BatchNorm2d(3), 51 | nn.ReLU(inplace=True), 52 | nn.Conv2d(3, 3, 1, 1, bias=False), 53 | NN.BatchNorm2d(3), 54 | ).cuda() 55 | torch.manual_seed(123) 56 | init_weight(m2) 57 | m2 = nn.DataParallel(m2, device_ids=range(num_gpu)) 58 | o1 = torch.optim.SGD(m1.parameters(), 1e-3) 59 | o2 = torch.optim.SGD(m2.parameters(), 1e-3) 60 | y = torch.ones(num_gpu).float().cuda() 61 | torch.manual_seed(123) 62 | for _ in range(100): 63 | x = torch.rand(num_gpu, 3, 2, 2).cuda() 64 | o1.zero_grad() 65 | z1 = m1(x) 66 | l1 = F.mse_loss(z1.mean(-1).mean(-1).mean(-1), y) 67 | l1.backward() 68 | o1.step() 69 | o2.zero_grad() 70 | z2 = m2(x) 71 | l2 = F.mse_loss(z2.mean(-1).mean(-1).mean(-1), y) 72 | l2.backward() 73 | o2.step() 74 | print(m2.module[1].bias.grad - m1[1].bias.grad) 75 | print(m2.module[1].weight.grad - m1[1].weight.grad) 76 | print(m2.module[-1].bias.grad - m1[-1].bias.grad) 77 | print(m2.module[-1].weight.grad - m1[-1].weight.grad) 78 | m2 = m2.module 79 | print("===============================") 80 | print("m1(nn.BatchNorm2d) running_mean", 81 | m1[1].running_mean, m1[-1].running_mean) 82 | print("m2(NN.BatchNorm2d) running_mean", 83 | m2[1].running_mean, m2[-1].running_mean) 84 | print("m1(nn.BatchNorm2d) running_var", m1[1].running_var, m1[-1].running_var) 85 | print("m2(NN.BatchNorm2d) running_var", m2[1].running_var, m2[-1].running_var) 86 | print("m1(nn.BatchNorm2d) weight", m1[1].weight, m1[-1].weight) 87 | print("m2(NN.BatchNorm2d) weight", m2[1].weight, m2[-1].weight) 88 | print("m1(nn.BatchNorm2d) bias", m1[1].bias, m1[-1].bias) 89 | print("m2(NN.BatchNorm2d) bias", m2[1].bias, m2[-1].bias) 90 | -------------------------------------------------------------------------------- /ranking_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | 6 | 7 | """ 8 | Sampling strategies: RS (Random Sampling), EGS (Edge-Guided Sampling), and IGS (Instance-Guided Sampling) 9 | """ 10 | ########### 11 | # RANDOM SAMPLING 12 | # input: 13 | # inputs[i,:], targets[i, :], masks[i, :], self.mask_value, self.point_pairs 14 | # return: 15 | # inputs_A, inputs_B, targets_A, targets_B, consistent_masks_A, consistent_masks_B 16 | ########### 17 | def randomSampling(inputs, targets, masks, threshold, sample_num): 18 | 19 | # find A-B point pairs from predictions 20 | inputs_index = torch.masked_select(inputs, targets.gt(threshold)) 21 | num_effect_pixels = len(inputs_index) 22 | shuffle_effect_pixels = torch.randperm(num_effect_pixels).cuda() 23 | inputs_A = inputs_index[shuffle_effect_pixels[0:sample_num*2:2]] 24 | inputs_B = inputs_index[shuffle_effect_pixels[1:sample_num*2:2]] 25 | # find corresponding pairs from GT 26 | target_index = torch.masked_select(targets, targets.gt(threshold)) 27 | targets_A = target_index[shuffle_effect_pixels[0:sample_num*2:2]] 28 | targets_B = target_index[shuffle_effect_pixels[1:sample_num*2:2]] 29 | # only compute the losses of point pairs with valid GT 30 | consistent_masks_index = torch.masked_select(masks, targets.gt(threshold)) 31 | consistent_masks_A = consistent_masks_index[shuffle_effect_pixels[0:sample_num*2:2]] 32 | consistent_masks_B = consistent_masks_index[shuffle_effect_pixels[1:sample_num*2:2]] 33 | 34 | # The amount of A and B should be the same!! 35 | if len(targets_A) > len(targets_B): 36 | targets_A = targets_A[:-1] 37 | inputs_A = inputs_A[:-1] 38 | consistent_masks_A = consistent_masks_A[:-1] 39 | 40 | return inputs_A, inputs_B, targets_A, targets_B, consistent_masks_A, consistent_masks_B 41 | 42 | ########### 43 | # EDGE-GUIDED SAMPLING 44 | # input: 45 | # inputs[i,:], targets[i, :], masks[i, :], edges_img[i], thetas_img[i], masks[i, :], h, w 46 | # return: 47 | # inputs_A, inputs_B, targets_A, targets_B, masks_A, masks_B 48 | ########### 49 | def ind2sub(idx, cols): 50 | r = idx / cols 51 | c = idx - r * cols 52 | return r, c 53 | 54 | def sub2ind(r, c, cols): 55 | idx = r * cols + c 56 | return idx 57 | 58 | def edgeGuidedSampling(inputs, targets, edges_img, thetas_img, masks, h, w): 59 | 60 | # find edges 61 | edges_max = edges_img.max() 62 | edges_mask = edges_img.ge(edges_max*0.1) 63 | edges_loc = edges_mask.nonzero() 64 | 65 | inputs_edge = torch.masked_select(inputs, edges_mask) 66 | targets_edge = torch.masked_select(targets, edges_mask) 67 | thetas_edge = torch.masked_select(thetas_img, edges_mask) 68 | minlen = inputs_edge.size()[0] 69 | 70 | # find anchor points (i.e, edge points) 71 | sample_num = minlen 72 | index_anchors = torch.randint(0, minlen, (sample_num,), dtype=torch.long).cuda() 73 | anchors = torch.gather(inputs_edge, 0, index_anchors) 74 | theta_anchors = torch.gather(thetas_edge, 0, index_anchors) 75 | row_anchors, col_anchors = ind2sub(edges_loc[index_anchors].squeeze(1), w) 76 | ## compute the coordinates of 4-points, distances are from [2, 30] 77 | distance_matrix = torch.randint(2, 31, (4,sample_num)).cuda() 78 | pos_or_neg = torch.ones(4, sample_num).cuda() 79 | pos_or_neg[:2,:] = -pos_or_neg[:2,:] 80 | distance_matrix = distance_matrix.float() * pos_or_neg 81 | col = col_anchors.unsqueeze(0).expand(4, sample_num).long() + torch.round(distance_matrix.double() * torch.cos(theta_anchors).unsqueeze(0)).long() 82 | row = row_anchors.unsqueeze(0).expand(4, sample_num).long() + torch.round(distance_matrix.double() * torch.sin(theta_anchors).unsqueeze(0)).long() 83 | 84 | # constrain 0=w-1] = w-1 88 | row[row<0] = 0 89 | row[row>h-1] = h-1 90 | 91 | # a-b, b-c, c-d 92 | a = sub2ind(row[0,:], col[0,:], w) 93 | b = sub2ind(row[1,:], col[1,:], w) 94 | c = sub2ind(row[2,:], col[2,:], w) 95 | d = sub2ind(row[3,:], col[3,:], w) 96 | A = torch.cat((a,b,c), 0) 97 | B = torch.cat((b,c,d), 0) 98 | 99 | inputs_A = torch.gather(inputs, 0, A.long()) 100 | inputs_B = torch.gather(inputs, 0, B.long()) 101 | targets_A = torch.gather(targets, 0, A.long()) 102 | targets_B = torch.gather(targets, 0, B.long()) 103 | masks_A = torch.gather(masks, 0, A.long()) 104 | masks_B = torch.gather(masks, 0, B.long()) 105 | 106 | return inputs_A, inputs_B, targets_A, targets_B, masks_A, masks_B, sample_num 107 | 108 | ###################################################### 109 | # EdgeguidedRankingLoss (with regularization term) 110 | # Please comment regularization_loss if you don't want to use multi-scale gradient matching term 111 | ##################################################### 112 | class EdgeguidedRankingLoss(nn.Module): 113 | def __init__(self, point_pairs=10000, sigma=0.03, alpha=1.0, mask_value=-1e-8): 114 | super(EdgeguidedRankingLoss, self).__init__() 115 | self.point_pairs = point_pairs # number of point pairs 116 | self.sigma = sigma # used for determining the ordinal relationship between a selected pair 117 | self.alpha = alpha # used for balancing the effect of = and (<,>) 118 | self.mask_value = mask_value 119 | #self.regularization_loss = GradientLoss(scales=4) 120 | 121 | def getEdge(self, images): 122 | n,c,h,w = images.size() 123 | a = torch.Tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]).cuda().view((1,1,3,3)).repeat(1, 1, 1, 1) 124 | b = torch.Tensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]]).cuda().view((1,1,3,3)).repeat(1, 1, 1, 1) 125 | if c == 3: 126 | gradient_x = F.conv2d(images[:,0,:,:].unsqueeze(1), a) 127 | gradient_y = F.conv2d(images[:,0,:,:].unsqueeze(1), b) 128 | else: 129 | gradient_x = F.conv2d(images, a) 130 | gradient_y = F.conv2d(images, b) 131 | edges = torch.sqrt(torch.pow(gradient_x,2)+ torch.pow(gradient_y,2)) 132 | edges = F.pad(edges, (1,1,1,1), "constant", 0) 133 | thetas = torch.atan2(gradient_y, gradient_x) 134 | thetas = F.pad(thetas, (1,1,1,1), "constant", 0) 135 | 136 | return edges, thetas 137 | 138 | def forward(self, inputs, targets, images, masks=None): 139 | if masks == None: 140 | masks = targets > self.mask_value 141 | # Comment this line if you don't want to use the multi-scale gradient matching term !!! 142 | # regularization_loss = self.regularization_loss(inputs.squeeze(1), targets.squeeze(1), masks.squeeze(1)) 143 | # find edges from RGB 144 | edges_img, thetas_img = self.getEdge(images) 145 | 146 | #============================= 147 | n,c,h,w = targets.size() 148 | if n != 1: 149 | inputs = inputs.view(n, -1).double() 150 | targets = targets.view(n, -1).double() 151 | masks = masks.view(n, -1).double() 152 | edges_img = edges_img.view(n, -1).double() 153 | thetas_img = thetas_img.view(n, -1).double() 154 | 155 | else: 156 | inputs = inputs.contiguous().view(1, -1).double() 157 | targets = targets.contiguous().view(1, -1).double() 158 | masks = masks.contiguous().view(1, -1).double() 159 | edges_img = edges_img.contiguous().view(1, -1).double() 160 | thetas_img = thetas_img.contiguous().view(1, -1).double() 161 | 162 | # initialization 163 | loss = torch.DoubleTensor([0.0]).cuda() 164 | 165 | 166 | for i in range(n): 167 | # Edge-Guided sampling 168 | inputs_A, inputs_B, targets_A, targets_B, masks_A, masks_B, sample_num = edgeGuidedSampling(inputs[i,:], targets[i, :], edges_img[i], thetas_img[i], masks[i, :], h, w) 169 | # Random Sampling 170 | random_sample_num = sample_num 171 | random_inputs_A, random_inputs_B, random_targets_A, random_targets_B, random_masks_A, random_masks_B = randomSampling(inputs[i,:], targets[i, :], masks[i, :], self.mask_value, random_sample_num) 172 | 173 | # Combine EGS + RS 174 | inputs_A = torch.cat((inputs_A, random_inputs_A), 0) 175 | inputs_B = torch.cat((inputs_B, random_inputs_B), 0) 176 | targets_A = torch.cat((targets_A, random_targets_A), 0) 177 | targets_B = torch.cat((targets_B, random_targets_B), 0) 178 | masks_A = torch.cat((masks_A, random_masks_A), 0) 179 | masks_B = torch.cat((masks_B, random_masks_B), 0) 180 | 181 | #GT ordinal relationship 182 | target_ratio = torch.div(targets_A+1e-6, targets_B+1e-6) 183 | mask_eq = target_ratio.lt(1.0 + self.sigma) * target_ratio.gt(1.0/(1.0+self.sigma)) 184 | labels = torch.zeros_like(target_ratio) 185 | labels[target_ratio.ge(1.0 + self.sigma)] = 1 186 | labels[target_ratio.le(1.0/(1.0+self.sigma))] = -1 187 | 188 | # consider forward-backward consistency checking, i.e, only compute losses of point pairs with valid GT 189 | consistency_mask = masks_A * masks_B 190 | 191 | equal_loss = (inputs_A - inputs_B).pow(2) * mask_eq.double() * consistency_mask 192 | unequal_loss = torch.log(1 + torch.exp((-inputs_A + inputs_B) * labels)) * (~mask_eq).double() * consistency_mask 193 | 194 | # Please comment the regularization term if you don't want to use the multi-scale gradient matching loss !!! 195 | loss = loss + self.alpha * equal_loss.mean() + 1.0 * unequal_loss.mean() #+ 0.2 * regularization_loss.double() 196 | 197 | return loss[0].float()/n 198 | --------------------------------------------------------------------------------