├── .gitignore ├── SyncBN ├── __init__.py ├── utils.py ├── syncbn.py └── functions.py ├── examples ├── utils │ ├── __init__.py │ ├── scheduler.py │ ├── utils.py │ ├── visualizer.py │ └── stream_metrics.py ├── resnet.py └── cifar.py └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ -------------------------------------------------------------------------------- /SyncBN/__init__.py: -------------------------------------------------------------------------------- 1 | from .syncbn import SyncBatchNorm1d, SyncBatchNorm2d, SyncBatchNorm3d, SyncBatchNorm -------------------------------------------------------------------------------- /examples/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * 2 | from .visualizer import Visualizer 3 | from .scheduler import PolyLR 4 | from .stream_metrics import StreamClsMetrics 5 | -------------------------------------------------------------------------------- /examples/utils/scheduler.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import _LRScheduler, StepLR 2 | 3 | class PolyLR(_LRScheduler): 4 | def __init__(self, optimizer, max_iters, power=0.9, last_epoch=-1): 5 | self.power = power 6 | self.max_iters = max_iters 7 | super(PolyLR, self).__init__(optimizer, last_epoch) 8 | 9 | def get_lr(self): 10 | return [ base_lr * ( 1 - self.last_epoch/self.max_iters )**self.power 11 | for base_lr in self.base_lrs] -------------------------------------------------------------------------------- /SyncBN/utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .syncbn import SyncBatchNorm 3 | 4 | def convert_sync_batchnorm(module): 5 | module_output = module 6 | if isinstance(module, nn.modules.batchnorm._BatchNorm): 7 | module_output = SyncBatchNorm(module.num_features, 8 | module.eps, module.momentum, 9 | module.affine, 10 | module.track_running_stats) 11 | if module.affine: 12 | module_output.weight.data = module.weight.data.clone().detach() 13 | module_output.bias.data = module.bias.data.clone().detach() 14 | module_output.running_mean = module.running_mean 15 | module_output.running_var = module.running_var 16 | module_output.num_batches_tracked = module.num_batches_tracked 17 | for name, child in module.named_children(): 18 | module_output.add_module(name, convert_sync_batchnorm(child)) 19 | return module_output -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SyncBatchNorm 2 | 3 | Pytorch synchronized batch normalization implemented in pure python. 4 | 5 | This repo is inspired by [PyTorch-Encoding](https://github.com/zhanghang1989/PyTorch-Encoding). 6 | 7 | 8 | # Requirements 9 | 10 | pytorch >= 1.0 11 | 12 | # Quick Start 13 | 14 | It is recommended to convert your model to sync version using convert_sync_batchnorm. 15 | 16 | ```python 17 | from SyncBN import SyncBatchNorm2d 18 | from SyncBN.utils import convert_sync_batchnorm 19 | from torchvision.models import resnet34 20 | 21 | sync_model = convert_sync_batchnorm( resnet34() ) # build resnet34 and replace bn with syncbn 22 | sync_model = torch.nn.DataParallel(sync_model) # Parallel on multi gpus 23 | ``` 24 | 25 | Or you can build your model from scratch. 26 | 27 | ```python 28 | from SyncBN import SyncBatchNorm2d 29 | 30 | sync_model = nn.Sequential( 31 | nn.Conv2d(3, 12, 3, 1, 1), 32 | SyncBatchNorm2d(12, momentum=0.1, eps=1e-5, affine=True), 33 | nn.ReLU(), 34 | ) 35 | sync_model = torch.nn.DataParallel(sync_model) # Parallel on multi gpus 36 | ``` 37 | # Cifar example 38 | 39 | ```bash 40 | cd SyncBatchNorm/examples 41 | python cifar.py --gpu_id 0,1 --data_root ./data --batch_size 64 --sync_bn 42 | ``` 43 | 44 | -------------------------------------------------------------------------------- /examples/utils/utils.py: -------------------------------------------------------------------------------- 1 | from torchvision.transforms.functional import normalize 2 | import torch.nn as nn 3 | import numpy as np 4 | import os 5 | 6 | def denormalize(tensor, mean, std): 7 | mean = np.array(mean) 8 | std = np.array(std) 9 | 10 | _mean = -mean/std 11 | _std = 1/std 12 | return normalize(tensor, _mean, _std) 13 | 14 | class Denormalize(object): 15 | def __init__(self, mean, std): 16 | mean = np.array(mean) 17 | std = np.array(std) 18 | self._mean = -mean/std 19 | self._std = 1/std 20 | 21 | def __call__(self, tensor): 22 | if isinstance(tensor, np.ndarray): 23 | return (tensor - self._mean.reshape(-1,1,1)) / self._std.reshape(-1,1,1) 24 | return normalize(tensor, self._mean, self._std) 25 | 26 | def fix_bn(model): 27 | for m in model.modules(): 28 | if isinstance(m, nn.BatchNorm2d): 29 | m.eval() 30 | m.weight.requires_grad = False 31 | m.bias.requires_grad = False 32 | 33 | def mkdir(path): 34 | if not os.path.exists(path): 35 | os.mkdir(path) 36 | 37 | def convert_bn2gn(module): 38 | mod = module 39 | if isinstance(module, nn.modules.batchnorm._BatchNorm): 40 | num_features = module.num_features 41 | num_groups = num_features//16 42 | mod = nn.GroupNorm(num_groups=num_groups, num_channels=num_features) 43 | for name, child in module.named_children(): 44 | mod.add_module(name, convert_bn2gn(child)) 45 | del module 46 | return mod 47 | 48 | def group_params(net): 49 | group_decay = [] 50 | group_no_decay = [] 51 | for m in net.modules(): 52 | if isinstance(m, nn.modules.batchnorm._BatchNorm):# or isinstance(m, nn.Linear): 53 | group_no_decay.extend(m.parameters(recurse=False)) 54 | else: 55 | for name, params in m.named_parameters(recurse=False): 56 | if 'bias' in name: 57 | group_no_decay.append(params) 58 | else: 59 | group_decay.append(params) 60 | return group_decay, group_no_decay 61 | -------------------------------------------------------------------------------- /examples/utils/visualizer.py: -------------------------------------------------------------------------------- 1 | from visdom import Visdom 2 | import json 3 | 4 | class Visualizer(object): 5 | """ Visualizer 6 | """ 7 | def __init__(self, port='13579', env='main', id=None): 8 | self.cur_win = {} 9 | self.vis = Visdom(port=port, env=env) 10 | self.id = id 11 | self.env = env 12 | # Restore 13 | ori_win = self.vis.get_window_data() 14 | ori_win = json.loads(ori_win) 15 | self.cur_win = { v['title']: k for k, v in ori_win.items() } 16 | 17 | def vis_scalar(self, win_name, trace_name, x, y, opts=None): 18 | """ Draw line 19 | """ 20 | if not isinstance(x, list): 21 | x = [x] 22 | if not isinstance(y, list): 23 | y = [y] 24 | 25 | if self.id is not None: 26 | win_name = "[%s]"%self.id + win_name 27 | 28 | default_opts = { 'title': win_name } 29 | 30 | if opts is not None: 31 | default_opts.update(opts) 32 | 33 | win = self.cur_win.get(win_name, None) 34 | 35 | if win is not None: 36 | self.vis.line( X=x, Y=y, opts=default_opts, update='append',win=win, name=trace_name ) 37 | else: 38 | self.cur_win[win_name] = self.vis.line( X=x, Y=y, opts=default_opts, name=trace_name) 39 | 40 | def vis_image(self, name, img, env=None, opts=None): 41 | """ vis image in visdom 42 | """ 43 | if env is None: 44 | env = self.env 45 | if self.id is not None: 46 | name = "[%s]"%self.id + name 47 | win = self.cur_win.get(name, None) 48 | default_opts = { 'title': name } 49 | if opts is not None: 50 | default_opts.update(opts) 51 | if win is not None: 52 | self.vis.image( img=img, win=win, opts=opts, env=env ) 53 | else: 54 | self.cur_win[name] = self.vis.image( img=img, opts=default_opts, env=env ) 55 | 56 | def vis_table(self, name, tbl, opts=None): 57 | win = self.cur_win.get(name, None) 58 | 59 | tbl_str = " " 60 | tbl_str+=" \ 61 | \ 62 | \ 63 | " 64 | for k, v in tbl.items(): 65 | tbl_str+= " \ 66 | \ 67 | \ 68 | "%(k, v) 69 | 70 | tbl_str+="
TermValue
%s%s
" 71 | 72 | default_opts = { 'title': name } 73 | if opts is not None: 74 | default_opts.update(opts) 75 | if win is not None: 76 | self.vis.text(tbl_str, win=win, opts=default_opts) 77 | else: 78 | self.cur_win[name] = self.vis.text(tbl_str, opts=default_opts) 79 | 80 | 81 | if __name__=='__main__': 82 | import numpy as np 83 | vis = Visualizer(port=13500, env='main') 84 | tbl = {"lr": 214, "momentum": 0.9} 85 | vis.vis_table("test_table", tbl) 86 | tbl = {"lr": 244444, "momentum": 0.9, "haha": "hoho"} 87 | vis.vis_table("test_table", tbl) 88 | -------------------------------------------------------------------------------- /SyncBN/syncbn.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.nn.modules.batchnorm import _BatchNorm 3 | import torch.nn as nn 4 | import torch 5 | import torch.nn.functional as F 6 | from queue import Queue 7 | 8 | from .functions import SyncBNFunction 9 | 10 | import collections 11 | 12 | _bn_context = collections.namedtuple("_bn_context", ['sync', 'is_master', 'cur_device', 'queue', 'devices']) 13 | 14 | class SyncBatchNorm(_BatchNorm): 15 | """ Sync BN 16 | """ 17 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True): 18 | super(SyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats) 19 | 20 | self.devices = list(range(torch.cuda.device_count())) 21 | self.sync = len(self.devices)>1 22 | self._slaves = self.devices[1:] 23 | self._queues = [ Queue(len(self._slaves)) ] + [ Queue(1) for _ in self._slaves ] 24 | 25 | def _check_input_dim(self, input): 26 | if input.dim() < 2: 27 | raise ValueError('expected at least 2 dims (got {}D input)' 28 | .format(input.dim())) 29 | 30 | def forward(self, input): 31 | self._check_input_dim(input) 32 | 33 | if not self.training and self.track_running_stats: 34 | return F.batch_norm(input, running_mean=self.running_mean, running_var=self.running_var, 35 | weight=self.weight, bias=self.bias, training=False, momentum=0.0, eps=self.eps) 36 | else: 37 | exponential_average_factor = 0.0 38 | if self.num_batches_tracked is not None: # track running statistics 39 | self.num_batches_tracked += 1 40 | if self.momentum is None: # use cumulative moving average 41 | exponential_average_factor = 1.0 / float(self.num_batches_tracked) 42 | else: # use exponential moving average 43 | exponential_average_factor = self.momentum 44 | 45 | if input.is_cuda: 46 | cur_device = input.get_device() 47 | bn_ctx = _bn_context( self.sync, (cur_device==self.devices[0]), cur_device, self._queues, self.devices ) 48 | else: 49 | bn_ctx = _bn_context( False, True, None, None, None ) 50 | 51 | return SyncBNFunction.apply( input, self.weight, self.bias, self.running_mean, self.running_var, exponential_average_factor, self.eps, self.training, bn_ctx ) 52 | 53 | 54 | class SyncBatchNorm1d(SyncBatchNorm): 55 | def _check_input_dim(self, input): 56 | if input.dim() != 2 and input.dim() != 3: 57 | raise ValueError('expected 2D or 3D input (got {}D input)' 58 | .format(input.dim())) 59 | 60 | 61 | class SyncBatchNorm2d(SyncBatchNorm): 62 | def _check_input_dim(self, input): 63 | if input.dim() != 4: 64 | raise ValueError('expected 4D input (got {}D input)' 65 | .format(input.dim())) 66 | 67 | 68 | class SyncBatchNorm3d(SyncBatchNorm): 69 | def _check_input_dim(self, input): 70 | if input.dim() != 5: 71 | raise ValueError('expected 5D input (got {}D input)' 72 | .format(input.dim())) 73 | -------------------------------------------------------------------------------- /examples/resnet.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | For Pre-activation ResNet, see 'preact_resnet.py'. 3 | Reference: 4 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 5 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 6 | ''' 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class BasicBlock(nn.Module): 13 | expansion = 1 14 | 15 | def __init__(self, in_planes, planes, stride=1): 16 | super(BasicBlock, self).__init__() 17 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 20 | self.bn2 = nn.BatchNorm2d(planes) 21 | 22 | self.shortcut = nn.Sequential() 23 | if stride != 1 or in_planes != self.expansion*planes: 24 | self.shortcut = nn.Sequential( 25 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 26 | nn.BatchNorm2d(self.expansion*planes) 27 | ) 28 | 29 | def forward(self, x): 30 | out = F.relu(self.bn1(self.conv1(x))) 31 | out = self.bn2(self.conv2(out)) 32 | out += self.shortcut(x) 33 | out = F.relu(out) 34 | return out 35 | 36 | 37 | class Bottleneck(nn.Module): 38 | expansion = 4 39 | 40 | def __init__(self, in_planes, planes, stride=1): 41 | super(Bottleneck, self).__init__() 42 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 43 | self.bn1 = nn.BatchNorm2d(planes) 44 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 45 | self.bn2 = nn.BatchNorm2d(planes) 46 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 47 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 48 | 49 | self.shortcut = nn.Sequential() 50 | if stride != 1 or in_planes != self.expansion*planes: 51 | self.shortcut = nn.Sequential( 52 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 53 | nn.BatchNorm2d(self.expansion*planes) 54 | ) 55 | 56 | def forward(self, x): 57 | out = F.relu(self.bn1(self.conv1(x))) 58 | out = F.relu(self.bn2(self.conv2(out))) 59 | out = self.bn3(self.conv3(out)) 60 | out += self.shortcut(x) 61 | out = F.relu(out) 62 | return out 63 | 64 | 65 | class ResNet(nn.Module): 66 | def __init__(self, block, num_blocks, num_classes=10): 67 | super(ResNet, self).__init__() 68 | self.in_planes = 64 69 | 70 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 71 | self.bn1 = nn.BatchNorm2d(64) 72 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 73 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 74 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 75 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 76 | self.linear = nn.Linear(512*block.expansion, num_classes) 77 | 78 | def _make_layer(self, block, planes, num_blocks, stride): 79 | strides = [stride] + [1]*(num_blocks-1) 80 | layers = [] 81 | for stride in strides: 82 | layers.append(block(self.in_planes, planes, stride)) 83 | self.in_planes = planes * block.expansion 84 | return nn.Sequential(*layers) 85 | 86 | def forward(self, x): 87 | out = F.relu(self.bn1(self.conv1(x))) 88 | out = self.layer1(out) 89 | out = self.layer2(out) 90 | out = self.layer3(out) 91 | out = self.layer4(out) 92 | out = F.avg_pool2d(out, 4) 93 | out = out.view(out.size(0), -1) 94 | out = self.linear(out) 95 | return out 96 | 97 | 98 | def resnet18(): 99 | return ResNet(BasicBlock, [2,2,2,2]) 100 | 101 | def resnet34(): 102 | return ResNet(BasicBlock, [3,4,6,3]) 103 | 104 | def resnet50(): 105 | return ResNet(Bottleneck, [3,4,6,3]) 106 | 107 | def resnet101(): 108 | return ResNet(Bottleneck, [3,4,23,3]) 109 | 110 | def resnet152(): 111 | return ResNet(Bottleneck, [3,8,36,3]) 112 | 113 | -------------------------------------------------------------------------------- /examples/utils/stream_metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class _StreamMetrics(object): 5 | def __init__(self): 6 | """ Overridden by subclasses """ 7 | raise NotImplementedError() 8 | 9 | def update(self, gt, pred): 10 | """ Overridden by subclasses """ 11 | raise NotImplementedError() 12 | 13 | def get_results(self): 14 | """ Overridden by subclasses """ 15 | raise NotImplementedError() 16 | 17 | def to_str(self, metrics): 18 | """ Overridden by subclasses """ 19 | raise NotImplementedError() 20 | 21 | def reset(self): 22 | """ Overridden by subclasses """ 23 | raise NotImplementedError() 24 | 25 | class StreamClsMetrics(_StreamMetrics): 26 | def __init__(self, n_classes): 27 | self.n_classes = n_classes 28 | self.confusion_matrix = np.zeros((n_classes, n_classes)) 29 | 30 | def update(self, label_trues, label_preds): 31 | for lt, lp in zip(label_trues, label_preds): 32 | self.confusion_matrix[lp][lt] += 1 33 | 34 | @staticmethod 35 | def to_str(results): 36 | string = "\n" 37 | for k, v in results.items(): 38 | if k!="Class IoU": 39 | string += "%s: %f\n"%(k, v) 40 | 41 | string+='Class IoU:\n' 42 | for k, v in results['Class IoU'].items(): 43 | string += "\tclass %d: %f\n"%(k, v) 44 | return string 45 | 46 | def get_results(self): 47 | """Returns accuracy score evaluation result. 48 | - overall accuracy 49 | - mean accuracy 50 | - mean IU 51 | - fwavacc 52 | """ 53 | hist = self.confusion_matrix 54 | acc = np.diag(hist).sum() / hist.sum() 55 | acc_cls = np.diag(hist) / hist.sum(axis=1) 56 | acc_cls = np.nanmean(acc_cls) 57 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)) 58 | mean_iu = np.nanmean(iu) 59 | freq = hist.sum(axis=1) / hist.sum() 60 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() 61 | cls_iu = dict(zip(range(self.n_classes), iu)) 62 | 63 | return { 64 | "Overall Acc": acc, 65 | "Mean Acc": acc_cls, 66 | "FreqW Acc": fwavacc, 67 | "Mean IoU": mean_iu, 68 | "Class IoU": cls_iu 69 | } 70 | 71 | def reset(self): 72 | self.confusion_matrix = np.zeros((self.n_classes, self.n_classes)) 73 | 74 | 75 | class AverageMeter(object): 76 | """Computes and stores the average and current value""" 77 | def __init__(self): 78 | self.book = dict() 79 | 80 | def reset_all(self): 81 | self.book.clear() 82 | 83 | def reset(self, id): 84 | item = self.book.get(id, None) 85 | if item is not None: 86 | item[0] = 0 87 | item[1] = 0 88 | 89 | def update(self, id, val): 90 | record = self.book.get(id, None) 91 | if record is None: 92 | self.book[id] = [val, 1] 93 | else: 94 | record[0]+=val 95 | record[1]+=1 96 | 97 | def get_results(self, id): 98 | record = self.book.get(id, None) 99 | assert record is not None 100 | return record[0] / record[1] 101 | 102 | class VocClsMetrics(_StreamMetrics): 103 | def __init__(self, n_classes): 104 | self.n_classes = n_classes 105 | self.confusion_matrix = np.zeros(shape=(n_classes, 4)) # n_classes, TN, FP, FN, TP 106 | #00, 01, 10, 11 target-predict 107 | 108 | def update(self, label_trues, label_preds): 109 | for lt, lp in zip(label_trues, label_preds): 110 | idx = (tuple(range(self.n_classes)), lt*2+lp) 111 | self.confusion_matrix[idx] += 1 112 | 113 | def to_str(self, results): 114 | string = "\n" 115 | string += "Overall Acc: %f\n"%(results['Overall Acc']) 116 | string += "Overall Precision: %f\n"%(results['Overall Precision']) 117 | string += "Overall Recall: %f\n"%(results['Overall Recall']) 118 | string += "Overall F1: %f\n"%(results['Overall F1']) 119 | 120 | string+='Class Metrics:\n' 121 | 122 | for i in range(self.n_classes): 123 | string += "\tclass %d: acc=%f, precision=%f, recall=%f, f1=%f\n"%(i,results['Class Acc'][i],results['Class Precision'][i],results['Class Recall'][i],results['Class F1'][i] ) 124 | return string 125 | 126 | def get_results(self): 127 | TN = self.confusion_matrix[:, 0] 128 | FP = self.confusion_matrix[:, 1] 129 | FN = self.confusion_matrix[:, 2] 130 | TP = self.confusion_matrix[:, 3] 131 | 132 | class_accuracy = np.nan_to_num( ( TN+TP ) / (TN+FP+FN+TP) ) 133 | class_precision = np.nan_to_num( TP / ( TP+FP ) ) 134 | class_recall = np.nan_to_num( TP / ( TP+FN ) ) 135 | class_f1 = np.nan_to_num( 2* (class_precision * class_recall) / (class_precision+class_recall) ) 136 | 137 | return {'Overall Acc': class_accuracy.mean(), 138 | 'Overall Precision': class_precision.mean(), 139 | 'Overall Recall': class_recall.mean(), 140 | 'Overall F1': class_f1.mean(), 141 | 'Class Acc': class_accuracy, 142 | 'Class Precision': class_precision, 143 | 'Class Recall': class_recall, 144 | 'Class F1': class_f1} 145 | 146 | def reset(self): 147 | self.correct = np.zeros(self.n_classes) 148 | -------------------------------------------------------------------------------- /SyncBN/functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | from torch.autograd.function import once_differentiable 4 | import torch.nn.functional as F 5 | 6 | def unsqueeze(tensor): 7 | return tensor.unsqueeze(1).unsqueeze(0) 8 | 9 | class SyncBNFunction(Function): 10 | 11 | @staticmethod 12 | def forward(ctx, x, weight, bias, running_mean, running_var, momentum, eps, training, bn_ctx): 13 | x_shape = x.shape 14 | B, C = x_shape[:2] 15 | 16 | _x = x.view(B,C,-1).contiguous() 17 | 18 | ctx.eps = eps 19 | ctx.training = training 20 | 21 | ctx.sync = bn_ctx.sync 22 | ctx.cur_device = bn_ctx.cur_device 23 | ctx.queue = bn_ctx.queue 24 | ctx.is_master = bn_ctx.is_master 25 | ctx.devices = bn_ctx.devices 26 | 27 | norm = 1/(_x.shape[0] * _x.shape[2]) 28 | 29 | if ctx.training: 30 | _ex = _x.sum(2).sum(0) * norm 31 | _exs = _x.pow(2).sum(2).sum(0) * norm 32 | 33 | if ctx.sync: 34 | if ctx.is_master: 35 | 36 | _ex, _exs = [_ex.unsqueeze(1)], [_exs.unsqueeze(1)] 37 | 38 | master_queue = ctx.queue[0] 39 | for j in range(master_queue.maxsize): 40 | _slave_ex, _slave_exs = master_queue.get() 41 | master_queue.task_done() 42 | 43 | _ex.append( _slave_ex.unsqueeze(1) ) 44 | _exs.append( _slave_exs.unsqueeze(1) ) 45 | 46 | _ex = torch.cuda.comm.gather( _ex, dim=1 ).mean(1) 47 | _exs = torch.cuda.comm.gather( _exs, dim=1 ).mean(1) 48 | 49 | distributed_tensor = torch.cuda.comm.broadcast_coalesced( (_ex, _exs), ctx.devices ) 50 | 51 | for dt, q in zip( distributed_tensor[1:], ctx.queue[1:] ): 52 | q.put(dt) 53 | else: 54 | master_queue = ctx.queue[0] 55 | slave_queue = ctx.queue[ctx.cur_device] 56 | master_queue.put( (_ex, _exs) ) 57 | 58 | _ex, _exs = slave_queue.get() 59 | slave_queue.task_done() 60 | _ex, _exs = _ex.squeeze(), _exs.squeeze() 61 | 62 | _var = _exs - _ex.pow(2) 63 | N = B*len(ctx.devices) 64 | unbiased_var = _var * N / (N - 1) 65 | 66 | running_mean.mul_( (1-momentum) ).add_( momentum * _ex ) 67 | running_var.mul_( (1-momentum) ).add_( momentum * unbiased_var ) 68 | ctx.mark_dirty(running_mean, running_var) 69 | else: 70 | _ex, _var = running_mean.contiguous(), running_var.contiguous() 71 | _exs = _ex.pow(2) + _var 72 | 73 | invstd = 1/torch.sqrt( _var + eps ) 74 | 75 | if weight is not None: # affine 76 | output = (_x - unsqueeze(_ex) ) * unsqueeze(invstd) * unsqueeze(weight) + unsqueeze(bias) 77 | else: 78 | output = (_x - unsqueeze(_ex) ) * unsqueeze(invstd) 79 | 80 | ctx.save_for_backward(x, _ex, _exs, weight, bias) 81 | return output.view(*x_shape).contiguous().clone() 82 | 83 | @staticmethod 84 | def backward(ctx, grad_output): 85 | x, _ex, _exs, weight, bias = ctx.saved_tensors 86 | grad_x = grad_weight = grad_bias = None 87 | 88 | B,C = grad_output.shape[:2] 89 | grad_output_shape = grad_output.shape 90 | 91 | _var = _exs - _ex.pow(2) 92 | _std = torch.sqrt( _var + ctx.eps) 93 | invstd = 1.0 / _std 94 | 95 | grad_output = grad_output.view(B,C,-1) 96 | x = x.view(B,C,-1) 97 | 98 | norm = 1.0/(x.shape[0] * x.shape[2]) 99 | 100 | dot_p = ( grad_output * ( x - unsqueeze( _ex ) ) ).sum(2).sum(0) 101 | grad_output_sum = grad_output.sum(2).sum(0) 102 | 103 | grad_scale = weight * invstd 104 | 105 | grad_ex = -grad_output_sum * grad_scale + _ex * invstd * invstd * dot_p * grad_scale 106 | grad_exs = -0.5 * grad_scale * invstd * invstd * dot_p 107 | 108 | # Sync 109 | if ctx.training: 110 | if ctx.sync: 111 | if ctx.is_master: 112 | grad_ex, grad_exs = [grad_ex.unsqueeze(1)], [grad_exs.unsqueeze(1)] 113 | master_queue = ctx.queue[0] 114 | for j in range(master_queue.maxsize): 115 | grad_slave_ex, grad_slave_exs = master_queue.get() 116 | master_queue.task_done() 117 | 118 | grad_ex.append( grad_slave_ex.unsqueeze(1) ) 119 | grad_exs.append( grad_slave_exs.unsqueeze(1) ) 120 | 121 | grad_ex = torch.cuda.comm.gather( grad_ex, dim=1 ).mean(1) 122 | grad_exs = torch.cuda.comm.gather( grad_exs, dim=1).mean(1) 123 | 124 | distributed_tensor = torch.cuda.comm.broadcast_coalesced( (grad_ex, grad_exs), ctx.devices ) 125 | for dt, q in zip( distributed_tensor[1:], ctx.queue[1:] ): 126 | q.put(dt) 127 | else: 128 | master_queue = ctx.queue[0] 129 | slave_queue = ctx.queue[ctx.cur_device] 130 | master_queue.put( (grad_ex, grad_exs) ) 131 | 132 | grad_ex, grad_exs = slave_queue.get() 133 | slave_queue.task_done() 134 | grad_ex, grad_exs = grad_ex.squeeze(), grad_exs.squeeze() 135 | 136 | if ctx.needs_input_grad[0]: 137 | grad_x = grad_output * unsqueeze( grad_scale ) + unsqueeze( grad_ex * norm ) + unsqueeze(grad_exs) * 2 * x * norm 138 | 139 | if ctx.needs_input_grad[1]: 140 | grad_weight = dot_p * invstd 141 | 142 | if ctx.needs_input_grad[2]: 143 | grad_bias = grad_output_sum 144 | 145 | return grad_x.view(*grad_output_shape), grad_weight, grad_bias, None, None, None, None, None, None 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | -------------------------------------------------------------------------------- /examples/cifar.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from resnet import resnet18, resnet34, resnet50, resnet101, resnet152 5 | from torchvision.datasets import CIFAR10 6 | from torchvision import transforms 7 | import argparse 8 | import os, sys 9 | 10 | import utils 11 | import numpy as np 12 | import random 13 | from utils import Visualizer 14 | 15 | from torch.utils import data 16 | 17 | sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) 18 | from SyncBN import SyncBatchNorm2d 19 | from SyncBN.utils import convert_sync_batchnorm 20 | from tqdm import tqdm 21 | 22 | def get_argparser(): 23 | parser = argparse.ArgumentParser() 24 | 25 | # Datset Options 26 | parser.add_argument("--data_root", type=str, default='./data', 27 | help="path to Dataset") 28 | 29 | # Train Options 30 | parser.add_argument("--model", type=str, default='resnet18', 31 | choices=['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152']) 32 | parser.add_argument("--epochs", type=int, default=40, 33 | help="epoch number (default: 40)") 34 | parser.add_argument("--lr", type=float, default=0.1, 35 | help="learning rate (default: 0.1)") 36 | 37 | parser.add_argument("--batch_size", type=int, default=128, 38 | help='batch size (default: 128)') 39 | parser.add_argument("--lr_decay_step", type=int, default=150, 40 | help='batch size (default: 150)') 41 | parser.add_argument("--ckpt", default=None, type=str, 42 | help="path to trained model. Leave it None if you want to retrain your model") 43 | parser.add_argument("--gpu_id", type=str, default='0', 44 | help="GPU ID") 45 | 46 | parser.add_argument("--momentum", type=float, default=0.9, 47 | help='momentum for SGD (default: 0.9)') 48 | parser.add_argument("--weight_decay", type=float, default=5e-4, 49 | help='weight decay (default: 5e-4)') 50 | 51 | parser.add_argument("--num_workers", type=int, default=4, 52 | help='number of workers (default: 4)') 53 | parser.add_argument("--val_on_trainset", action='store_true', default=False , 54 | help="enable validation on train set (default: False)") 55 | parser.add_argument("--random_seed", type=int, default=23333, 56 | help="random seed (default: 23333)") 57 | parser.add_argument("--print_interval", type=int, default=10, 58 | help="print interval of loss (default: 10)") 59 | parser.add_argument("--val_interval", type=int, default=1, 60 | help="epoch interval for eval (default: 1)") 61 | parser.add_argument("--ckpt_interval", type=int, default=1, 62 | help="saving interval (default: 1)") 63 | parser.add_argument("--download", action='store_true', default=False, 64 | help="download datasets") 65 | parser.add_argument("--sync_bn", action='store_true', default=False, 66 | help="sync batchnorm") 67 | 68 | # Visdom options 69 | parser.add_argument("--enable_vis", action='store_true', default=False, 70 | help="use visdom for visualization") 71 | parser.add_argument("--vis_port", type=str, default='15555', 72 | help='port for visdom') 73 | parser.add_argument("--vis_env", type=str, default='main', 74 | help='env for visdom') 75 | parser.add_argument("--trace_name", type=str, default=None) 76 | return parser 77 | 78 | 79 | 80 | def train( cur_epoch, criterion, model, optim, train_loader, device, scheduler=None, print_interval=10, vis=None, trace_name=None): 81 | """Train and return epoch loss""" 82 | 83 | if scheduler is not None: 84 | scheduler.step() 85 | print("Epoch %d, lr = %f"%(cur_epoch, optim.param_groups[0]['lr'])) 86 | epoch_loss = 0.0 87 | interval_loss = 0.0 88 | 89 | for cur_step, (images, labels) in enumerate( train_loader ): 90 | 91 | images = images.to(device, dtype=torch.float32) 92 | labels = labels.to(device, dtype=torch.long) 93 | 94 | # N, C, H, W 95 | optim.zero_grad() 96 | outputs = model(images) 97 | loss = criterion(outputs, labels) 98 | 99 | loss.backward() 100 | optim.step() 101 | 102 | np_loss = loss.detach().cpu().numpy() 103 | epoch_loss+=np_loss 104 | interval_loss+=np_loss 105 | 106 | if (cur_step+1)%print_interval==0: 107 | interval_loss = interval_loss/print_interval 108 | print("Epoch %d, Batch %d/%d, Loss=%f"%(cur_epoch, cur_step+1, len(train_loader), interval_loss)) 109 | if vis is not None: 110 | x = cur_epoch*len(train_loader) + cur_step + 1 111 | vis.vis_scalar('Loss', trace_name, x, interval_loss ) 112 | interval_loss=0.0 113 | return epoch_loss / len(train_loader) 114 | 115 | 116 | def validate( model, loader, device, metrics): 117 | """Do validation and return specified samples""" 118 | metrics.reset() 119 | with torch.no_grad(): 120 | for i, (images, labels) in enumerate( tqdm( loader ) ): 121 | 122 | images = images.to(device, dtype=torch.float32) 123 | labels = labels.to(device, dtype=torch.long) 124 | 125 | outputs = model(images) 126 | preds = outputs.detach().max(dim=1)[1].cpu().numpy() 127 | targets = labels.cpu().numpy() 128 | 129 | metrics.update(targets, preds) 130 | score = metrics.get_results() 131 | return score 132 | 133 | 134 | def main(): 135 | opts = get_argparser().parse_args() 136 | # Set up visualization 137 | vis = Visualizer(port=opts.vis_port, env=opts.vis_env) 138 | if vis is not None: # display options 139 | vis.vis_table( "%s opts"%opts.trace_name, vars(opts) ) 140 | os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_id 141 | device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' ) 142 | print("Device: %s"%device) 143 | 144 | # Set up random seed 145 | torch.manual_seed(opts.random_seed) 146 | torch.cuda.manual_seed(opts.random_seed) 147 | np.random.seed(opts.random_seed) 148 | random.seed(opts.random_seed) 149 | 150 | # Set up dataloader 151 | train_dst = CIFAR10(root='./data', train=True, 152 | transform=transforms.Compose([ 153 | transforms.RandomHorizontalFlip(), 154 | transforms.ToTensor(), 155 | transforms.Normalize( mean=[0.485, 0.456, 0.406], 156 | std=[0.229, 0.224, 0.225] )]), 157 | download=opts.download ) 158 | 159 | val_dst = CIFAR10(root='./data', train=False, 160 | transform=transforms.Compose([ 161 | transforms.ToTensor(), 162 | transforms.Normalize( mean=[0.485, 0.456, 0.406], 163 | std=[0.229, 0.224, 0.225] )]), 164 | download=False ) 165 | 166 | train_loader = data.DataLoader(train_dst, batch_size=opts.batch_size, shuffle=True, num_workers=opts.num_workers) 167 | val_loader = data.DataLoader(val_dst, batch_size=opts.batch_size, shuffle=False, num_workers=opts.num_workers) 168 | print("Dataset: CIFAR10, Train set: %d, Val set: %d"%(len(train_dst), len(val_dst))) 169 | 170 | model = {"resnet18": resnet18, 171 | "resnet34": resnet34, 172 | "resnet50": resnet50, 173 | "resnet101": resnet101, 174 | "resnet152": resnet152 }[opts.model]() 175 | 176 | trace_name = opts.trace_name 177 | if opts.sync_bn==True: 178 | print("Use sync batchnorm") 179 | model = convert_sync_batchnorm(model) 180 | print(model) 181 | 182 | if torch.cuda.device_count()>1: # Parallel 183 | print("%d GPU parallel"%(torch.cuda.device_count())) 184 | model = torch.nn.DataParallel(model) 185 | model_ref = model.module # for ckpt 186 | else: 187 | model_ref = model 188 | model = model.to(device) 189 | 190 | # Set up metrics 191 | metrics = utils.StreamClsMetrics(10) 192 | 193 | # Set up optimizer 194 | group_decay, group_no_decay = utils.group_params(model_ref) 195 | assert(len(group_decay)+len(group_no_decay) == len(list(model_ref.parameters()))) 196 | optimizer = torch.optim.SGD([ {'params': group_decay, 'weight_decay': opts.weight_decay}, 197 | {'params': group_no_decay}], 198 | lr=opts.lr, momentum=opts.momentum ) 199 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.lr_decay_step, gamma=0.1) 200 | print("optimizer:\n%s"%(optimizer)) 201 | 202 | utils.mkdir('checkpoints') 203 | 204 | # Restore 205 | best_score = 0.0 206 | cur_epoch = 0 207 | if opts.ckpt is not None and os.path.isfile(opts.ckpt): 208 | checkpoint = torch.load(opts.ckpt) 209 | model_ref.load_state_dict(checkpoint["model_state"]) 210 | optimizer.load_state_dict(checkpoint["optimizer_state"]) 211 | scheduler.load_state_dict(checkpoint["scheduler_state"]) 212 | cur_epoch = checkpoint["epoch"]+1 213 | best_score = checkpoint['best_score'] 214 | print("Model restored from %s"%opts.ckpt) 215 | del checkpoint # free memory 216 | else: 217 | print("[!] Retrain") 218 | 219 | def save_ckpt(path): 220 | """ save current model 221 | """ 222 | state = { 223 | "epoch": cur_epoch, 224 | "model_state": model_ref.state_dict(), 225 | "optimizer_state": optimizer.state_dict(), 226 | "scheduler_state": scheduler.state_dict(), 227 | "best_score": best_score, 228 | } 229 | torch.save(state, path) 230 | print( "Model saved as %s"%path ) 231 | 232 | # Set up criterion 233 | criterion = nn.CrossEntropyLoss(reduction='mean') 234 | #========== Train Loop ==========# 235 | while cur_epoch < opts.epochs: 236 | # ===== Train ===== 237 | model.train() 238 | epoch_loss = train(cur_epoch=cur_epoch, criterion=criterion, model=model, optim=optimizer, train_loader=train_loader, device=device, scheduler=scheduler, vis=vis, trace_name=trace_name) 239 | print("End of Epoch %d/%d, Average Loss=%f"%(cur_epoch, opts.epochs, epoch_loss)) 240 | 241 | if opts.enable_vis: 242 | vis.vis_scalar("Epoch Loss", trace_name, cur_epoch, epoch_loss ) 243 | 244 | # ===== Save Latest Model ===== 245 | if (cur_epoch+1)%opts.ckpt_interval==0: 246 | save_ckpt( 'checkpoints/latest_resnet34_cifar10.pkl' ) 247 | 248 | # ===== Validation ===== 249 | if (cur_epoch+1)%opts.val_interval==0: 250 | print("validate on val set...") 251 | model.eval() 252 | val_score = validate(model=model, loader=val_loader, device=device, metrics=metrics) 253 | print(metrics.to_str(val_score)) 254 | 255 | # ===== Save Best Model ===== 256 | if val_score['Mean IoU']>best_score: # save best model 257 | best_score = val_score['Overall Acc'] 258 | save_ckpt( 'checkpoints/latest_resnet34_cifar10.pkl') 259 | 260 | if vis is not None: # visualize validation score and samples 261 | vis.vis_scalar("[Val] Overall Acc",trace_name, cur_epoch, val_score['Overall Acc'] ) 262 | 263 | if opts.val_on_trainset==True: # validate on train set 264 | print("validate on train set...") 265 | model.eval() 266 | train_score = validate(model=model, loader=train_loader, device=device, metrics=metrics) 267 | print(metrics.to_str(train_score)) 268 | if vis is not None: 269 | vis.vis_scalar("[Train] Overall Acc", trace_name, cur_epoch, train_score['Overall Acc'] ) 270 | 271 | cur_epoch+=1 272 | 273 | if __name__=='__main__': 274 | main() 275 | 276 | 277 | 278 | 279 | 280 | 281 | --------------------------------------------------------------------------------