├── README.md ├── __init__.py ├── func.py ├── parallel.py ├── sync.py ├── sync_bn.py ├── test.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Synchronized-BatchNormalization # 2 | 3 | ***Multi-Gpus Synchronized Batch Normalization implementation in PyTorch*** 4 | 5 | ---------- 6 | 7 | ## Introduction ## 8 | 9 | This module is a synchronized version of Batch Normalization when using multi-gpus for deep learning, aka 'Syn-BN', as the mean and standard-deviation are reduced across all devices during training. 10 | 11 | Traditionally, when using 'nn.DataParallel' to wrap module during training, the built-in PyTorch BatchNorm normalize the tensor on each device using the statistics only on that device, thus the statistics might be inaccurate. 12 | 13 | Instead, in this synchronized version, the statistics will be computed over all training samples distributed on each devices. 14 | 15 | Besides, in single-gpu or cpu-only case, this module behaves exactly same as the built-in PyTorch implementation. 16 | 17 | Note that this module may exist some design problems, if you have any questions or suggestions, please feel free to open an issue or submit a pull request, let's make it better! 18 | 19 | ---------- 20 | 21 | 22 | ## Why Syn-BN ? ## 23 | 24 | Usually, the working batch-size is typically large enough to obtain good statistics for some computer vision tasks, such as classification and detection, thus there is no need to synchronize BN layer during the training, while synchronization will slow down the training. 25 | 26 | However, for the other computer vision tasks, such as semantic segmentation, which belongs to dense prediction problem, is very memory consuming, the working bath-size is usually very small(typically 2 or 4 in each GPU), thus it will hurt the performance without synchronization. 27 | 28 | (*The importance of synchronized batch normalization in object detection has been proved with an extensive analysis in the paper [https://arxiv.org/abs/1711.07240](https://arxiv.org/abs/1711.07240 "MegDet: A Large Mini-Batch Object Detector")*) 29 | 30 | ---------- 31 | 32 | ## How to use ? ## 33 | 34 | To use the Syn-BN, I customize a data parallel wrapper named 'DataParallelWithCallBack', which inherits nn.DataParallel, it will call a callback function when in data parallel replication. This introduces a slight difference with typical usage of the nn.DataParallel. 35 | 36 | Use it with a provided, customized data parallel wrapper: 37 | 38 | from sync import DataParallelWithCallBack 39 | from sync_bn import SynchronizedBatchNorm2d 40 | 41 | sync_bn = SynchronizedBatchNorm2d( 42 | num_features=3, eps=1e-5, momentum=0.1, affine=True, sync_timeout=15. 43 | ) 44 | sync_bn = DataParallelWithCallBack(sync_bn, device_ids=[0, 1]) 45 | sync_bn.to(device) 46 | 47 | Or, if you have already defined a model wrapped in nn.DataParallel like: 48 | 49 | from torchvision import models 50 | 51 | m = models.resnet50(pretrained=True) 52 | m = nn.DataParallel(m, device_ids=[0,1]) 53 | m.to(device) 54 | 55 | then you can use the method 'convert_model' to convert your model to use Syn-BN easily: 56 | 57 | from func import convert_model 58 | 59 | m = convert_model(m) 60 | 61 | this will change all BNs into Syn-BNs which is contained in your model. 62 | 63 | ---------- 64 | 65 | ## Author ## 66 | 67 | chrisway(cw), 2020. 68 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # This file is part of Synchronized-BatchNorm-PyTorch. 3 | 4 | __all__ = ( 5 | 'SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d', 'DataParallelWithCallBack', 6 | 'convert_model', 'patch_replication_callback' 7 | ) 8 | 9 | from .func import * 10 | from .sync_bn import * 11 | from .parallel import * 12 | -------------------------------------------------------------------------------- /func.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : func.py 3 | # Author : CW 4 | # Email : chrisway613@gmail.com 5 | # Date : 21/01/2020 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | 9 | __all__ = ('patch_replication_callback', 'convert_model') 10 | 11 | # + 12 | try: 13 | from sync_bn import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 14 | from parallel import DataParallelWithCallBack 15 | except ImportError: 16 | from .sync_bn import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 17 | from .parallel import DataParallelWithCallBack 18 | 19 | from torch.nn import DataParallel, BatchNorm1d, BatchNorm2d, BatchNorm3d 20 | # - 21 | 22 | import torch 23 | import functools 24 | 25 | 26 | def convert_model(module): 27 | """ 28 | Convert input module and its child recursively. 29 | :param module: the input module needs to be convert to SyncBN model; 30 | :return: 31 | Examples: 32 | >>> import torch.nn as nn 33 | >>> import torchvision 34 | >>> # m is a standard pytorch model 35 | >>> m = torchvision.models.resnet18(True) 36 | >>> m = nn.DataParallel(m) 37 | >>> # after convert, m is using SyncBN 38 | >>> m = convert_model(m) 39 | """ 40 | 41 | def _convert(mod_old): 42 | if 'BatchNorm' not in type(mod_old).__name__: 43 | return mod_old 44 | 45 | mod_new = mod_old 46 | for pth_module, sync_module in zip( 47 | [BatchNorm1d, 48 | BatchNorm2d, 49 | BatchNorm3d], 50 | [SynchronizedBatchNorm1d, 51 | SynchronizedBatchNorm2d, 52 | SynchronizedBatchNorm3d] 53 | ): 54 | if isinstance(mod_old, pth_module): 55 | mod_new = sync_module(mod_old.num_features, mod_old.eps, mod_old.momentum, mod_old.affine) 56 | mod_new.running_mean = mod_old.running_mean 57 | mod_new.running_var = mod_old.running_var 58 | 59 | if mod_old.affine: 60 | mod_new.weight.data = mod_old.weight.data.clone().detach() 61 | mod_new.bias.data = mod_old.bias.data.clone().detach() 62 | 63 | return mod_new 64 | 65 | if isinstance(module, torch.nn.DataParallel): 66 | # Top model inside DataParallel. 67 | mod = module.module 68 | mod = convert_model(mod) 69 | mod = DataParallelWithCallBack(mod, device_ids=module.device_ids) 70 | 71 | return mod 72 | 73 | mod_cvt = _convert(module) 74 | for name, child in module.named_children(): 75 | mod_cvt.add_module(name, _convert(child)) 76 | 77 | return mod_cvt 78 | 79 | 80 | def patch_replication_callback(data_parallel): 81 | """ 82 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 83 | Useful when you have customized `DataParallel` implementation. 84 | 85 | Examples: 86 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 87 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 88 | > patch_replication_callback(sync_bn) 89 | # this is equivalent to 90 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 91 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 92 | """ 93 | 94 | assert isinstance(data_parallel, DataParallel) 95 | old_replicate = data_parallel.replicate 96 | 97 | @functools.wraps(old_replicate) 98 | def new_replicate(module, device_ids): 99 | replicas = old_replicate(module, device_ids) 100 | # execute_replication_callbacks(modules) 101 | DataParallelWithCallBack._callback(replicas) 102 | 103 | return replicas 104 | 105 | data_parallel.replicate = new_replicate 106 | -------------------------------------------------------------------------------- /parallel.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : parallel.py 3 | # Author : CW 4 | # Email : chrisway613@gmail.com 5 | # Date : 21/01/2020 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | 9 | __all__ = ('DataParallelWithCallBack',) 10 | 11 | from torch.nn.parallel.data_parallel import DataParallel 12 | 13 | 14 | class DataParallelContext: 15 | """ 16 | Context data structure for data parallel. 17 | Multiple copies of a module on different devices share the same context, 18 | Thus with this context, different copies can share some information. 19 | """ 20 | def __init__(self): 21 | self.sync_master = None 22 | 23 | 24 | class DataParallelWithCallBack(DataParallel): 25 | """ 26 | Data Parallel with a replication callback. 27 | 28 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 29 | original `replicate` function. 30 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 31 | 32 | Examples: 33 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 34 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 35 | # sync_bn.sync_replicas will be invoked. 36 | """ 37 | @classmethod 38 | def _callback(cls, replicas): 39 | master_copy = replicas[0] 40 | replicas_ctx = [DataParallelContext() for _ in master_copy.modules()] 41 | 42 | for copy_id, module_replicated in enumerate(replicas): 43 | for idx, m in enumerate(module_replicated.modules()): 44 | if 'SynchronizedBatchNorm' in type(m).__name__ and hasattr(m, '_sync_replicas'): 45 | m._sync_replicas(replicas_ctx[idx], copy_id) 46 | 47 | def __init__(self, module, device_ids=None, output_device=None, dim=0): 48 | """ 49 | Initialization. 50 | :param module: module to be parallelized; 51 | :param device_ids: CUDA devices (default: all devices); 52 | :param output_device: device location of output (default: device_ids[0]); 53 | :param dim: dim of input data to be scattered & gathered. 54 | """ 55 | super(DataParallelWithCallBack, self).__init__( 56 | module, device_ids, output_device, dim 57 | ) 58 | 59 | def replicate(self, module, device_ids): 60 | """ 61 | Replication with callback. 62 | :param module: (nn.Module) module to be parallelized; 63 | :param device_ids: (list of int or torch.device) CUDA devices (default: all devices); 64 | :return: module replicated on each device. 65 | """ 66 | replicas = super(DataParallelWithCallBack, self).replicate(module, device_ids) 67 | self._callback(replicas) 68 | 69 | return replicas 70 | 71 | def forward(self, *inputs, **kwargs): 72 | """ 73 | Note that this method will invoke the methods as below(in order): 74 | i). self.scatter; 75 | ii). self.replicate; 76 | iii). self.parallel_apply; 77 | iv). self.gather 78 | """ 79 | return super(DataParallelWithCallBack, self).forward(*inputs, **kwargs) 80 | -------------------------------------------------------------------------------- /sync.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : sync.py 3 | # Author : CW 4 | # Email : chrisway613@gmail.com 5 | # Date : 21/01/2020 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | 9 | __all__ = ('FutureResult', 'SlavePipe', 'SyncMaster') 10 | 11 | try: 12 | from utils import * 13 | except ImportError: 14 | from .utils import * 15 | 16 | import time 17 | import queue 18 | import collections 19 | 20 | 21 | _Registry = collections.namedtuple('_Registry', ('result',)) 22 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ('identifier', 'queue', 'result')) 23 | 24 | 25 | class SlavePipe(_SlavePipeBase): 26 | """Pipe for master <=> slave communication.""" 27 | def run_slave(self, msg): 28 | # Put msg to the queue which shared with master & all other slave copies. 29 | self.queue.put((self.identifier, msg)) 30 | # Get result from master 31 | ret = self.result.get() 32 | # Notify master that result is already got. 33 | self.queue.put(True) 34 | 35 | return ret 36 | 37 | 38 | class SyncMaster: 39 | """An abstract `SyncMaster` object. 40 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 41 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 42 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 43 | and passed to a registered callback. 44 | - After receiving the messages, the master device should gather the information and determine to message passed 45 | back to each slave devices. 46 | """ 47 | def __init__(self, callback=None, sync_timeout=15.): 48 | """ 49 | Args: 50 | callback: a callback method to be invoked after having collected messages from slave devices. 51 | """ 52 | self._callback = callback 53 | self._sync_timeout = sync_timeout 54 | 55 | self._activated = False 56 | self._queue = queue.Queue() 57 | self._registry = collections.OrderedDict() 58 | 59 | @property 60 | def num_slaves(self): 61 | return len(self._registry) 62 | 63 | def register_slave(self, identifier): 64 | """ 65 | Register an slave device. 66 | The 'future' data structure stores slave's results; 67 | The '_registry' attribute records the mapping relation between slave's copy id & results; 68 | Master & its all copies share the same queue. 69 | 70 | Args: 71 | identifier: an identifier, usually is the device id. 72 | 73 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 74 | """ 75 | if self._activated: 76 | # assert self._queue.empty(), 'Queue is not cleaned before next initialization!' 77 | self._queue.queue.clear() 78 | self._activated = False 79 | self._registry.clear() 80 | 81 | future = FutureResult(wait_timeout=2*self._sync_timeout) 82 | self._registry[identifier] = _Registry(future) 83 | 84 | return SlavePipe(identifier, self._queue, future) 85 | 86 | def run_master(self, msg): 87 | """ 88 | Main entry for the master device in each forward pass. 89 | The messages were first collected from each devices (including the master device), and then 90 | an callback will be invoked to compute the message to be sent back to each devices 91 | (including the master device). 92 | 93 | Note that if timeout occurred, this method will not be invoked. 94 | 95 | Args: 96 | msg: the message that the master want to send to itself. This will be placed as the first 97 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 98 | 99 | Returns: the message to be sent back to the master device. 100 | 101 | """ 102 | self._activated = True 103 | 104 | intermediates = [(0, msg)] 105 | prev_time = time.time() 106 | # Until gather all slaves' msg or timeout occurred. 107 | while self._queue.qsize() != self.num_slaves: 108 | cur_time = time.time() 109 | time_used = cur_time - prev_time 110 | 111 | if time_used > self._sync_timeout: 112 | return None 113 | 114 | intermediates.extend([self._queue.get() for _ in range(self.num_slaves)]) 115 | # print("intermediates: ", intermediates) 116 | results = self._callback(intermediates) 117 | # print(results) 118 | assert results[0][0] == 0, 'The first result should belongs to the master!' 119 | 120 | # results[0] belongs to master 121 | for i, res in results[1:]: 122 | # Return result to slave. 123 | self._registry[i].result.put(res) 124 | 125 | # Checkout whether slave has already got the result. 126 | for i in range(self.num_slaves): 127 | assert self._queue.get() is True 128 | 129 | # Return the result to master which belongs to itself. 130 | return results[0][1] 131 | -------------------------------------------------------------------------------- /sync_bn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : sync_bn.py 3 | # Author : CW 4 | # Email : chrisway613@gmail.com 5 | # Date : 21/01/2020 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | 9 | __all__ = ( 10 | 'SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d' 11 | ) 12 | 13 | try: 14 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast 15 | except ImportError: 16 | ReduceAddCoalesced = Broadcast = None 17 | from torch.nn.modules.batchnorm import _BatchNorm 18 | 19 | try: 20 | from utils import * 21 | from sync import SyncMaster 22 | from parallel import DataParallelWithCallBack 23 | except ImportError: 24 | from .utils import * 25 | from .sync import SyncMaster 26 | from .parallel import DataParallelWithCallBack 27 | 28 | import collections 29 | 30 | import torch 31 | import torch.nn.functional as F 32 | 33 | 34 | _MessageToCollect = collections.namedtuple('_ChildMessage', ('sum', 'ssum', 'sum_size')) 35 | _MessageToBroadcast = collections.namedtuple('_MasterMessage', ('mean', 'inv_std')) 36 | 37 | 38 | class _SynchronizedBatchNorm(_BatchNorm): 39 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, sync_timeout=15.): 40 | assert ReduceAddCoalesced is not None, 'Can not use Synchronized Batch Normalization without CUDA support.' 41 | 42 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) 43 | 44 | self._is_parallel = False 45 | self._parallel_id = None 46 | 47 | self._sync_master = SyncMaster(callback=self._coalesce_and_compute, sync_timeout=sync_timeout) 48 | self._slave_pipe = None 49 | 50 | @property 51 | def _is_master(self): 52 | assert self._parallel_id is not None, "parallel replicate method should be executed first!" 53 | return self._parallel_id == 0 54 | 55 | def forward(self, inputs): 56 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. 57 | if not (self._is_parallel and self.training): 58 | return F.batch_norm( 59 | inputs, self.running_mean, self.running_var, self.weight, self.bias, 60 | self.training, self.momentum, self.eps 61 | ) 62 | 63 | inputs_shape = inputs.shape 64 | # Reshape to (N, C, -1), whereas N is batch size, C is number of features/classes. 65 | inputs = inputs.reshape(inputs_shape[0], self.num_features, -1) 66 | # Compute the sum and square-sum. 67 | sum_size = inputs.size(0) * inputs.size(2) 68 | input_sum = sum_ft(inputs) 69 | input_ssum = sum_ft(inputs ** 2) 70 | # Master will collect message as below from all copies. 71 | msg = _MessageToCollect(input_sum, input_ssum, sum_size) 72 | # Reduce & broadcast the statistics. 73 | if self._is_master: 74 | # print("run master\n") 75 | result = self._sync_master.run_master(msg) 76 | 77 | # When timeout occurred during synchronizing with slaves, 78 | # the result will be None, 79 | # then use PyTorch's implementation. 80 | if result is None: 81 | return F.batch_norm( 82 | inputs, self.running_mean, self.running_var, self.weight, self.bias, 83 | self.training, self.momentum, self.eps 84 | ) 85 | else: 86 | mean, inv_std = result 87 | else: 88 | # print("run slave\n") 89 | result_from_master = self._slave_pipe.run_slave(msg) 90 | 91 | # When timeout occurred during synchronizing with master, 92 | # the result from master will be None, 93 | # then use PyTorch's implementation. 94 | if result_from_master is None: 95 | return F.batch_norm( 96 | inputs, self.running_mean, self.running_var, self.weight, self.bias, 97 | self.training, self.momentum, self.eps 98 | ) 99 | else: 100 | mean, inv_std = result_from_master 101 | 102 | # Compute the output. 103 | if self.affine: 104 | outputs = (inputs - unsqueeze_ft(mean)) * unsqueeze_ft(inv_std * self.weight) + unsqueeze_ft(self.bias) 105 | else: 106 | outputs = (inputs - unsqueeze_ft(mean)) * unsqueeze_ft(inv_std) 107 | 108 | # Reshape to original input shape 109 | return outputs.reshape(inputs_shape) 110 | 111 | def _sync_replicas(self, ctx, copy_id): 112 | """ 113 | Synchronize all copies from a module. 114 | :param ctx: a context data structure for communication; 115 | :param copy_id: id of a copied module (usually the device id). 116 | :return: 117 | """ 118 | self._is_parallel = True 119 | self._parallel_id = copy_id 120 | 121 | # parallel_id == 0 means master device 122 | if self._parallel_id == 0: 123 | ctx.sync_master = self._sync_master 124 | else: 125 | self._slave_pipe = ctx.sync_master.register_slave(copy_id) 126 | 127 | def _coalesce_and_compute(self, intermediates): 128 | """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" 129 | 130 | # Ensure that master being the first one. 131 | intermediates = sorted(intermediates, key=lambda i: i[0]) 132 | 133 | # Get sum & square sum of from every device. 134 | to_reduce = [i[1][:2] for i in intermediates] 135 | # Flatten 136 | to_reduce = [j for i in to_reduce for j in i] 137 | # Size of data from every device. 138 | sum_size = sum([i[1].sum_size for i in intermediates]) 139 | # Device of every copies 140 | target_gpus = [i[1].sum.get_device() for i in intermediates] 141 | # print("target gpus: ", target_gpus) 142 | 143 | # Add all sum & square sum individually from every copies, 144 | # and put the result to the master device. 145 | # 2 means that has 2 types input data. 146 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) 147 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) 148 | # Copied results for every device that to broadcasted. 149 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std) 150 | # print("broadcasted: ", broadcasted) 151 | 152 | outputs = [] 153 | for i, rec in enumerate(intermediates): 154 | outputs.append((rec[0], _MessageToBroadcast(*broadcasted[i*2:i*2+2]))) 155 | 156 | # print("outputs: ", outputs) 157 | return outputs 158 | 159 | def _compute_mean_std(self, sum_, ssum, size): 160 | """ 161 | Compute the mean and standard-deviation with sum and square-sum. This method 162 | also maintains the moving average on the master device. 163 | """ 164 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1!' 165 | 166 | def _compute(): 167 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data 168 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data 169 | 170 | mean = sum_ / size 171 | sum_var = ssum - sum_ * mean 172 | unbias_var = sum_var / (size - 1) 173 | bias_var = sum_var / size 174 | 175 | if hasattr(torch, 'no_grad'): 176 | with torch.no_grad(): 177 | _compute() 178 | else: 179 | _compute() 180 | 181 | return mean, bias_var.clamp(self.eps) ** -.5 182 | 183 | 184 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): 185 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a 186 | mini-batch. 187 | 188 | .. math:: 189 | 190 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 191 | 192 | This module differs from the built-in PyTorch BatchNorm1d as the mean and 193 | standard-deviation are reduced across all devices during training. 194 | 195 | For example, when one uses `nn.DataParallel` to wrap the network during 196 | training, PyTorch's implementation normalize the tensor on each device using 197 | the statistics only on that device, which accelerated the computation and 198 | is also easy to implement, but the statistics might be inaccurate. 199 | Instead, in this synchronized version, the statistics will be computed 200 | over all training samples distributed on multiple devices. 201 | 202 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 203 | as the built-in PyTorch implementation. 204 | 205 | The mean and standard-deviation are calculated per-dimension over 206 | the mini-batches and gamma and beta are learnable parameter vectors 207 | of size C (where C is the input size). 208 | 209 | During training, this layer keeps a running estimate of its computed mean 210 | and variance. The running sum is kept with a default momentum of 0.1. 211 | 212 | During evaluation, this running mean/variance is used for normalization. 213 | 214 | Because the BatchNorm is done over the `C` dimension, computing statistics 215 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm 216 | 217 | Args: 218 | num_features: num_features from an expected input of size 219 | `batch_size x num_features [x width]` 220 | eps: a value added to the denominator for numerical stability. 221 | Default: 1e-5 222 | momentum: the value used for the running_mean and running_var 223 | computation. Default: 0.1 224 | affine: a boolean value that when set to ``True``, gives the layer learnable 225 | affine parameters. Default: ``True`` 226 | 227 | Shape:: 228 | - Input: :math:`(N, C)` or :math:`(N, C, L)` 229 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) 230 | 231 | Examples: 232 | >>> # With Learnable Parameters 233 | >>> m = SynchronizedBatchNorm1d(100) 234 | >>> # Without Learnable Parameters 235 | >>> m = SynchronizedBatchNorm1d(100, affine=False) 236 | >>> inputs = torch.autograd.Variable(torch.randn(20, 100)) 237 | >>> output = m(inputs) 238 | """ 239 | 240 | def _check_input_dim(self, input): 241 | if input.dim() != 2 and input.dim() != 3: 242 | raise ValueError( 243 | 'expected 2D or 3D input (got {}D input)'.format(input.dim()) 244 | ) 245 | 246 | 247 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): 248 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch 249 | of 3d inputs 250 | 251 | .. math:: 252 | 253 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 254 | 255 | This module differs from the built-in PyTorch BatchNorm2d as the mean and 256 | standard-deviation are reduced across all devices during training. 257 | 258 | For example, when one uses `nn.DataParallel` to wrap the network during 259 | training, PyTorch's implementation normalize the tensor on each device using 260 | the statistics only on that device, which accelerated the computation and 261 | is also easy to implement, but the statistics might be inaccurate. 262 | Instead, in this synchronized version, the statistics will be computed 263 | over all training samples distributed on multiple devices. 264 | 265 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 266 | as the built-in PyTorch implementation. 267 | 268 | The mean and standard-deviation are calculated per-dimension over 269 | the mini-batches and gamma and beta are learnable parameter vectors 270 | of size C (where C is the input size). 271 | 272 | During training, this layer keeps a running estimate of its computed mean 273 | and variance. The running sum is kept with a default momentum of 0.1. 274 | 275 | During evaluation, this running mean/variance is used for normalization. 276 | 277 | Because the BatchNorm is done over the `C` dimension, computing statistics 278 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm 279 | 280 | Args: 281 | num_features: num_features from an expected input of 282 | size batch_size x num_features x height x width 283 | eps: a value added to the denominator for numerical stability. 284 | Default: 1e-5 285 | momentum: the value used for the running_mean and running_var 286 | computation. Default: 0.1 287 | affine: a boolean value that when set to ``True``, gives the layer learnable 288 | affine parameters. Default: ``True`` 289 | 290 | Shape:: 291 | - Input: :math:`(N, C, H, W)` 292 | - Output: :math:`(N, C, H, W)` (same shape as input) 293 | 294 | Examples: 295 | >>> # With Learnable Parameters 296 | >>> m = SynchronizedBatchNorm2d(100) 297 | >>> # Without Learnable Parameters 298 | >>> m = SynchronizedBatchNorm2d(100, affine=False) 299 | >>> inputs = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) 300 | >>> outputs = m(inputs) 301 | """ 302 | 303 | def _check_input_dim(self, input): 304 | if input.dim() != 4: 305 | raise ValueError( 306 | 'expected 4D input (got {}D input)'.format(input.dim()) 307 | ) 308 | 309 | 310 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): 311 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch 312 | of 4d inputs 313 | 314 | .. math:: 315 | 316 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 317 | 318 | This module differs from the built-in PyTorch BatchNorm3d as the mean and 319 | standard-deviation are reduced across all devices during training. 320 | 321 | For example, when one uses `nn.DataParallel` to wrap the network during 322 | training, PyTorch's implementation normalize the tensor on each device using 323 | the statistics only on that device, which accelerated the computation and 324 | is also easy to implement, but the statistics might be inaccurate. 325 | Instead, in this synchronized version, the statistics will be computed 326 | over all training samples distributed on multiple devices. 327 | 328 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 329 | as the built-in PyTorch implementation. 330 | 331 | The mean and standard-deviation are calculated per-dimension over 332 | the mini-batches and gamma and beta are learnable parameter vectors 333 | of size C (where C is the input size). 334 | 335 | During training, this layer keeps a running estimate of its computed mean 336 | and variance. The running sum is kept with a default momentum of 0.1. 337 | 338 | During evaluation, this running mean/variance is used for normalization. 339 | 340 | Because the BatchNorm is done over the `C` dimension, computing statistics 341 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm 342 | or Spatio-temporal BatchNorm 343 | 344 | Args: 345 | num_features: num_features from an expected input of 346 | size batch_size x num_features x depth x height x width 347 | eps: a value added to the denominator for numerical stability. 348 | Default: 1e-5 349 | momentum: the value used for the running_mean and running_var 350 | computation. Default: 0.1 351 | affine: a boolean value that when set to ``True``, gives the layer learnable 352 | affine parameters. Default: ``True`` 353 | 354 | Shape:: 355 | - Input: :math:`(N, C, D, H, W)` 356 | - Output: :math:`(N, C, D, H, W)` (same shape as input) 357 | 358 | Examples: 359 | >>> # With Learnable Parameters 360 | >>> m = SynchronizedBatchNorm3d(100) 361 | >>> # Without Learnable Parameters 362 | >>> m = SynchronizedBatchNorm3d(100, affine=False) 363 | >>> inputs = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) 364 | >>> output = m(inputs) 365 | """ 366 | 367 | def _check_input_dim(self, input): 368 | if input.dim() != 5: 369 | raise ValueError( 370 | 'expected 5D input (got {}D input)'.format(input.dim()) 371 | ) 372 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | # File: test.py 4 | 5 | # This file is used for testing. 6 | 7 | from func import * 8 | from sync_bn import * 9 | from parallel import * 10 | 11 | from torch import nn 12 | 13 | import os 14 | import torch 15 | 16 | DEV_IDS = [1, 5] 17 | DEV = torch.device('cuda:{}'.format(DEV_IDS[0]) if torch.cuda.is_available() else 'cpu') 18 | os.environ["CUDA_VISIBLE_DEVICES"] = ','.join([str(dev_id) for dev_id in DEV_IDS]) 19 | 20 | 21 | class ModelBn(nn.Module): 22 | def __init__(self, num_features=3): 23 | super(ModelBn, self).__init__() 24 | self.bn = nn.BatchNorm2d(num_features=num_features) 25 | 26 | def forward(self, inputs): 27 | outputs = self.bn(inputs) 28 | return outputs 29 | 30 | 31 | class ModelSynBn(nn.Module): 32 | def __init__(self, num_features=3): 33 | super(ModelSynBn, self).__init__() 34 | self.bn = SynchronizedBatchNorm2d(num_features=num_features) 35 | 36 | def forward(self, inputs): 37 | outputs = self.bn(inputs) 38 | return outputs 39 | 40 | 41 | if __name__ == '__main__': 42 | model_syn_bn = ModelSynBn() 43 | model_syn_bn = DataParallelWithCallBack(model_syn_bn, device_ids=DEV_IDS) 44 | model_syn_bn.to(DEV) 45 | print(model_syn_bn) 46 | 47 | x = torch.randint(low=0, high=256, size=(4, 3, 256, 256), device=DEV).float() 48 | print(x) 49 | 50 | y = model_syn_bn(x) 51 | print(y) 52 | 53 | # *mean* 54 | print(y.mean(dim=(0, 2, 3))) 55 | # *std* 56 | print(y.std(dim=(0, 2, 3))) 57 | 58 | model_bn = ModelBn() 59 | model_bn = nn.DataParallel(model_bn, device_ids=DEV_IDS) 60 | model_bn.to(DEV) 61 | print(model_bn) 62 | 63 | y = model_bn(x) 64 | print(y) 65 | 66 | # *mean* 67 | print(y.mean(dim=(0, 2, 3))) 68 | # *std* 69 | print(y.std(dim=(0, 2, 3))) 70 | 71 | # *Use 'convert_model' to onvert input module and its child recursively* 72 | model_cvt = convert_model(model_bn) 73 | print(model_cvt) 74 | 75 | y = model_cvt(x) 76 | print(y) 77 | 78 | # *mean* 79 | print(y.mean(dim=(0, 2, 3))) 80 | # *std* 81 | print(y.std(dim=(0, 2, 3))) 82 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : utils.py 3 | # Author : CW 4 | # Email : chrisway613@gmail.com 5 | # Date : 21/01/2020 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | 9 | __all__ = ('FutureResult', 'sum_ft', 'unsqueeze_ft') 10 | 11 | import threading 12 | 13 | 14 | class FutureResult: 15 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 16 | 17 | def __init__(self, wait_timeout=30.): 18 | self._wait_timeout = wait_timeout 19 | 20 | self._result = None 21 | self._lock = threading.Lock() 22 | self._cond = threading.Condition(self._lock) 23 | 24 | def put(self, result): 25 | with self._lock: 26 | assert self._result is None, 'Previous result has not been fetched!' 27 | self._result = result 28 | self._cond.notify() 29 | 30 | def get(self): 31 | with self._lock: 32 | if self._result is None: 33 | self._cond.wait(timeout=self._wait_timeout) 34 | 35 | res = self._result 36 | self._result = None 37 | 38 | return res 39 | 40 | 41 | def sum_ft(tensor): 42 | """sum over the first and last dimension""" 43 | return tensor.sum(dim=0).sum(dim=-1) 44 | 45 | 46 | def unsqueeze_ft(tensor): 47 | """add new dimensions at the front and the tail""" 48 | return tensor.unsqueeze(0).unsqueeze(-1) 49 | --------------------------------------------------------------------------------