├── modules ├── __init__.py ├── nn │ ├── __init__.py │ └── syncbn.py └── functional │ ├── __init__.py │ ├── csrc │ ├── ext_lib.cpp │ ├── cuda │ │ ├── ext_lib.h │ │ ├── common.h │ │ └── bn_cuda.cu │ └── bn.h │ ├── _csrc.py │ └── syncbn.py ├── requirements.txt ├── LICENSE ├── Dockerfile ├── .gitignore ├── test.py └── README.md /modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | future 2 | cffi 3 | ninja -------------------------------------------------------------------------------- /modules/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from .syncbn import * 2 | -------------------------------------------------------------------------------- /modules/functional/__init__.py: -------------------------------------------------------------------------------- 1 | from .syncbn import batchnorm2d_sync 2 | -------------------------------------------------------------------------------- /modules/functional/csrc/ext_lib.cpp: -------------------------------------------------------------------------------- 1 | #include "bn.h" 2 | 3 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 4 | m.def("syncbn_sum_sqsum", &syncbn_sum_sqsum, "Sum and Sum^2 computation"); 5 | m.def("syncbn_forward", &syncbn_forward, "SyncBN forward computation"); 6 | m.def("syncbn_backward_xhat", &syncbn_backward_xhat, 7 | "First part of SyncBN backward computation"); 8 | m.def("syncbn_backward", &syncbn_backward, 9 | "Second part of SyncBN backward computation"); 10 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Tamaki Kojima 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /modules/functional/csrc/cuda/ext_lib.h: -------------------------------------------------------------------------------- 1 | /***************************************************************************** 2 | 3 | CUDA SyncBN code 4 | 5 | *****************************************************************************/ 6 | #pragma once 7 | #include 8 | #include 9 | 10 | /// Sync-BN 11 | std::vector syncbn_sum_sqsum_cuda(const at::Tensor& x); 12 | at::Tensor syncbn_forward_cuda(const at::Tensor& x, const at::Tensor& weight, 13 | const at::Tensor& bias, const at::Tensor& mean, 14 | const at::Tensor& var, bool affine, float eps); 15 | std::vector syncbn_backward_xhat_cuda(const at::Tensor& dz, 16 | const at::Tensor& x, 17 | const at::Tensor& mean, 18 | const at::Tensor& var, 19 | float eps); 20 | std::vector syncbn_backward_cuda( 21 | const at::Tensor& dz, const at::Tensor& x, const at::Tensor& weight, 22 | const at::Tensor& bias, const at::Tensor& mean, const at::Tensor& var, 23 | const at::Tensor& sum_dz, const at::Tensor& sum_dz_xhat, bool affine, 24 | float eps); 25 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # base container 2 | FROM nvidia/cuda::10.0-devel-ubuntu18.04 3 | 4 | # update 5 | ENV DEBIAN_FRONTEND "noninteractive" 6 | RUN apt-get update -y 7 | RUN apt-get -y \ 8 | -o Dpkg::Options::="--force-confdef" \ 9 | -o Dpkg::Options::="--force-confold" dist-upgrade 10 | 11 | # install basic 12 | RUN apt-get install -y --no-install-recommends \ 13 | less sudo ssh \ 14 | build-essential \ 15 | unzip git curl wget vim tree htop \ 16 | python3-dev python3-tk \ 17 | ninja-build 18 | 19 | # python libs 20 | RUN curl https://bootstrap.pypa.io/get-pip.py | python3 21 | RUN pip3 install \ 22 | future six cffi numpy pillow tqdm Cython awscli ninja 23 | 24 | # install pytorch 25 | RUN pip3 install https://download.pytorch.org/whl/cu100/torch-1.0.0-cp36-cp36m-linux_x86_64.whl 26 | RUN pip3 install torchvision 27 | 28 | # clean up 29 | RUN apt-get update -y && apt-get upgrade -y && apt-get autoremove -y 30 | RUN apt-get clean -y && apt-get autoclean -y 31 | RUN rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/* 32 | 33 | # create mountpoint from host 34 | RUN mkdir -p /workspace 35 | 36 | # create non-root user 37 | ARG user_name=ubuntu 38 | ARG user_id=1000 39 | ARG group_name=ubuntu 40 | ARG group_id=1000 41 | RUN groupadd -g ${group_id} ${group_name} 42 | RUN useradd -u ${user_id} -g ${group_id} -d /home/${user_name} --create-home --shell /bin/bash ${user_name} 43 | RUN echo "${user_name} ALL=(ALL) NOPASSWD:ALL" >> /etc/sudoers 44 | RUN chown -R ${user_name}:${group_name} /home/${user_name} 45 | RUN chown -R ${user_name}:${group_name} /workspace 46 | RUN chsh -s /bin/bash ${user_name} 47 | USER ubuntu 48 | WORKDIR /home/ubuntu 49 | ENV HOME /home/ubuntu 50 | 51 | -------------------------------------------------------------------------------- /modules/functional/_csrc.py: -------------------------------------------------------------------------------- 1 | """ 2 | /*****************************************************************************/ 3 | 4 | Extension module loader 5 | 6 | code referenced from : https://github.com/facebookresearch/maskrcnn-benchmark 7 | 8 | /*****************************************************************************/ 9 | """ 10 | from __future__ import absolute_import 11 | from __future__ import division 12 | from __future__ import print_function 13 | 14 | import glob 15 | import os.path 16 | 17 | import torch 18 | 19 | try: 20 | from torch.utils.cpp_extension import load 21 | from torch.utils.cpp_extension import CUDA_HOME 22 | except ImportError: 23 | raise ImportError( 24 | "The cpp layer extensions requires PyTorch 0.4 or higher") 25 | 26 | 27 | def _load_C_extensions(): 28 | this_dir = os.path.dirname(os.path.abspath(__file__)) 29 | this_dir = os.path.join(this_dir, "csrc") 30 | 31 | main_file = glob.glob(os.path.join(this_dir, "*.cpp")) 32 | sources_cpu = glob.glob(os.path.join(this_dir, "cpu", "*.cpp")) 33 | sources_cuda = glob.glob(os.path.join(this_dir, "cuda", "*.cu")) 34 | 35 | sources = main_file + sources_cpu 36 | 37 | extra_cflags = [] 38 | extra_cuda_cflags = [] 39 | if torch.cuda.is_available() and CUDA_HOME is not None: 40 | sources.extend(sources_cuda) 41 | extra_cflags = ["-O3", "-DWITH_CUDA"] 42 | extra_cuda_cflags = ["--expt-extended-lambda"] 43 | sources = [os.path.join(this_dir, s) for s in sources] 44 | extra_include_paths = [this_dir] 45 | return load( 46 | name="ext_lib", 47 | sources=sources, 48 | extra_cflags=extra_cflags, 49 | extra_include_paths=extra_include_paths, 50 | extra_cuda_cflags=extra_cuda_cflags, 51 | ) 52 | 53 | 54 | _backend = _load_C_extensions() 55 | -------------------------------------------------------------------------------- /modules/functional/csrc/bn.h: -------------------------------------------------------------------------------- 1 | /***************************************************************************** 2 | 3 | SyncBN 4 | 5 | *****************************************************************************/ 6 | #pragma once 7 | 8 | #ifdef WITH_CUDA 9 | #include "cuda/ext_lib.h" 10 | #endif 11 | 12 | /// SyncBN 13 | 14 | std::vector syncbn_sum_sqsum(const at::Tensor& x) { 15 | if (x.is_cuda()) { 16 | #ifdef WITH_CUDA 17 | return syncbn_sum_sqsum_cuda(x); 18 | #else 19 | AT_ERROR("Not compiled with GPU support"); 20 | #endif 21 | } else { 22 | AT_ERROR("CPU implementation not supported"); 23 | } 24 | } 25 | 26 | at::Tensor syncbn_forward(const at::Tensor& x, const at::Tensor& weight, 27 | const at::Tensor& bias, const at::Tensor& mean, 28 | const at::Tensor& var, bool affine, float eps) { 29 | if (x.is_cuda()) { 30 | #ifdef WITH_CUDA 31 | return syncbn_forward_cuda(x, weight, bias, mean, var, affine, eps); 32 | #else 33 | AT_ERROR("Not compiled with GPU support"); 34 | #endif 35 | } else { 36 | AT_ERROR("CPU implementation not supported"); 37 | } 38 | } 39 | 40 | std::vector syncbn_backward_xhat(const at::Tensor& dz, 41 | const at::Tensor& x, 42 | const at::Tensor& mean, 43 | const at::Tensor& var, float eps) { 44 | if (dz.is_cuda()) { 45 | #ifdef WITH_CUDA 46 | return syncbn_backward_xhat_cuda(dz, x, mean, var, eps); 47 | #else 48 | AT_ERROR("Not compiled with GPU support"); 49 | #endif 50 | } else { 51 | AT_ERROR("CPU implementation not supported"); 52 | } 53 | } 54 | 55 | std::vector syncbn_backward( 56 | const at::Tensor& dz, const at::Tensor& x, const at::Tensor& weight, 57 | const at::Tensor& bias, const at::Tensor& mean, const at::Tensor& var, 58 | const at::Tensor& sum_dz, const at::Tensor& sum_dz_xhat, bool affine, 59 | float eps) { 60 | if (dz.is_cuda()) { 61 | #ifdef WITH_CUDA 62 | return syncbn_backward_cuda(dz, x, weight, bias, mean, var, sum_dz, 63 | sum_dz_xhat, affine, eps); 64 | #else 65 | AT_ERROR("Not compiled with GPU support"); 66 | #endif 67 | } else { 68 | AT_ERROR("CPU implementation not supported"); 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Prerequisites 2 | *.d 3 | 4 | # Compiled Object files 5 | *.slo 6 | *.lo 7 | *.o 8 | *.obj 9 | 10 | # Precompiled Headers 11 | *.gch 12 | *.pch 13 | 14 | # Compiled Dynamic libraries 15 | *.so 16 | *.dylib 17 | *.dll 18 | 19 | # Fortran module files 20 | *.mod 21 | *.smod 22 | 23 | # Compiled Static libraries 24 | *.lai 25 | *.la 26 | *.a 27 | *.lib 28 | 29 | # Executables 30 | *.exe 31 | *.out 32 | *.app 33 | 34 | # CMake 35 | CMakeCache.txt 36 | CMakeFiles 37 | CMakeScripts 38 | Testing 39 | Makefile 40 | cmake_install.cmake 41 | install_manifest.txt 42 | compile_commands.json 43 | CTestTestfile.cmake 44 | 45 | # IDE 46 | .cproject 47 | .project 48 | .pydevproject 49 | .travis.yml 50 | 51 | # Compiled protocol buffers 52 | *.pb.h 53 | *.pb.cc 54 | *_pb2.py 55 | 56 | # IPython notebook checkpoints 57 | .ipynb_checkpoints 58 | 59 | # Editor temporaries 60 | *.swp 61 | *~ 62 | 63 | # Sublime Text settings 64 | *.sublime-workspace 65 | *.sublime-project 66 | 67 | # Eclipse Project settings 68 | *.*project 69 | .settings 70 | 71 | # QtCreator files 72 | *.user 73 | 74 | # PyCharm files 75 | .idea 76 | 77 | # Visual Studio Code files 78 | .vscode 79 | 80 | # OSX dir files 81 | .DS_Store 82 | 83 | # Byte-compiled / optimized / DLL files 84 | __pycache__/ 85 | *.py[cod] 86 | *$py.class 87 | 88 | # C extensions 89 | *.so 90 | 91 | # Distribution / packaging 92 | .Python 93 | build/ 94 | develop-eggs/ 95 | dist/ 96 | downloads/ 97 | eggs/ 98 | .eggs/ 99 | lib/ 100 | lib64/ 101 | parts/ 102 | sdist/ 103 | var/ 104 | wheels/ 105 | *.egg-info/ 106 | .installed.cfg 107 | *.egg 108 | 109 | # PyInstaller 110 | # Usually these files are written by a python script from a template 111 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 112 | *.manifest 113 | *.spec 114 | 115 | # Installer logs 116 | pip-log.txt 117 | pip-delete-this-directory.txt 118 | 119 | # Unit test / coverage reports 120 | htmlcov/ 121 | .tox/ 122 | .coverage 123 | .coverage.* 124 | .cache 125 | nosetests.xml 126 | coverage.xml 127 | *.cover 128 | .hypothesis/ 129 | 130 | # Translations 131 | *.mo 132 | *.pot 133 | 134 | # Django stuff: 135 | *.log 136 | local_settings.py 137 | 138 | # Flask stuff: 139 | instance/ 140 | .webassets-cache 141 | 142 | # Scrapy stuff: 143 | .scrapy 144 | 145 | # Sphinx documentation 146 | docs/_build/ 147 | 148 | # PyBuilder 149 | target/ 150 | 151 | # Jupyter Notebook 152 | .ipynb_checkpoints 153 | 154 | # pyenv 155 | .python-version 156 | 157 | # celery beat schedule file 158 | celerybeat-schedule 159 | 160 | # SageMath parsed files 161 | *.sage.py 162 | 163 | # Environments 164 | .env 165 | .venv 166 | env/ 167 | venv/ 168 | ENV/ 169 | 170 | # Spyder project settings 171 | .spyderproject 172 | .spyproject 173 | 174 | # Rope project settings 175 | .ropeproject 176 | 177 | # mkdocs documentation 178 | /site 179 | 180 | # mypy 181 | .mypy_cache/ 182 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | """ 2 | /*****************************************************************************/ 3 | 4 | Test for BatchNorm2dSync with multi-gpu 5 | 6 | /*****************************************************************************/ 7 | """ 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | import sys 13 | import numpy as np 14 | import torch 15 | from torch import nn 16 | from torch.nn import functional as F 17 | sys.path.append("./") 18 | from modules import nn as NN 19 | 20 | torch.backends.cudnn.deterministic = True 21 | 22 | 23 | def init_weight(model): 24 | for m in model.modules(): 25 | if isinstance(m, nn.Conv2d): 26 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 27 | m.weight.data.normal_(0, np.sqrt(2. / n)) 28 | elif isinstance(m, NN.BatchNorm2d) or isinstance(m, nn.BatchNorm2d): 29 | m.weight.data.fill_(1) 30 | m.bias.data.zero_() 31 | elif isinstance(m, nn.Linear): 32 | m.bias.data.zero_() 33 | 34 | num_gpu = torch.cuda.device_count() 35 | print("num_gpu={}".format(num_gpu)) 36 | if num_gpu < 2: 37 | print("No multi-gpu found. NN.BatchNorm2d will act as normal nn.BatchNorm2d") 38 | 39 | m1 = nn.Sequential( 40 | nn.Conv2d(3, 3, 1, 1, bias=False), 41 | nn.BatchNorm2d(3), 42 | nn.ReLU(inplace=True), 43 | nn.Conv2d(3, 3, 1, 1, bias=False), 44 | nn.BatchNorm2d(3), 45 | ).cuda() 46 | torch.manual_seed(123) 47 | init_weight(m1) 48 | m2 = nn.Sequential( 49 | nn.Conv2d(3, 3, 1, 1, bias=False), 50 | NN.BatchNorm2d(3), 51 | nn.ReLU(inplace=True), 52 | nn.Conv2d(3, 3, 1, 1, bias=False), 53 | NN.BatchNorm2d(3), 54 | ).cuda() 55 | torch.manual_seed(123) 56 | init_weight(m2) 57 | m2 = nn.DataParallel(m2, device_ids=range(num_gpu)) 58 | o1 = torch.optim.SGD(m1.parameters(), 1e-3) 59 | o2 = torch.optim.SGD(m2.parameters(), 1e-3) 60 | y = torch.ones(num_gpu).float().cuda() 61 | torch.manual_seed(123) 62 | for _ in range(100): 63 | x = torch.rand(num_gpu, 3, 2, 2).cuda() 64 | o1.zero_grad() 65 | z1 = m1(x) 66 | l1 = F.mse_loss(z1.mean(-1).mean(-1).mean(-1), y) 67 | l1.backward() 68 | o1.step() 69 | o2.zero_grad() 70 | z2 = m2(x) 71 | l2 = F.mse_loss(z2.mean(-1).mean(-1).mean(-1), y) 72 | l2.backward() 73 | o2.step() 74 | print(m2.module[1].bias.grad - m1[1].bias.grad) 75 | print(m2.module[1].weight.grad - m1[1].weight.grad) 76 | print(m2.module[-1].bias.grad - m1[-1].bias.grad) 77 | print(m2.module[-1].weight.grad - m1[-1].weight.grad) 78 | m2 = m2.module 79 | print("===============================") 80 | print("m1(nn.BatchNorm2d) running_mean", 81 | m1[1].running_mean, m1[-1].running_mean) 82 | print("m2(NN.BatchNorm2d) running_mean", 83 | m2[1].running_mean, m2[-1].running_mean) 84 | print("m1(nn.BatchNorm2d) running_var", m1[1].running_var, m1[-1].running_var) 85 | print("m2(NN.BatchNorm2d) running_var", m2[1].running_var, m2[-1].running_var) 86 | print("m1(nn.BatchNorm2d) weight", m1[1].weight, m1[-1].weight) 87 | print("m2(NN.BatchNorm2d) weight", m2[1].weight, m2[-1].weight) 88 | print("m1(nn.BatchNorm2d) bias", m1[1].bias, m1[-1].bias) 89 | print("m2(NN.BatchNorm2d) bias", m2[1].bias, m2[-1].bias) 90 | -------------------------------------------------------------------------------- /modules/functional/csrc/cuda/common.h: -------------------------------------------------------------------------------- 1 | /***************************************************************************** 2 | 3 | CUDA utility funcs 4 | 5 | code referenced from : https://github.com/mapillary/inplace_abn 6 | 7 | *****************************************************************************/ 8 | #pragma once 9 | 10 | #include 11 | 12 | // Checks 13 | #ifndef AT_CHECK 14 | #define AT_CHECK AT_ASSERT 15 | #endif 16 | #define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 17 | #define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x " must be contiguous") 18 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 19 | 20 | /* 21 | * General settings 22 | */ 23 | const int WARP_SIZE = 32; 24 | const int MAX_BLOCK_SIZE = 512; 25 | 26 | template 27 | struct Pair { 28 | T v1, v2; 29 | __device__ Pair() {} 30 | __device__ Pair(T _v1, T _v2) : v1(_v1), v2(_v2) {} 31 | __device__ Pair(T v) : v1(v), v2(v) {} 32 | __device__ Pair(int v) : v1(v), v2(v) {} 33 | __device__ Pair &operator+=(const Pair &a) { 34 | v1 += a.v1; 35 | v2 += a.v2; 36 | return *this; 37 | } 38 | }; 39 | 40 | /* 41 | * Utility functions 42 | */ 43 | template 44 | __device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask, 45 | int width = warpSize, 46 | unsigned int mask = 0xffffffff) { 47 | #if CUDART_VERSION >= 9000 48 | return __shfl_xor_sync(mask, value, laneMask, width); 49 | #else 50 | return __shfl_xor(value, laneMask, width); 51 | #endif 52 | } 53 | 54 | __device__ __forceinline__ int getMSB(int val) { return 31 - __clz(val); } 55 | 56 | static int getNumThreads(int nElem) { 57 | int threadSizes[5] = {32, 64, 128, 256, MAX_BLOCK_SIZE}; 58 | for (int i = 0; i != 5; ++i) { 59 | if (nElem <= threadSizes[i]) { 60 | return threadSizes[i]; 61 | } 62 | } 63 | return MAX_BLOCK_SIZE; 64 | } 65 | 66 | template 67 | static __device__ __forceinline__ T warpSum(T val) { 68 | #if __CUDA_ARCH__ >= 300 69 | for (int i = 0; i < getMSB(WARP_SIZE); ++i) { 70 | val += WARP_SHFL_XOR(val, 1 << i, WARP_SIZE); 71 | } 72 | #else 73 | __shared__ T values[MAX_BLOCK_SIZE]; 74 | values[threadIdx.x] = val; 75 | __threadfence_block(); 76 | const int base = (threadIdx.x / WARP_SIZE) * WARP_SIZE; 77 | for (int i = 1; i < WARP_SIZE; i++) { 78 | val += values[base + ((i + threadIdx.x) % WARP_SIZE)]; 79 | } 80 | #endif 81 | return val; 82 | } 83 | 84 | template 85 | static __device__ __forceinline__ Pair warpSum(Pair value) { 86 | value.v1 = warpSum(value.v1); 87 | value.v2 = warpSum(value.v2); 88 | return value; 89 | } 90 | 91 | template 92 | __device__ T reduce(Op op, int plane, int N, int C, int S) { 93 | T sum = (T)0; 94 | for (int batch = 0; batch < N; ++batch) { 95 | for (int x = threadIdx.x; x < S; x += blockDim.x) { 96 | sum += op(batch, plane, x); 97 | } 98 | } 99 | 100 | // sum over NumThreads within a warp 101 | sum = warpSum(sum); 102 | 103 | // 'transpose', and reduce within warp again 104 | __shared__ T shared[32]; 105 | __syncthreads(); 106 | if (threadIdx.x % WARP_SIZE == 0) { 107 | shared[threadIdx.x / WARP_SIZE] = sum; 108 | } 109 | if (threadIdx.x >= blockDim.x / WARP_SIZE && threadIdx.x < WARP_SIZE) { 110 | // zero out the other entries in shared 111 | shared[threadIdx.x] = (T)0; 112 | } 113 | __syncthreads(); 114 | if (threadIdx.x / WARP_SIZE == 0) { 115 | sum = warpSum(shared[threadIdx.x]); 116 | if (threadIdx.x == 0) { 117 | shared[0] = sum; 118 | } 119 | } 120 | __syncthreads(); 121 | 122 | // Everyone picks it up, should be broadcast into the whole gradInput 123 | return shared[0]; 124 | } -------------------------------------------------------------------------------- /modules/nn/syncbn.py: -------------------------------------------------------------------------------- 1 | """ 2 | /*****************************************************************************/ 3 | 4 | BatchNorm2dSync with multi-gpu 5 | 6 | /*****************************************************************************/ 7 | """ 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | try: 13 | # python 3 14 | from queue import Queue 15 | except ImportError: 16 | # python 2 17 | from Queue import Queue 18 | 19 | import torch 20 | import torch.nn as nn 21 | from torch.nn import functional as F 22 | from torch.nn.parameter import Parameter 23 | from modules.functional import batchnorm2d_sync 24 | 25 | 26 | class _BatchNorm(nn.Module): 27 | """ 28 | Customized BatchNorm from nn.BatchNorm 29 | >> added freeze attribute to enable bn freeze. 30 | """ 31 | 32 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, 33 | track_running_stats=True): 34 | super(_BatchNorm, self).__init__() 35 | self.num_features = num_features 36 | self.eps = eps 37 | self.momentum = momentum 38 | self.affine = affine 39 | self.track_running_stats = track_running_stats 40 | self.freezed = False 41 | if self.affine: 42 | self.weight = Parameter(torch.Tensor(num_features)) 43 | self.bias = Parameter(torch.Tensor(num_features)) 44 | else: 45 | self.register_parameter('weight', None) 46 | self.register_parameter('bias', None) 47 | if self.track_running_stats: 48 | self.register_buffer('running_mean', torch.zeros(num_features)) 49 | self.register_buffer('running_var', torch.ones(num_features)) 50 | else: 51 | self.register_parameter('running_mean', None) 52 | self.register_parameter('running_var', None) 53 | self.reset_parameters() 54 | 55 | def reset_parameters(self): 56 | if self.track_running_stats: 57 | self.running_mean.zero_() 58 | self.running_var.fill_(1) 59 | if self.affine: 60 | self.weight.data.uniform_() 61 | self.bias.data.zero_() 62 | 63 | def _check_input_dim(self, input): 64 | return NotImplemented 65 | 66 | def forward(self, input): 67 | self._check_input_dim(input) 68 | 69 | compute_stats = not self.freezed and \ 70 | self.training and self.track_running_stats 71 | 72 | ret = F.batch_norm(input, self.running_mean, self.running_var, 73 | self.weight, self.bias, compute_stats, 74 | self.momentum, self.eps) 75 | return ret 76 | 77 | def extra_repr(self): 78 | return '{num_features}, eps={eps}, momentum={momentum}, '\ 79 | 'affine={affine}, ' \ 80 | 'track_running_stats={track_running_stats}'.format( 81 | **self.__dict__) 82 | 83 | 84 | class BatchNorm2dNoSync(_BatchNorm): 85 | """ 86 | Equivalent to nn.BatchNorm2d 87 | """ 88 | 89 | def _check_input_dim(self, input): 90 | if input.dim() != 4: 91 | raise ValueError('expected 4D input (got {}D input)' 92 | .format(input.dim())) 93 | 94 | 95 | class BatchNorm2dSync(BatchNorm2dNoSync): 96 | """ 97 | BatchNorm2d with automatic multi-GPU Sync 98 | """ 99 | 100 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, 101 | track_running_stats=True): 102 | super(BatchNorm2dSync, self).__init__( 103 | num_features, eps=eps, momentum=momentum, affine=affine, 104 | track_running_stats=track_running_stats) 105 | self.sync_enabled = True 106 | self.devices = list(range(torch.cuda.device_count())) 107 | if len(self.devices) > 1: 108 | # Initialize queues 109 | self.worker_ids = self.devices[1:] 110 | self.master_queue = Queue(len(self.worker_ids)) 111 | self.worker_queues = [Queue(1) for _ in self.worker_ids] 112 | 113 | def forward(self, x): 114 | compute_stats = not self.freezed and \ 115 | self.training and self.track_running_stats 116 | if self.sync_enabled and compute_stats and len(self.devices) > 1: 117 | if x.get_device() == self.devices[0]: 118 | # Master mode 119 | extra = { 120 | "is_master": True, 121 | "master_queue": self.master_queue, 122 | "worker_queues": self.worker_queues, 123 | "worker_ids": self.worker_ids 124 | } 125 | else: 126 | # Worker mode 127 | extra = { 128 | "is_master": False, 129 | "master_queue": self.master_queue, 130 | "worker_queue": self.worker_queues[ 131 | self.worker_ids.index(x.get_device())] 132 | } 133 | return batchnorm2d_sync(x, self.weight, self.bias, 134 | self.running_mean, self.running_var, 135 | extra, compute_stats, self.momentum, 136 | self.eps) 137 | return super(BatchNorm2dSync, self).forward(x) 138 | 139 | def __repr__(self): 140 | """repr""" 141 | rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \ 142 | 'affine={affine}, ' \ 143 | 'track_running_stats={track_running_stats},' \ 144 | 'devices={devices})' 145 | return rep.format(name=self.__class__.__name__, **self.__dict__) 146 | 147 | #BatchNorm2d = BatchNorm2dNoSync 148 | BatchNorm2d = BatchNorm2dSync 149 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-syncbn 2 | 3 | Tamaki Kojima(tamakoji@gmail.com) 4 | 5 | ## Announcement 6 | 7 | **Pytorch 1.0 support** 8 | 9 | ## Overview 10 | This is alternative implementation of "Synchronized Multi-GPU Batch Normalization" which computes global stats across gpus instead of locally computed. SyncBN are getting important for those input image is large, and must use multi-gpu to increase the minibatch-size for the training. 11 | 12 | The code was inspired by [Pytorch-Encoding](https://github.com/zhanghang1989/PyTorch-Encoding) and [Inplace-ABN](https://github.com/mapillary/inplace_abn) 13 | 14 | ## Remarks 15 | - Unlike [Pytorch-Encoding](https://github.com/zhanghang1989/PyTorch-Encoding), you don't need custom `nn.DataParallel` 16 | - Unlike [Inplace-ABN](https://github.com/mapillary/inplace_abn), you can just replace your `nn.BatchNorm2d` to this module implementation, since it will not mark for inplace operation 17 | - You can plug into arbitrary module written in PyTorch to enable Synchronized BatchNorm 18 | - Backward computation is rewritten and tested against behavior of `nn.BatchNorm2d` 19 | 20 | ## Requirements 21 | For PyTorch, please refer to https://pytorch.org/ 22 | 23 | NOTE : The code is tested only with PyTorch v1.0.0, CUDA10/CuDNN7.4.2 on ubuntu18.04 24 | 25 | It utilize Pytorch JIT mechanism to compile seamlessly, using ninja. Please install ninja-build before use. 26 | 27 | ``` 28 | sudo apt-get install ninja-build 29 | ``` 30 | 31 | Also install all dependencies for python. For pip, run: 32 | 33 | 34 | ``` 35 | pip install -U -r requirements.txt 36 | ``` 37 | 38 | ## Build 39 | 40 | There is no need to build. just run and JIT will take care. 41 | JIT and cpp extensions are supported after PyTorch0.4, however it is highly recommended to use PyTorch > 1.0 due to huge design changes. 42 | 43 | ## Usage 44 | 45 | Please refer to [`test.py`](./test.py) for testing the difference between `nn.BatchNorm2d` and `modules.nn.BatchNorm2d` 46 | 47 | ``` 48 | import torch 49 | from modules import nn as NN 50 | num_gpu = torch.cuda.device_count() 51 | model = nn.Sequential( 52 | nn.Conv2d(3, 3, 1, 1, bias=False), 53 | NN.BatchNorm2d(3), 54 | nn.ReLU(inplace=True), 55 | nn.Conv2d(3, 3, 1, 1, bias=False), 56 | NN.BatchNorm2d(3), 57 | ).cuda() 58 | model = nn.DataParallel(model, device_ids=range(num_gpu)) 59 | x = torch.rand(num_gpu, 3, 2, 2).cuda() 60 | z = model(x) 61 | ``` 62 | 63 | ## Math 64 | 65 | ### Forward 66 | 1. compute in each gpu 67 | 2. gather all from workers to master and compute where 68 | 69 | 70 | 71 | and 72 | 73 | 74 | 75 | and then above global stats to be shared to all gpus, update running_mean and running_var by moving average using global stats. 76 | 77 | 3. forward batchnorm using global stats by 78 | 79 | 80 | 81 | and then 82 | 83 | 84 | 85 | where is weight parameter and is bias parameter. 86 | 87 | 4. save for backward 88 | 89 | ### Backward 90 | 91 | 1. Restore saved 92 | 93 | 2. Compute below sums on each gpu 94 | 95 | 96 | 97 | and 98 | 99 | 100 | 101 | where 102 | 103 | then gather them at master node to sum up global, and normalize with N where N is total number of elements for each channels. Global sums are then shared among all gpus. 104 | 105 | 3. compute gradients using global stats 106 | 107 | 108 | 109 | where 110 | 111 | 112 | 113 | and 114 | 115 | 116 | 117 | and finally, 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | Note that in the implementation, normalization with N is performed at step (2) and above equation and implementation is not exactly the same, but mathematically is same. 126 | 127 | You can go deeper on above explanation at [Kevin Zakka's Blog](https://kevinzakka.github.io/2016/09/14/batch_normalization/) -------------------------------------------------------------------------------- /modules/functional/syncbn.py: -------------------------------------------------------------------------------- 1 | """ 2 | /*****************************************************************************/ 3 | 4 | BatchNorm2dSync with multi-gpu 5 | 6 | code referenced from : https://github.com/mapillary/inplace_abn 7 | 8 | /*****************************************************************************/ 9 | """ 10 | from __future__ import absolute_import 11 | from __future__ import division 12 | from __future__ import print_function 13 | 14 | import torch.cuda.comm as comm 15 | from torch.autograd import Function 16 | from torch.autograd.function import once_differentiable 17 | from ._csrc import _backend 18 | 19 | 20 | def _count_samples(x): 21 | count = 1 22 | for i, s in enumerate(x.size()): 23 | if i != 1: 24 | count *= s 25 | return count 26 | 27 | 28 | class BatchNorm2dSyncFunc(Function): 29 | 30 | @staticmethod 31 | def forward(ctx, x, weight, bias, running_mean, running_var, 32 | extra, compute_stats=True, momentum=0.1, eps=1e-05): 33 | def _parse_extra(ctx, extra): 34 | ctx.is_master = extra["is_master"] 35 | if ctx.is_master: 36 | ctx.master_queue = extra["master_queue"] 37 | ctx.worker_queues = extra["worker_queues"] 38 | ctx.worker_ids = extra["worker_ids"] 39 | else: 40 | ctx.master_queue = extra["master_queue"] 41 | ctx.worker_queue = extra["worker_queue"] 42 | # Save context 43 | if extra is not None: 44 | _parse_extra(ctx, extra) 45 | ctx.compute_stats = compute_stats 46 | ctx.momentum = momentum 47 | ctx.eps = eps 48 | ctx.affine = weight is not None and bias is not None 49 | if ctx.compute_stats: 50 | N = _count_samples(x) * (ctx.master_queue.maxsize + 1) 51 | assert N > 1 52 | # 1. compute sum(x) and sum(x^2) 53 | xsum, xsqsum = _backend.syncbn_sum_sqsum(x.detach()) 54 | if ctx.is_master: 55 | xsums, xsqsums = [xsum], [xsqsum] 56 | # master : gatther all sum(x) and sum(x^2) from slaves 57 | for _ in range(ctx.master_queue.maxsize): 58 | xsum_w, xsqsum_w = ctx.master_queue.get() 59 | ctx.master_queue.task_done() 60 | xsums.append(xsum_w) 61 | xsqsums.append(xsqsum_w) 62 | xsum = comm.reduce_add(xsums) 63 | xsqsum = comm.reduce_add(xsqsums) 64 | mean = xsum / N 65 | sumvar = xsqsum - xsum * mean 66 | var = sumvar / N 67 | uvar = sumvar / (N - 1) 68 | # master : broadcast global mean, variance to all slaves 69 | tensors = comm.broadcast_coalesced( 70 | (mean, uvar, var), [mean.get_device()] + ctx.worker_ids) 71 | for ts, queue in zip(tensors[1:], ctx.worker_queues): 72 | queue.put(ts) 73 | else: 74 | # slave : send sum(x) and sum(x^2) to master 75 | ctx.master_queue.put((xsum, xsqsum)) 76 | # slave : get global mean and variance 77 | mean, uvar, var = ctx.worker_queue.get() 78 | ctx.worker_queue.task_done() 79 | 80 | # Update running stats 81 | running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean) 82 | running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * uvar) 83 | ctx.N = N 84 | ctx.save_for_backward(x, weight, bias, mean, var) 85 | else: 86 | mean, var = running_mean, running_var 87 | 88 | # do batch norm forward 89 | z = _backend.syncbn_forward(x, weight, bias, mean, var, 90 | ctx.affine, ctx.eps) 91 | return z 92 | 93 | @staticmethod 94 | @once_differentiable 95 | def backward(ctx, dz): 96 | x, weight, bias, mean, var = ctx.saved_tensors 97 | dz = dz.contiguous() 98 | 99 | # 1. compute \sum(\frac{dJ}{dy_i}) and \sum(\frac{dJ}{dy_i}*\hat{x_i}) 100 | sum_dz, sum_dz_xhat = _backend.syncbn_backward_xhat( 101 | dz, x, mean, var, ctx.eps) 102 | if ctx.is_master: 103 | sum_dzs, sum_dz_xhats = [sum_dz], [sum_dz_xhat] 104 | # master : gatther from slaves 105 | for _ in range(ctx.master_queue.maxsize): 106 | sum_dz_w, sum_dz_xhat_w = ctx.master_queue.get() 107 | ctx.master_queue.task_done() 108 | sum_dzs.append(sum_dz_w) 109 | sum_dz_xhats.append(sum_dz_xhat_w) 110 | # master : compute global stats 111 | sum_dz = comm.reduce_add(sum_dzs) 112 | sum_dz_xhat = comm.reduce_add(sum_dz_xhats) 113 | sum_dz /= ctx.N 114 | sum_dz_xhat /= ctx.N 115 | # master : broadcast global stats 116 | tensors = comm.broadcast_coalesced( 117 | (sum_dz, sum_dz_xhat), [mean.get_device()] + ctx.worker_ids) 118 | for ts, queue in zip(tensors[1:], ctx.worker_queues): 119 | queue.put(ts) 120 | else: 121 | # slave : send to master 122 | ctx.master_queue.put((sum_dz, sum_dz_xhat)) 123 | # slave : get global stats 124 | sum_dz, sum_dz_xhat = ctx.worker_queue.get() 125 | ctx.worker_queue.task_done() 126 | 127 | # do batch norm backward 128 | dx, dweight, dbias = _backend.syncbn_backward( 129 | dz, x, weight, bias, mean, var, sum_dz, sum_dz_xhat, 130 | ctx.affine, ctx.eps) 131 | 132 | return dx, dweight, dbias, \ 133 | None, None, None, None, None, None 134 | 135 | batchnorm2d_sync = BatchNorm2dSyncFunc.apply 136 | 137 | __all__ = ["batchnorm2d_sync"] 138 | -------------------------------------------------------------------------------- /modules/functional/csrc/cuda/bn_cuda.cu: -------------------------------------------------------------------------------- 1 | /***************************************************************************** 2 | 3 | CUDA SyncBN code 4 | 5 | code referenced from : https://github.com/mapillary/inplace_abn 6 | 7 | *****************************************************************************/ 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include "cuda/common.h" 13 | 14 | // Utilities 15 | void get_dims(at::Tensor x, int64_t &num, int64_t &chn, int64_t &sp) { 16 | num = x.size(0); 17 | chn = x.size(1); 18 | sp = 1; 19 | for (int64_t i = 2; i < x.ndimension(); ++i) sp *= x.size(i); 20 | } 21 | 22 | /// SyncBN 23 | 24 | template 25 | struct SqSumOp { 26 | __device__ SqSumOp(const T *t, int c, int s) : tensor(t), chn(c), sp(s) {} 27 | __device__ __forceinline__ Pair operator()(int batch, int plane, int n) { 28 | T x = tensor[(batch * chn + plane) * sp + n]; 29 | return Pair(x, x * x); // x, x^2 30 | } 31 | const T *tensor; 32 | const int chn; 33 | const int sp; 34 | }; 35 | 36 | template 37 | __global__ void syncbn_sum_sqsum_kernel(const T *x, T *sum, T *sqsum, 38 | int num, int chn, int sp) { 39 | int plane = blockIdx.x; 40 | Pair res = 41 | reduce, SqSumOp>(SqSumOp(x, chn, sp), plane, num, chn, sp); 42 | __syncthreads(); 43 | if (threadIdx.x == 0) { 44 | sum[plane] = res.v1; 45 | sqsum[plane] = res.v2; 46 | } 47 | } 48 | 49 | std::vector syncbn_sum_sqsum_cuda(const at::Tensor &x) { 50 | CHECK_INPUT(x); 51 | 52 | // Extract dimensions 53 | int64_t num, chn, sp; 54 | get_dims(x, num, chn, sp); 55 | 56 | // Prepare output tensors 57 | auto sum = at::empty({chn}, x.options()); 58 | auto sqsum = at::empty({chn}, x.options()); 59 | 60 | // Run kernel 61 | dim3 blocks(chn); 62 | dim3 threads(getNumThreads(sp)); 63 | AT_DISPATCH_FLOATING_TYPES( 64 | x.type(), "syncbn_sum_sqsum_cuda", ([&] { 65 | syncbn_sum_sqsum_kernel<<>>( 66 | x.data(), sum.data(), 67 | sqsum.data(), num, chn, sp); 68 | })); 69 | return {sum, sqsum}; 70 | } 71 | 72 | template 73 | __global__ void syncbn_forward_kernel(T *z, const T *x, const T *weight, 74 | const T *bias, const T *mean, 75 | const T *var, bool affine, float eps, 76 | int num, int chn, int sp) { 77 | int plane = blockIdx.x; 78 | T _mean = mean[plane]; 79 | T _var = var[plane]; 80 | T _weight = affine ? weight[plane] : T(1); 81 | T _bias = affine ? bias[plane] : T(0); 82 | float _invstd = T(0); 83 | if (_var || eps) { 84 | _invstd = rsqrt(_var + eps); 85 | } 86 | for (int batch = 0; batch < num; ++batch) { 87 | for (int n = threadIdx.x; n < sp; n += blockDim.x) { 88 | T _x = x[(batch * chn + plane) * sp + n]; 89 | T _xhat = (_x - _mean) * _invstd; 90 | T _z = _xhat * _weight + _bias; 91 | z[(batch * chn + plane) * sp + n] = _z; 92 | } 93 | } 94 | } 95 | 96 | at::Tensor syncbn_forward_cuda(const at::Tensor &x, const at::Tensor &weight, 97 | const at::Tensor &bias, const at::Tensor &mean, 98 | const at::Tensor &var, bool affine, float eps) { 99 | CHECK_INPUT(x); 100 | CHECK_INPUT(weight); 101 | CHECK_INPUT(bias); 102 | CHECK_INPUT(mean); 103 | CHECK_INPUT(var); 104 | 105 | // Extract dimensions 106 | int64_t num, chn, sp; 107 | get_dims(x, num, chn, sp); 108 | 109 | auto z = at::zeros_like(x); 110 | 111 | // Run kernel 112 | dim3 blocks(chn); 113 | dim3 threads(getNumThreads(sp)); 114 | AT_DISPATCH_FLOATING_TYPES( 115 | x.type(), "syncbn_forward_cuda", ([&] { 116 | syncbn_forward_kernel<<>>( 117 | z.data(), x.data(), 118 | weight.data(), bias.data(), 119 | mean.data(), var.data(), 120 | affine, eps, num, chn, sp); 121 | })); 122 | return z; 123 | } 124 | 125 | template 126 | struct XHatOp { 127 | __device__ XHatOp(T _weight, T _bias, const T *_dz, const T *_x, int c, int s) 128 | : weight(_weight), bias(_bias), x(_x), dz(_dz), chn(c), sp(s) {} 129 | __device__ __forceinline__ Pair operator()(int batch, int plane, int n) { 130 | // xhat = (x - bias) * weight 131 | T _xhat = (x[(batch * chn + plane) * sp + n] - bias) * weight; 132 | // dxhat * x_hat 133 | T _dz = dz[(batch * chn + plane) * sp + n]; 134 | return Pair(_dz, _dz * _xhat); 135 | } 136 | const T weight; 137 | const T bias; 138 | const T *dz; 139 | const T *x; 140 | const int chn; 141 | const int sp; 142 | }; 143 | 144 | template 145 | __global__ void syncbn_backward_xhat_kernel(const T *dz, const T *x, 146 | const T *mean, const T *var, 147 | T *sum_dz, T *sum_dz_xhat, 148 | float eps, int num, int chn, 149 | int sp) { 150 | int plane = blockIdx.x; 151 | T _mean = mean[plane]; 152 | T _var = var[plane]; 153 | T _invstd = T(0); 154 | if (_var || eps) { 155 | _invstd = rsqrt(_var + eps); 156 | } 157 | Pair res = reduce, XHatOp>( 158 | XHatOp(_invstd, _mean, dz, x, chn, sp), plane, num, chn, sp); 159 | __syncthreads(); 160 | if (threadIdx.x == 0) { 161 | // \sum(\frac{dJ}{dy_i}) 162 | sum_dz[plane] = res.v1; 163 | // \sum(\frac{dJ}{dy_i}*\hat{x_i}) 164 | sum_dz_xhat[plane] = res.v2; 165 | } 166 | } 167 | 168 | std::vector syncbn_backward_xhat_cuda(const at::Tensor &dz, 169 | const at::Tensor &x, 170 | const at::Tensor &mean, 171 | const at::Tensor &var, 172 | float eps) { 173 | CHECK_INPUT(dz); 174 | CHECK_INPUT(x); 175 | CHECK_INPUT(mean); 176 | CHECK_INPUT(var); 177 | // Extract dimensions 178 | int64_t num, chn, sp; 179 | get_dims(x, num, chn, sp); 180 | // Prepare output tensors 181 | auto sum_dz = at::empty({chn}, x.options()); 182 | auto sum_dz_xhat = at::empty({chn}, x.options()); 183 | // Run kernel 184 | dim3 blocks(chn); 185 | dim3 threads(getNumThreads(sp)); 186 | AT_DISPATCH_FLOATING_TYPES( 187 | x.type(), "syncbn_backward_xhat_cuda", ([&] { 188 | syncbn_backward_xhat_kernel<<>>( 189 | dz.data(), x.data(), mean.data(), 190 | var.data(), sum_dz.data(), 191 | sum_dz_xhat.data(), eps, num, chn, sp); 192 | })); 193 | return {sum_dz, sum_dz_xhat}; 194 | } 195 | 196 | template 197 | __global__ void syncbn_backward_kernel(const T *dz, const T *x, const T *weight, 198 | const T *bias, const T *mean, 199 | const T *var, const T *sum_dz, 200 | const T *sum_dz_xhat, T *dx, T *dweight, 201 | T *dbias, bool affine, float eps, 202 | int num, int chn, int sp) { 203 | int plane = blockIdx.x; 204 | T _mean = mean[plane]; 205 | T _var = var[plane]; 206 | T _weight = affine ? weight[plane] : T(1); 207 | T _sum_dz = sum_dz[plane]; 208 | T _sum_dz_xhat = sum_dz_xhat[plane]; 209 | T _invstd = T(0); 210 | if (_var || eps) { 211 | _invstd = rsqrt(_var + eps); 212 | } 213 | /* 214 | \frac{dJ}{dx_i} = \frac{1}{N\sqrt{(\sigma^2+\epsilon)}} ( 215 | N\frac{dJ}{d\hat{x_i}} - 216 | \sum_{j=1}^{N}(\frac{dJ}{d\hat{x_j}}) - 217 | \hat{x_i}\sum_{j=1}^{N}(\frac{dJ}{d\hat{x_j}}\hat{x_j}) 218 | ) 219 | Note : N is omitted here since it will be accumulated and 220 | _sum_dz and _sum_dz_xhat expected to be already normalized 221 | before the call. 222 | */ 223 | if (dx) { 224 | T _mul = _weight * _invstd; 225 | for (int batch = 0; batch < num; ++batch) { 226 | for (int n = threadIdx.x; n < sp; n += blockDim.x) { 227 | T _dz = dz[(batch * chn + plane) * sp + n]; 228 | T _xhat = (x[(batch * chn + plane) * sp + n] - _mean) * _invstd; 229 | T _dx = (_dz - _sum_dz - _xhat * _sum_dz_xhat) * _mul; 230 | dx[(batch * chn + plane) * sp + n] = _dx; 231 | } 232 | } 233 | } 234 | __syncthreads(); 235 | if (threadIdx.x == 0) { 236 | if (affine) { 237 | T _norm = num * sp; 238 | dweight[plane] += _sum_dz_xhat * _norm; 239 | dbias[plane] += _sum_dz * _norm; 240 | } 241 | } 242 | } 243 | 244 | std::vector syncbn_backward_cuda( 245 | const at::Tensor &dz, const at::Tensor &x, const at::Tensor &weight, 246 | const at::Tensor &bias, const at::Tensor &mean, const at::Tensor &var, 247 | const at::Tensor &sum_dz, const at::Tensor &sum_dz_xhat, bool affine, 248 | float eps) { 249 | CHECK_INPUT(dz); 250 | CHECK_INPUT(x); 251 | CHECK_INPUT(weight); 252 | CHECK_INPUT(bias); 253 | CHECK_INPUT(mean); 254 | CHECK_INPUT(var); 255 | CHECK_INPUT(sum_dz); 256 | CHECK_INPUT(sum_dz_xhat); 257 | 258 | // Extract dimensions 259 | int64_t num, chn, sp; 260 | get_dims(x, num, chn, sp); 261 | 262 | // Prepare output tensors 263 | auto dx = at::zeros_like(dz); 264 | auto dweight = at::zeros_like(weight); 265 | auto dbias = at::zeros_like(bias); 266 | 267 | // Run kernel 268 | dim3 blocks(chn); 269 | dim3 threads(getNumThreads(sp)); 270 | AT_DISPATCH_FLOATING_TYPES( 271 | x.type(), "syncbn_backward_cuda", ([&] { 272 | syncbn_backward_kernel<<>>( 273 | dz.data(), x.data(), weight.data(), 274 | bias.data(), mean.data(), var.data(), 275 | sum_dz.data(), sum_dz_xhat.data(), 276 | dx.data(), dweight.data(), 277 | dbias.data(), affine, eps, num, chn, sp); 278 | })); 279 | return {dx, dweight, dbias}; 280 | } --------------------------------------------------------------------------------