├── .gitignore
├── SyncBN
├── __init__.py
├── utils.py
├── syncbn.py
└── functions.py
├── examples
├── utils
│ ├── __init__.py
│ ├── scheduler.py
│ ├── utils.py
│ ├── visualizer.py
│ └── stream_metrics.py
├── resnet.py
└── cifar.py
└── README.md
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
--------------------------------------------------------------------------------
/SyncBN/__init__.py:
--------------------------------------------------------------------------------
1 | from .syncbn import SyncBatchNorm1d, SyncBatchNorm2d, SyncBatchNorm3d, SyncBatchNorm
--------------------------------------------------------------------------------
/examples/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .utils import *
2 | from .visualizer import Visualizer
3 | from .scheduler import PolyLR
4 | from .stream_metrics import StreamClsMetrics
5 |
--------------------------------------------------------------------------------
/examples/utils/scheduler.py:
--------------------------------------------------------------------------------
1 | from torch.optim.lr_scheduler import _LRScheduler, StepLR
2 |
3 | class PolyLR(_LRScheduler):
4 | def __init__(self, optimizer, max_iters, power=0.9, last_epoch=-1):
5 | self.power = power
6 | self.max_iters = max_iters
7 | super(PolyLR, self).__init__(optimizer, last_epoch)
8 |
9 | def get_lr(self):
10 | return [ base_lr * ( 1 - self.last_epoch/self.max_iters )**self.power
11 | for base_lr in self.base_lrs]
--------------------------------------------------------------------------------
/SyncBN/utils.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from .syncbn import SyncBatchNorm
3 |
4 | def convert_sync_batchnorm(module):
5 | module_output = module
6 | if isinstance(module, nn.modules.batchnorm._BatchNorm):
7 | module_output = SyncBatchNorm(module.num_features,
8 | module.eps, module.momentum,
9 | module.affine,
10 | module.track_running_stats)
11 | if module.affine:
12 | module_output.weight.data = module.weight.data.clone().detach()
13 | module_output.bias.data = module.bias.data.clone().detach()
14 | module_output.running_mean = module.running_mean
15 | module_output.running_var = module.running_var
16 | module_output.num_batches_tracked = module.num_batches_tracked
17 | for name, child in module.named_children():
18 | module_output.add_module(name, convert_sync_batchnorm(child))
19 | return module_output
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SyncBatchNorm
2 |
3 | Pytorch synchronized batch normalization implemented in pure python.
4 |
5 | This repo is inspired by [PyTorch-Encoding](https://github.com/zhanghang1989/PyTorch-Encoding).
6 |
7 |
8 | # Requirements
9 |
10 | pytorch >= 1.0
11 |
12 | # Quick Start
13 |
14 | It is recommended to convert your model to sync version using convert_sync_batchnorm.
15 |
16 | ```python
17 | from SyncBN import SyncBatchNorm2d
18 | from SyncBN.utils import convert_sync_batchnorm
19 | from torchvision.models import resnet34
20 |
21 | sync_model = convert_sync_batchnorm( resnet34() ) # build resnet34 and replace bn with syncbn
22 | sync_model = torch.nn.DataParallel(sync_model) # Parallel on multi gpus
23 | ```
24 |
25 | Or you can build your model from scratch.
26 |
27 | ```python
28 | from SyncBN import SyncBatchNorm2d
29 |
30 | sync_model = nn.Sequential(
31 | nn.Conv2d(3, 12, 3, 1, 1),
32 | SyncBatchNorm2d(12, momentum=0.1, eps=1e-5, affine=True),
33 | nn.ReLU(),
34 | )
35 | sync_model = torch.nn.DataParallel(sync_model) # Parallel on multi gpus
36 | ```
37 | # Cifar example
38 |
39 | ```bash
40 | cd SyncBatchNorm/examples
41 | python cifar.py --gpu_id 0,1 --data_root ./data --batch_size 64 --sync_bn
42 | ```
43 |
44 |
--------------------------------------------------------------------------------
/examples/utils/utils.py:
--------------------------------------------------------------------------------
1 | from torchvision.transforms.functional import normalize
2 | import torch.nn as nn
3 | import numpy as np
4 | import os
5 |
6 | def denormalize(tensor, mean, std):
7 | mean = np.array(mean)
8 | std = np.array(std)
9 |
10 | _mean = -mean/std
11 | _std = 1/std
12 | return normalize(tensor, _mean, _std)
13 |
14 | class Denormalize(object):
15 | def __init__(self, mean, std):
16 | mean = np.array(mean)
17 | std = np.array(std)
18 | self._mean = -mean/std
19 | self._std = 1/std
20 |
21 | def __call__(self, tensor):
22 | if isinstance(tensor, np.ndarray):
23 | return (tensor - self._mean.reshape(-1,1,1)) / self._std.reshape(-1,1,1)
24 | return normalize(tensor, self._mean, self._std)
25 |
26 | def fix_bn(model):
27 | for m in model.modules():
28 | if isinstance(m, nn.BatchNorm2d):
29 | m.eval()
30 | m.weight.requires_grad = False
31 | m.bias.requires_grad = False
32 |
33 | def mkdir(path):
34 | if not os.path.exists(path):
35 | os.mkdir(path)
36 |
37 | def convert_bn2gn(module):
38 | mod = module
39 | if isinstance(module, nn.modules.batchnorm._BatchNorm):
40 | num_features = module.num_features
41 | num_groups = num_features//16
42 | mod = nn.GroupNorm(num_groups=num_groups, num_channels=num_features)
43 | for name, child in module.named_children():
44 | mod.add_module(name, convert_bn2gn(child))
45 | del module
46 | return mod
47 |
48 | def group_params(net):
49 | group_decay = []
50 | group_no_decay = []
51 | for m in net.modules():
52 | if isinstance(m, nn.modules.batchnorm._BatchNorm):# or isinstance(m, nn.Linear):
53 | group_no_decay.extend(m.parameters(recurse=False))
54 | else:
55 | for name, params in m.named_parameters(recurse=False):
56 | if 'bias' in name:
57 | group_no_decay.append(params)
58 | else:
59 | group_decay.append(params)
60 | return group_decay, group_no_decay
61 |
--------------------------------------------------------------------------------
/examples/utils/visualizer.py:
--------------------------------------------------------------------------------
1 | from visdom import Visdom
2 | import json
3 |
4 | class Visualizer(object):
5 | """ Visualizer
6 | """
7 | def __init__(self, port='13579', env='main', id=None):
8 | self.cur_win = {}
9 | self.vis = Visdom(port=port, env=env)
10 | self.id = id
11 | self.env = env
12 | # Restore
13 | ori_win = self.vis.get_window_data()
14 | ori_win = json.loads(ori_win)
15 | self.cur_win = { v['title']: k for k, v in ori_win.items() }
16 |
17 | def vis_scalar(self, win_name, trace_name, x, y, opts=None):
18 | """ Draw line
19 | """
20 | if not isinstance(x, list):
21 | x = [x]
22 | if not isinstance(y, list):
23 | y = [y]
24 |
25 | if self.id is not None:
26 | win_name = "[%s]"%self.id + win_name
27 |
28 | default_opts = { 'title': win_name }
29 |
30 | if opts is not None:
31 | default_opts.update(opts)
32 |
33 | win = self.cur_win.get(win_name, None)
34 |
35 | if win is not None:
36 | self.vis.line( X=x, Y=y, opts=default_opts, update='append',win=win, name=trace_name )
37 | else:
38 | self.cur_win[win_name] = self.vis.line( X=x, Y=y, opts=default_opts, name=trace_name)
39 |
40 | def vis_image(self, name, img, env=None, opts=None):
41 | """ vis image in visdom
42 | """
43 | if env is None:
44 | env = self.env
45 | if self.id is not None:
46 | name = "[%s]"%self.id + name
47 | win = self.cur_win.get(name, None)
48 | default_opts = { 'title': name }
49 | if opts is not None:
50 | default_opts.update(opts)
51 | if win is not None:
52 | self.vis.image( img=img, win=win, opts=opts, env=env )
53 | else:
54 | self.cur_win[name] = self.vis.image( img=img, opts=default_opts, env=env )
55 |
56 | def vis_table(self, name, tbl, opts=None):
57 | win = self.cur_win.get(name, None)
58 |
59 | tbl_str = "
"
60 | tbl_str+=" \
61 | | Term | \
62 | Value | \
63 |
"
64 | for k, v in tbl.items():
65 | tbl_str+= " \
66 | | %s | \
67 | %s | \
68 |
"%(k, v)
69 |
70 | tbl_str+="
"
71 |
72 | default_opts = { 'title': name }
73 | if opts is not None:
74 | default_opts.update(opts)
75 | if win is not None:
76 | self.vis.text(tbl_str, win=win, opts=default_opts)
77 | else:
78 | self.cur_win[name] = self.vis.text(tbl_str, opts=default_opts)
79 |
80 |
81 | if __name__=='__main__':
82 | import numpy as np
83 | vis = Visualizer(port=13500, env='main')
84 | tbl = {"lr": 214, "momentum": 0.9}
85 | vis.vis_table("test_table", tbl)
86 | tbl = {"lr": 244444, "momentum": 0.9, "haha": "hoho"}
87 | vis.vis_table("test_table", tbl)
88 |
--------------------------------------------------------------------------------
/SyncBN/syncbn.py:
--------------------------------------------------------------------------------
1 |
2 | from torch.nn.modules.batchnorm import _BatchNorm
3 | import torch.nn as nn
4 | import torch
5 | import torch.nn.functional as F
6 | from queue import Queue
7 |
8 | from .functions import SyncBNFunction
9 |
10 | import collections
11 |
12 | _bn_context = collections.namedtuple("_bn_context", ['sync', 'is_master', 'cur_device', 'queue', 'devices'])
13 |
14 | class SyncBatchNorm(_BatchNorm):
15 | """ Sync BN
16 | """
17 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True):
18 | super(SyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats)
19 |
20 | self.devices = list(range(torch.cuda.device_count()))
21 | self.sync = len(self.devices)>1
22 | self._slaves = self.devices[1:]
23 | self._queues = [ Queue(len(self._slaves)) ] + [ Queue(1) for _ in self._slaves ]
24 |
25 | def _check_input_dim(self, input):
26 | if input.dim() < 2:
27 | raise ValueError('expected at least 2 dims (got {}D input)'
28 | .format(input.dim()))
29 |
30 | def forward(self, input):
31 | self._check_input_dim(input)
32 |
33 | if not self.training and self.track_running_stats:
34 | return F.batch_norm(input, running_mean=self.running_mean, running_var=self.running_var,
35 | weight=self.weight, bias=self.bias, training=False, momentum=0.0, eps=self.eps)
36 | else:
37 | exponential_average_factor = 0.0
38 | if self.num_batches_tracked is not None: # track running statistics
39 | self.num_batches_tracked += 1
40 | if self.momentum is None: # use cumulative moving average
41 | exponential_average_factor = 1.0 / float(self.num_batches_tracked)
42 | else: # use exponential moving average
43 | exponential_average_factor = self.momentum
44 |
45 | if input.is_cuda:
46 | cur_device = input.get_device()
47 | bn_ctx = _bn_context( self.sync, (cur_device==self.devices[0]), cur_device, self._queues, self.devices )
48 | else:
49 | bn_ctx = _bn_context( False, True, None, None, None )
50 |
51 | return SyncBNFunction.apply( input, self.weight, self.bias, self.running_mean, self.running_var, exponential_average_factor, self.eps, self.training, bn_ctx )
52 |
53 |
54 | class SyncBatchNorm1d(SyncBatchNorm):
55 | def _check_input_dim(self, input):
56 | if input.dim() != 2 and input.dim() != 3:
57 | raise ValueError('expected 2D or 3D input (got {}D input)'
58 | .format(input.dim()))
59 |
60 |
61 | class SyncBatchNorm2d(SyncBatchNorm):
62 | def _check_input_dim(self, input):
63 | if input.dim() != 4:
64 | raise ValueError('expected 4D input (got {}D input)'
65 | .format(input.dim()))
66 |
67 |
68 | class SyncBatchNorm3d(SyncBatchNorm):
69 | def _check_input_dim(self, input):
70 | if input.dim() != 5:
71 | raise ValueError('expected 5D input (got {}D input)'
72 | .format(input.dim()))
73 |
--------------------------------------------------------------------------------
/examples/resnet.py:
--------------------------------------------------------------------------------
1 | '''ResNet in PyTorch.
2 | For Pre-activation ResNet, see 'preact_resnet.py'.
3 | Reference:
4 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
5 | Deep Residual Learning for Image Recognition. arXiv:1512.03385
6 | '''
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 |
12 | class BasicBlock(nn.Module):
13 | expansion = 1
14 |
15 | def __init__(self, in_planes, planes, stride=1):
16 | super(BasicBlock, self).__init__()
17 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
18 | self.bn1 = nn.BatchNorm2d(planes)
19 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
20 | self.bn2 = nn.BatchNorm2d(planes)
21 |
22 | self.shortcut = nn.Sequential()
23 | if stride != 1 or in_planes != self.expansion*planes:
24 | self.shortcut = nn.Sequential(
25 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
26 | nn.BatchNorm2d(self.expansion*planes)
27 | )
28 |
29 | def forward(self, x):
30 | out = F.relu(self.bn1(self.conv1(x)))
31 | out = self.bn2(self.conv2(out))
32 | out += self.shortcut(x)
33 | out = F.relu(out)
34 | return out
35 |
36 |
37 | class Bottleneck(nn.Module):
38 | expansion = 4
39 |
40 | def __init__(self, in_planes, planes, stride=1):
41 | super(Bottleneck, self).__init__()
42 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
43 | self.bn1 = nn.BatchNorm2d(planes)
44 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
45 | self.bn2 = nn.BatchNorm2d(planes)
46 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
47 | self.bn3 = nn.BatchNorm2d(self.expansion*planes)
48 |
49 | self.shortcut = nn.Sequential()
50 | if stride != 1 or in_planes != self.expansion*planes:
51 | self.shortcut = nn.Sequential(
52 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
53 | nn.BatchNorm2d(self.expansion*planes)
54 | )
55 |
56 | def forward(self, x):
57 | out = F.relu(self.bn1(self.conv1(x)))
58 | out = F.relu(self.bn2(self.conv2(out)))
59 | out = self.bn3(self.conv3(out))
60 | out += self.shortcut(x)
61 | out = F.relu(out)
62 | return out
63 |
64 |
65 | class ResNet(nn.Module):
66 | def __init__(self, block, num_blocks, num_classes=10):
67 | super(ResNet, self).__init__()
68 | self.in_planes = 64
69 |
70 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
71 | self.bn1 = nn.BatchNorm2d(64)
72 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
73 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
74 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
75 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
76 | self.linear = nn.Linear(512*block.expansion, num_classes)
77 |
78 | def _make_layer(self, block, planes, num_blocks, stride):
79 | strides = [stride] + [1]*(num_blocks-1)
80 | layers = []
81 | for stride in strides:
82 | layers.append(block(self.in_planes, planes, stride))
83 | self.in_planes = planes * block.expansion
84 | return nn.Sequential(*layers)
85 |
86 | def forward(self, x):
87 | out = F.relu(self.bn1(self.conv1(x)))
88 | out = self.layer1(out)
89 | out = self.layer2(out)
90 | out = self.layer3(out)
91 | out = self.layer4(out)
92 | out = F.avg_pool2d(out, 4)
93 | out = out.view(out.size(0), -1)
94 | out = self.linear(out)
95 | return out
96 |
97 |
98 | def resnet18():
99 | return ResNet(BasicBlock, [2,2,2,2])
100 |
101 | def resnet34():
102 | return ResNet(BasicBlock, [3,4,6,3])
103 |
104 | def resnet50():
105 | return ResNet(Bottleneck, [3,4,6,3])
106 |
107 | def resnet101():
108 | return ResNet(Bottleneck, [3,4,23,3])
109 |
110 | def resnet152():
111 | return ResNet(Bottleneck, [3,8,36,3])
112 |
113 |
--------------------------------------------------------------------------------
/examples/utils/stream_metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class _StreamMetrics(object):
5 | def __init__(self):
6 | """ Overridden by subclasses """
7 | raise NotImplementedError()
8 |
9 | def update(self, gt, pred):
10 | """ Overridden by subclasses """
11 | raise NotImplementedError()
12 |
13 | def get_results(self):
14 | """ Overridden by subclasses """
15 | raise NotImplementedError()
16 |
17 | def to_str(self, metrics):
18 | """ Overridden by subclasses """
19 | raise NotImplementedError()
20 |
21 | def reset(self):
22 | """ Overridden by subclasses """
23 | raise NotImplementedError()
24 |
25 | class StreamClsMetrics(_StreamMetrics):
26 | def __init__(self, n_classes):
27 | self.n_classes = n_classes
28 | self.confusion_matrix = np.zeros((n_classes, n_classes))
29 |
30 | def update(self, label_trues, label_preds):
31 | for lt, lp in zip(label_trues, label_preds):
32 | self.confusion_matrix[lp][lt] += 1
33 |
34 | @staticmethod
35 | def to_str(results):
36 | string = "\n"
37 | for k, v in results.items():
38 | if k!="Class IoU":
39 | string += "%s: %f\n"%(k, v)
40 |
41 | string+='Class IoU:\n'
42 | for k, v in results['Class IoU'].items():
43 | string += "\tclass %d: %f\n"%(k, v)
44 | return string
45 |
46 | def get_results(self):
47 | """Returns accuracy score evaluation result.
48 | - overall accuracy
49 | - mean accuracy
50 | - mean IU
51 | - fwavacc
52 | """
53 | hist = self.confusion_matrix
54 | acc = np.diag(hist).sum() / hist.sum()
55 | acc_cls = np.diag(hist) / hist.sum(axis=1)
56 | acc_cls = np.nanmean(acc_cls)
57 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist))
58 | mean_iu = np.nanmean(iu)
59 | freq = hist.sum(axis=1) / hist.sum()
60 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
61 | cls_iu = dict(zip(range(self.n_classes), iu))
62 |
63 | return {
64 | "Overall Acc": acc,
65 | "Mean Acc": acc_cls,
66 | "FreqW Acc": fwavacc,
67 | "Mean IoU": mean_iu,
68 | "Class IoU": cls_iu
69 | }
70 |
71 | def reset(self):
72 | self.confusion_matrix = np.zeros((self.n_classes, self.n_classes))
73 |
74 |
75 | class AverageMeter(object):
76 | """Computes and stores the average and current value"""
77 | def __init__(self):
78 | self.book = dict()
79 |
80 | def reset_all(self):
81 | self.book.clear()
82 |
83 | def reset(self, id):
84 | item = self.book.get(id, None)
85 | if item is not None:
86 | item[0] = 0
87 | item[1] = 0
88 |
89 | def update(self, id, val):
90 | record = self.book.get(id, None)
91 | if record is None:
92 | self.book[id] = [val, 1]
93 | else:
94 | record[0]+=val
95 | record[1]+=1
96 |
97 | def get_results(self, id):
98 | record = self.book.get(id, None)
99 | assert record is not None
100 | return record[0] / record[1]
101 |
102 | class VocClsMetrics(_StreamMetrics):
103 | def __init__(self, n_classes):
104 | self.n_classes = n_classes
105 | self.confusion_matrix = np.zeros(shape=(n_classes, 4)) # n_classes, TN, FP, FN, TP
106 | #00, 01, 10, 11 target-predict
107 |
108 | def update(self, label_trues, label_preds):
109 | for lt, lp in zip(label_trues, label_preds):
110 | idx = (tuple(range(self.n_classes)), lt*2+lp)
111 | self.confusion_matrix[idx] += 1
112 |
113 | def to_str(self, results):
114 | string = "\n"
115 | string += "Overall Acc: %f\n"%(results['Overall Acc'])
116 | string += "Overall Precision: %f\n"%(results['Overall Precision'])
117 | string += "Overall Recall: %f\n"%(results['Overall Recall'])
118 | string += "Overall F1: %f\n"%(results['Overall F1'])
119 |
120 | string+='Class Metrics:\n'
121 |
122 | for i in range(self.n_classes):
123 | string += "\tclass %d: acc=%f, precision=%f, recall=%f, f1=%f\n"%(i,results['Class Acc'][i],results['Class Precision'][i],results['Class Recall'][i],results['Class F1'][i] )
124 | return string
125 |
126 | def get_results(self):
127 | TN = self.confusion_matrix[:, 0]
128 | FP = self.confusion_matrix[:, 1]
129 | FN = self.confusion_matrix[:, 2]
130 | TP = self.confusion_matrix[:, 3]
131 |
132 | class_accuracy = np.nan_to_num( ( TN+TP ) / (TN+FP+FN+TP) )
133 | class_precision = np.nan_to_num( TP / ( TP+FP ) )
134 | class_recall = np.nan_to_num( TP / ( TP+FN ) )
135 | class_f1 = np.nan_to_num( 2* (class_precision * class_recall) / (class_precision+class_recall) )
136 |
137 | return {'Overall Acc': class_accuracy.mean(),
138 | 'Overall Precision': class_precision.mean(),
139 | 'Overall Recall': class_recall.mean(),
140 | 'Overall F1': class_f1.mean(),
141 | 'Class Acc': class_accuracy,
142 | 'Class Precision': class_precision,
143 | 'Class Recall': class_recall,
144 | 'Class F1': class_f1}
145 |
146 | def reset(self):
147 | self.correct = np.zeros(self.n_classes)
148 |
--------------------------------------------------------------------------------
/SyncBN/functions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Function
3 | from torch.autograd.function import once_differentiable
4 | import torch.nn.functional as F
5 |
6 | def unsqueeze(tensor):
7 | return tensor.unsqueeze(1).unsqueeze(0)
8 |
9 | class SyncBNFunction(Function):
10 |
11 | @staticmethod
12 | def forward(ctx, x, weight, bias, running_mean, running_var, momentum, eps, training, bn_ctx):
13 | x_shape = x.shape
14 | B, C = x_shape[:2]
15 |
16 | _x = x.view(B,C,-1).contiguous()
17 |
18 | ctx.eps = eps
19 | ctx.training = training
20 |
21 | ctx.sync = bn_ctx.sync
22 | ctx.cur_device = bn_ctx.cur_device
23 | ctx.queue = bn_ctx.queue
24 | ctx.is_master = bn_ctx.is_master
25 | ctx.devices = bn_ctx.devices
26 |
27 | norm = 1/(_x.shape[0] * _x.shape[2])
28 |
29 | if ctx.training:
30 | _ex = _x.sum(2).sum(0) * norm
31 | _exs = _x.pow(2).sum(2).sum(0) * norm
32 |
33 | if ctx.sync:
34 | if ctx.is_master:
35 |
36 | _ex, _exs = [_ex.unsqueeze(1)], [_exs.unsqueeze(1)]
37 |
38 | master_queue = ctx.queue[0]
39 | for j in range(master_queue.maxsize):
40 | _slave_ex, _slave_exs = master_queue.get()
41 | master_queue.task_done()
42 |
43 | _ex.append( _slave_ex.unsqueeze(1) )
44 | _exs.append( _slave_exs.unsqueeze(1) )
45 |
46 | _ex = torch.cuda.comm.gather( _ex, dim=1 ).mean(1)
47 | _exs = torch.cuda.comm.gather( _exs, dim=1 ).mean(1)
48 |
49 | distributed_tensor = torch.cuda.comm.broadcast_coalesced( (_ex, _exs), ctx.devices )
50 |
51 | for dt, q in zip( distributed_tensor[1:], ctx.queue[1:] ):
52 | q.put(dt)
53 | else:
54 | master_queue = ctx.queue[0]
55 | slave_queue = ctx.queue[ctx.cur_device]
56 | master_queue.put( (_ex, _exs) )
57 |
58 | _ex, _exs = slave_queue.get()
59 | slave_queue.task_done()
60 | _ex, _exs = _ex.squeeze(), _exs.squeeze()
61 |
62 | _var = _exs - _ex.pow(2)
63 | N = B*len(ctx.devices)
64 | unbiased_var = _var * N / (N - 1)
65 |
66 | running_mean.mul_( (1-momentum) ).add_( momentum * _ex )
67 | running_var.mul_( (1-momentum) ).add_( momentum * unbiased_var )
68 | ctx.mark_dirty(running_mean, running_var)
69 | else:
70 | _ex, _var = running_mean.contiguous(), running_var.contiguous()
71 | _exs = _ex.pow(2) + _var
72 |
73 | invstd = 1/torch.sqrt( _var + eps )
74 |
75 | if weight is not None: # affine
76 | output = (_x - unsqueeze(_ex) ) * unsqueeze(invstd) * unsqueeze(weight) + unsqueeze(bias)
77 | else:
78 | output = (_x - unsqueeze(_ex) ) * unsqueeze(invstd)
79 |
80 | ctx.save_for_backward(x, _ex, _exs, weight, bias)
81 | return output.view(*x_shape).contiguous().clone()
82 |
83 | @staticmethod
84 | def backward(ctx, grad_output):
85 | x, _ex, _exs, weight, bias = ctx.saved_tensors
86 | grad_x = grad_weight = grad_bias = None
87 |
88 | B,C = grad_output.shape[:2]
89 | grad_output_shape = grad_output.shape
90 |
91 | _var = _exs - _ex.pow(2)
92 | _std = torch.sqrt( _var + ctx.eps)
93 | invstd = 1.0 / _std
94 |
95 | grad_output = grad_output.view(B,C,-1)
96 | x = x.view(B,C,-1)
97 |
98 | norm = 1.0/(x.shape[0] * x.shape[2])
99 |
100 | dot_p = ( grad_output * ( x - unsqueeze( _ex ) ) ).sum(2).sum(0)
101 | grad_output_sum = grad_output.sum(2).sum(0)
102 |
103 | grad_scale = weight * invstd
104 |
105 | grad_ex = -grad_output_sum * grad_scale + _ex * invstd * invstd * dot_p * grad_scale
106 | grad_exs = -0.5 * grad_scale * invstd * invstd * dot_p
107 |
108 | # Sync
109 | if ctx.training:
110 | if ctx.sync:
111 | if ctx.is_master:
112 | grad_ex, grad_exs = [grad_ex.unsqueeze(1)], [grad_exs.unsqueeze(1)]
113 | master_queue = ctx.queue[0]
114 | for j in range(master_queue.maxsize):
115 | grad_slave_ex, grad_slave_exs = master_queue.get()
116 | master_queue.task_done()
117 |
118 | grad_ex.append( grad_slave_ex.unsqueeze(1) )
119 | grad_exs.append( grad_slave_exs.unsqueeze(1) )
120 |
121 | grad_ex = torch.cuda.comm.gather( grad_ex, dim=1 ).mean(1)
122 | grad_exs = torch.cuda.comm.gather( grad_exs, dim=1).mean(1)
123 |
124 | distributed_tensor = torch.cuda.comm.broadcast_coalesced( (grad_ex, grad_exs), ctx.devices )
125 | for dt, q in zip( distributed_tensor[1:], ctx.queue[1:] ):
126 | q.put(dt)
127 | else:
128 | master_queue = ctx.queue[0]
129 | slave_queue = ctx.queue[ctx.cur_device]
130 | master_queue.put( (grad_ex, grad_exs) )
131 |
132 | grad_ex, grad_exs = slave_queue.get()
133 | slave_queue.task_done()
134 | grad_ex, grad_exs = grad_ex.squeeze(), grad_exs.squeeze()
135 |
136 | if ctx.needs_input_grad[0]:
137 | grad_x = grad_output * unsqueeze( grad_scale ) + unsqueeze( grad_ex * norm ) + unsqueeze(grad_exs) * 2 * x * norm
138 |
139 | if ctx.needs_input_grad[1]:
140 | grad_weight = dot_p * invstd
141 |
142 | if ctx.needs_input_grad[2]:
143 | grad_bias = grad_output_sum
144 |
145 | return grad_x.view(*grad_output_shape), grad_weight, grad_bias, None, None, None, None, None, None
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
--------------------------------------------------------------------------------
/examples/cifar.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from resnet import resnet18, resnet34, resnet50, resnet101, resnet152
5 | from torchvision.datasets import CIFAR10
6 | from torchvision import transforms
7 | import argparse
8 | import os, sys
9 |
10 | import utils
11 | import numpy as np
12 | import random
13 | from utils import Visualizer
14 |
15 | from torch.utils import data
16 |
17 | sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
18 | from SyncBN import SyncBatchNorm2d
19 | from SyncBN.utils import convert_sync_batchnorm
20 | from tqdm import tqdm
21 |
22 | def get_argparser():
23 | parser = argparse.ArgumentParser()
24 |
25 | # Datset Options
26 | parser.add_argument("--data_root", type=str, default='./data',
27 | help="path to Dataset")
28 |
29 | # Train Options
30 | parser.add_argument("--model", type=str, default='resnet18',
31 | choices=['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'])
32 | parser.add_argument("--epochs", type=int, default=40,
33 | help="epoch number (default: 40)")
34 | parser.add_argument("--lr", type=float, default=0.1,
35 | help="learning rate (default: 0.1)")
36 |
37 | parser.add_argument("--batch_size", type=int, default=128,
38 | help='batch size (default: 128)')
39 | parser.add_argument("--lr_decay_step", type=int, default=150,
40 | help='batch size (default: 150)')
41 | parser.add_argument("--ckpt", default=None, type=str,
42 | help="path to trained model. Leave it None if you want to retrain your model")
43 | parser.add_argument("--gpu_id", type=str, default='0',
44 | help="GPU ID")
45 |
46 | parser.add_argument("--momentum", type=float, default=0.9,
47 | help='momentum for SGD (default: 0.9)')
48 | parser.add_argument("--weight_decay", type=float, default=5e-4,
49 | help='weight decay (default: 5e-4)')
50 |
51 | parser.add_argument("--num_workers", type=int, default=4,
52 | help='number of workers (default: 4)')
53 | parser.add_argument("--val_on_trainset", action='store_true', default=False ,
54 | help="enable validation on train set (default: False)")
55 | parser.add_argument("--random_seed", type=int, default=23333,
56 | help="random seed (default: 23333)")
57 | parser.add_argument("--print_interval", type=int, default=10,
58 | help="print interval of loss (default: 10)")
59 | parser.add_argument("--val_interval", type=int, default=1,
60 | help="epoch interval for eval (default: 1)")
61 | parser.add_argument("--ckpt_interval", type=int, default=1,
62 | help="saving interval (default: 1)")
63 | parser.add_argument("--download", action='store_true', default=False,
64 | help="download datasets")
65 | parser.add_argument("--sync_bn", action='store_true', default=False,
66 | help="sync batchnorm")
67 |
68 | # Visdom options
69 | parser.add_argument("--enable_vis", action='store_true', default=False,
70 | help="use visdom for visualization")
71 | parser.add_argument("--vis_port", type=str, default='15555',
72 | help='port for visdom')
73 | parser.add_argument("--vis_env", type=str, default='main',
74 | help='env for visdom')
75 | parser.add_argument("--trace_name", type=str, default=None)
76 | return parser
77 |
78 |
79 |
80 | def train( cur_epoch, criterion, model, optim, train_loader, device, scheduler=None, print_interval=10, vis=None, trace_name=None):
81 | """Train and return epoch loss"""
82 |
83 | if scheduler is not None:
84 | scheduler.step()
85 | print("Epoch %d, lr = %f"%(cur_epoch, optim.param_groups[0]['lr']))
86 | epoch_loss = 0.0
87 | interval_loss = 0.0
88 |
89 | for cur_step, (images, labels) in enumerate( train_loader ):
90 |
91 | images = images.to(device, dtype=torch.float32)
92 | labels = labels.to(device, dtype=torch.long)
93 |
94 | # N, C, H, W
95 | optim.zero_grad()
96 | outputs = model(images)
97 | loss = criterion(outputs, labels)
98 |
99 | loss.backward()
100 | optim.step()
101 |
102 | np_loss = loss.detach().cpu().numpy()
103 | epoch_loss+=np_loss
104 | interval_loss+=np_loss
105 |
106 | if (cur_step+1)%print_interval==0:
107 | interval_loss = interval_loss/print_interval
108 | print("Epoch %d, Batch %d/%d, Loss=%f"%(cur_epoch, cur_step+1, len(train_loader), interval_loss))
109 | if vis is not None:
110 | x = cur_epoch*len(train_loader) + cur_step + 1
111 | vis.vis_scalar('Loss', trace_name, x, interval_loss )
112 | interval_loss=0.0
113 | return epoch_loss / len(train_loader)
114 |
115 |
116 | def validate( model, loader, device, metrics):
117 | """Do validation and return specified samples"""
118 | metrics.reset()
119 | with torch.no_grad():
120 | for i, (images, labels) in enumerate( tqdm( loader ) ):
121 |
122 | images = images.to(device, dtype=torch.float32)
123 | labels = labels.to(device, dtype=torch.long)
124 |
125 | outputs = model(images)
126 | preds = outputs.detach().max(dim=1)[1].cpu().numpy()
127 | targets = labels.cpu().numpy()
128 |
129 | metrics.update(targets, preds)
130 | score = metrics.get_results()
131 | return score
132 |
133 |
134 | def main():
135 | opts = get_argparser().parse_args()
136 | # Set up visualization
137 | vis = Visualizer(port=opts.vis_port, env=opts.vis_env)
138 | if vis is not None: # display options
139 | vis.vis_table( "%s opts"%opts.trace_name, vars(opts) )
140 | os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_id
141 | device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' )
142 | print("Device: %s"%device)
143 |
144 | # Set up random seed
145 | torch.manual_seed(opts.random_seed)
146 | torch.cuda.manual_seed(opts.random_seed)
147 | np.random.seed(opts.random_seed)
148 | random.seed(opts.random_seed)
149 |
150 | # Set up dataloader
151 | train_dst = CIFAR10(root='./data', train=True,
152 | transform=transforms.Compose([
153 | transforms.RandomHorizontalFlip(),
154 | transforms.ToTensor(),
155 | transforms.Normalize( mean=[0.485, 0.456, 0.406],
156 | std=[0.229, 0.224, 0.225] )]),
157 | download=opts.download )
158 |
159 | val_dst = CIFAR10(root='./data', train=False,
160 | transform=transforms.Compose([
161 | transforms.ToTensor(),
162 | transforms.Normalize( mean=[0.485, 0.456, 0.406],
163 | std=[0.229, 0.224, 0.225] )]),
164 | download=False )
165 |
166 | train_loader = data.DataLoader(train_dst, batch_size=opts.batch_size, shuffle=True, num_workers=opts.num_workers)
167 | val_loader = data.DataLoader(val_dst, batch_size=opts.batch_size, shuffle=False, num_workers=opts.num_workers)
168 | print("Dataset: CIFAR10, Train set: %d, Val set: %d"%(len(train_dst), len(val_dst)))
169 |
170 | model = {"resnet18": resnet18,
171 | "resnet34": resnet34,
172 | "resnet50": resnet50,
173 | "resnet101": resnet101,
174 | "resnet152": resnet152 }[opts.model]()
175 |
176 | trace_name = opts.trace_name
177 | if opts.sync_bn==True:
178 | print("Use sync batchnorm")
179 | model = convert_sync_batchnorm(model)
180 | print(model)
181 |
182 | if torch.cuda.device_count()>1: # Parallel
183 | print("%d GPU parallel"%(torch.cuda.device_count()))
184 | model = torch.nn.DataParallel(model)
185 | model_ref = model.module # for ckpt
186 | else:
187 | model_ref = model
188 | model = model.to(device)
189 |
190 | # Set up metrics
191 | metrics = utils.StreamClsMetrics(10)
192 |
193 | # Set up optimizer
194 | group_decay, group_no_decay = utils.group_params(model_ref)
195 | assert(len(group_decay)+len(group_no_decay) == len(list(model_ref.parameters())))
196 | optimizer = torch.optim.SGD([ {'params': group_decay, 'weight_decay': opts.weight_decay},
197 | {'params': group_no_decay}],
198 | lr=opts.lr, momentum=opts.momentum )
199 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.lr_decay_step, gamma=0.1)
200 | print("optimizer:\n%s"%(optimizer))
201 |
202 | utils.mkdir('checkpoints')
203 |
204 | # Restore
205 | best_score = 0.0
206 | cur_epoch = 0
207 | if opts.ckpt is not None and os.path.isfile(opts.ckpt):
208 | checkpoint = torch.load(opts.ckpt)
209 | model_ref.load_state_dict(checkpoint["model_state"])
210 | optimizer.load_state_dict(checkpoint["optimizer_state"])
211 | scheduler.load_state_dict(checkpoint["scheduler_state"])
212 | cur_epoch = checkpoint["epoch"]+1
213 | best_score = checkpoint['best_score']
214 | print("Model restored from %s"%opts.ckpt)
215 | del checkpoint # free memory
216 | else:
217 | print("[!] Retrain")
218 |
219 | def save_ckpt(path):
220 | """ save current model
221 | """
222 | state = {
223 | "epoch": cur_epoch,
224 | "model_state": model_ref.state_dict(),
225 | "optimizer_state": optimizer.state_dict(),
226 | "scheduler_state": scheduler.state_dict(),
227 | "best_score": best_score,
228 | }
229 | torch.save(state, path)
230 | print( "Model saved as %s"%path )
231 |
232 | # Set up criterion
233 | criterion = nn.CrossEntropyLoss(reduction='mean')
234 | #========== Train Loop ==========#
235 | while cur_epoch < opts.epochs:
236 | # ===== Train =====
237 | model.train()
238 | epoch_loss = train(cur_epoch=cur_epoch, criterion=criterion, model=model, optim=optimizer, train_loader=train_loader, device=device, scheduler=scheduler, vis=vis, trace_name=trace_name)
239 | print("End of Epoch %d/%d, Average Loss=%f"%(cur_epoch, opts.epochs, epoch_loss))
240 |
241 | if opts.enable_vis:
242 | vis.vis_scalar("Epoch Loss", trace_name, cur_epoch, epoch_loss )
243 |
244 | # ===== Save Latest Model =====
245 | if (cur_epoch+1)%opts.ckpt_interval==0:
246 | save_ckpt( 'checkpoints/latest_resnet34_cifar10.pkl' )
247 |
248 | # ===== Validation =====
249 | if (cur_epoch+1)%opts.val_interval==0:
250 | print("validate on val set...")
251 | model.eval()
252 | val_score = validate(model=model, loader=val_loader, device=device, metrics=metrics)
253 | print(metrics.to_str(val_score))
254 |
255 | # ===== Save Best Model =====
256 | if val_score['Mean IoU']>best_score: # save best model
257 | best_score = val_score['Overall Acc']
258 | save_ckpt( 'checkpoints/latest_resnet34_cifar10.pkl')
259 |
260 | if vis is not None: # visualize validation score and samples
261 | vis.vis_scalar("[Val] Overall Acc",trace_name, cur_epoch, val_score['Overall Acc'] )
262 |
263 | if opts.val_on_trainset==True: # validate on train set
264 | print("validate on train set...")
265 | model.eval()
266 | train_score = validate(model=model, loader=train_loader, device=device, metrics=metrics)
267 | print(metrics.to_str(train_score))
268 | if vis is not None:
269 | vis.vis_scalar("[Train] Overall Acc", trace_name, cur_epoch, train_score['Overall Acc'] )
270 |
271 | cur_epoch+=1
272 |
273 | if __name__=='__main__':
274 | main()
275 |
276 |
277 |
278 |
279 |
280 |
281 |
--------------------------------------------------------------------------------