├── .gitignore ├── .idea └── vcs.xml ├── LICENSE ├── README.md ├── bn_fusion.py ├── dog.jpg ├── test_convert_inference.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Aleksei Tiulpin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Batch Norm Fusion for Pytorch 2 | 3 | ## About 4 | 5 | In this repository, we present a simplistic implementation of batchnorm fusion for the most popular CNN architectures in PyTorch. 6 | This package is aimed to speed up the inference at the test time: **expected boost is 30%!** In the future 7 | 8 | ## How it works 9 | 10 | We know that both - convolution and batchnorm are the linear operations to the data point x, and they can be written in terms of matrix multiplications: 11 | ![T_{bn}*S{bn}*Conv_W*(x)](https://latex.codecogs.com/gif.latex?T_{bn}*S_{bn}*W_{conv}*x), 12 | where we first apply convolution to the data, scale it and eventually shift it using the batchnorm-trained parameters. 13 | 14 | ## Supported architectures 15 | 16 | We support any architecture, where Conv and BN are combined in a Sequential module. 17 | If you want to optimize your own networks with this tool, just follow this design. 18 | For the conveniece, we wrapped VGG, ResNet and SeNet families to demonstrate how your models can be converted into such format. 19 | 20 | - [x] VGG from torchvision. 21 | - [x] ResNet Family from `torchvision`. 22 | - [x] SeNet family from `pretrainedmodels` 23 | 24 | ## How to use 25 | 26 | ```python 27 | import torchvision.models as models 28 | from bn_fusion import fuse_bn_recursively 29 | 30 | net = getattr(models,'vgg16_bn')(pretrained=True) 31 | net = fuse_bn_recursively(net) 32 | net.eval() 33 | # Make inference with the converted model 34 | ``` 35 | ## TODO 36 | 37 | - [ ] Tests. 38 | - [ ] Performance benchmarks. 39 | 40 | ## Acknowledgements 41 | 42 | Thanks to [@ZFTurbo](https://github.com/ZFTurbo) for the idea, discussions and his [implementation for Keras](https://github.com/ZFTurbo/Keras-inference-time-optimizer). 43 | -------------------------------------------------------------------------------- /bn_fusion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def fuse_bn_sequential(block): 6 | """ 7 | This function takes a sequential block and fuses the batch normalization with convolution 8 | 9 | :param model: nn.Sequential. Source resnet model 10 | :return: nn.Sequential. Converted block 11 | """ 12 | if not isinstance(block, nn.Sequential): 13 | return block 14 | stack = [] 15 | for m in block.children(): 16 | if isinstance(m, nn.BatchNorm2d): 17 | if isinstance(stack[-1], nn.Conv2d): 18 | bn_st_dict = m.state_dict() 19 | conv_st_dict = stack[-1].state_dict() 20 | 21 | # BatchNorm params 22 | eps = m.eps 23 | mu = bn_st_dict['running_mean'] 24 | var = bn_st_dict['running_var'] 25 | gamma = bn_st_dict['weight'] 26 | 27 | if 'bias' in bn_st_dict: 28 | beta = bn_st_dict['bias'] 29 | else: 30 | beta = torch.zeros(gamma.size(0)).float().to(gamma.device) 31 | 32 | # Conv params 33 | W = conv_st_dict['weight'] 34 | if 'bias' in conv_st_dict: 35 | bias = conv_st_dict['bias'] 36 | else: 37 | bias = torch.zeros(W.size(0)).float().to(gamma.device) 38 | 39 | denom = torch.sqrt(var + eps) 40 | b = beta - gamma.mul(mu).div(denom) 41 | A = gamma.div(denom) 42 | bias *= A 43 | A = A.expand_as(W.transpose(0, -1)).transpose(0, -1) 44 | 45 | W.mul_(A) 46 | bias.add_(b) 47 | 48 | stack[-1].weight.data.copy_(W) 49 | if stack[-1].bias is None: 50 | stack[-1].bias = torch.nn.Parameter(bias) 51 | else: 52 | stack[-1].bias.data.copy_(bias) 53 | 54 | else: 55 | stack.append(m) 56 | 57 | if len(stack) > 1: 58 | return nn.Sequential(*stack) 59 | else: 60 | return stack[0] 61 | 62 | 63 | def fuse_bn_recursively(model): 64 | for module_name in model._modules: 65 | model._modules[module_name] = fuse_bn_sequential(model._modules[module_name]) 66 | if len(model._modules[module_name]._modules) > 0: 67 | fuse_bn_recursively(model._modules[module_name]) 68 | 69 | return model 70 | -------------------------------------------------------------------------------- /dog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imedslab/pytorch_bn_fusion/b235f24abdd8b972d844d05964dfa4c03feb4ce1/dog.jpg -------------------------------------------------------------------------------- /test_convert_inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | import torchvision.models as models 4 | from torchvision import transforms 5 | 6 | import argparse 7 | from PIL import Image 8 | import numpy as np 9 | from bn_fusion import fuse_bn_recursively 10 | from utils import convert_resnet_family 11 | import pretrainedmodels 12 | import time 13 | 14 | if __name__ == '__main__': 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--model', default='vgg16_bn') 17 | args = parser.parse_args() 18 | 19 | trf = transforms.Compose([ 20 | transforms.Resize((224, 224)), 21 | transforms.ToTensor(), 22 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 23 | ]) 24 | 25 | img = trf(Image.open('dog.jpg')).unsqueeze(0) 26 | 27 | try: 28 | net = getattr(models, args.model)(pretrained=True) 29 | except: 30 | net = pretrainedmodels.__dict__[args.model](num_classes=1000, pretrained='imagenet') 31 | 32 | 33 | 34 | if 'resnet' in args.model: 35 | se = True if 'se' in args.model else False 36 | net = convert_resnet_family(net, se) 37 | 38 | # Benchmarking 39 | # First, we run the network the way it is 40 | net.eval() 41 | with torch.no_grad(): 42 | F.softmax(net(img), 1) 43 | # Measuring non-optimized model performance 44 | times = [] 45 | for i in range(50): 46 | start = time.time() 47 | with torch.no_grad(): 48 | res_0 = F.softmax(net(img), 1) 49 | times.append(time.time() - start) 50 | 51 | print('Non fused takes', np.mean(times), 'seconds') 52 | 53 | net = fuse_bn_recursively(net) 54 | net.eval() 55 | times = [] 56 | for i in range(50): 57 | start = time.time() 58 | with torch.no_grad(): 59 | res_1 = F.softmax(net(img), 1) 60 | times.append(time.time() - start) 61 | 62 | print('Fused takes', np.mean(times), 'seconds') 63 | 64 | diff = res_0 - res_1 65 | print('L2 Norm of the element-wise difference:', diff.norm().item()) 66 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torchvision.models import resnet as resnet_modules 3 | from pretrainedmodels.models import senet as senet_modules 4 | 5 | class Net(nn.Module): 6 | def __init__(self, features, classifer): 7 | super(Net, self).__init__() 8 | self.features = features 9 | self.pool = nn.AdaptiveAvgPool2d(1) 10 | self.classifier = classifer 11 | 12 | def forward(self, x): 13 | out = self.features(x) 14 | out = self.pool(out).view(x.size(0), -1) 15 | return self.classifier(out) 16 | 17 | 18 | def convert_resnet_family(model, se=False): 19 | """ 20 | This function wraps any (se)resnet model from torchvision or from pretrained models 21 | :param model: nn.Sequential 22 | :return: nn.Sequential 23 | """ 24 | 25 | features = list() 26 | if not se: 27 | layer0 = nn.Sequential( 28 | model.conv1, 29 | model.bn1, 30 | model.relu, 31 | model.maxpool 32 | ) 33 | features.append(layer0) 34 | else: 35 | features.append(model.layer0) 36 | 37 | for ind in range(1, 5): 38 | modules_layer = model._modules[f'layer{ind}']._modules 39 | new_modules = [] 40 | for block_name in modules_layer: 41 | b = modules_layer[block_name] 42 | if isinstance(b, resnet_modules.BasicBlock): 43 | b = BasicResnetBlock(b) 44 | 45 | if isinstance(b, resnet_modules.Bottleneck) or \ 46 | isinstance(b, senet_modules.SEBottleneck) or \ 47 | isinstance(b, senet_modules.SEResNetBottleneck) or \ 48 | isinstance(b, senet_modules.SEResNeXtBottleneck): 49 | 50 | b = BottleneckResnetBlock(b, se) 51 | new_modules.append(b) 52 | features.append(nn.Sequential(*new_modules)) 53 | 54 | features = nn.Sequential(*features) 55 | if not se: 56 | classifier = model.fc 57 | else: 58 | classifier = model.last_linear 59 | 60 | return Net(features, classifier) 61 | 62 | 63 | class BasicResnetBlock(nn.Module): 64 | expansion = 1 65 | 66 | def __init__(self, source_block): 67 | super(BasicResnetBlock, self).__init__() 68 | self.block1 = nn.Sequential( 69 | source_block.conv1, 70 | source_block.bn1 71 | ) 72 | 73 | self.block2 = nn.Sequential( 74 | source_block.conv2, 75 | source_block.bn2 76 | ) 77 | 78 | self.downsample = source_block.downsample 79 | self.stride = source_block.stride 80 | self.relu = nn.ReLU(inplace=True) 81 | 82 | def forward(self, x): 83 | residual = x 84 | 85 | out = self.relu(self.block1(x)) 86 | out = self.block2(out) 87 | 88 | if self.downsample is not None: 89 | residual = self.downsample(x) 90 | 91 | out += residual 92 | out = self.relu(out) 93 | 94 | return out 95 | 96 | 97 | class BottleneckResnetBlock(nn.Module): 98 | expansion = 4 99 | 100 | def __init__(self, source_block, se=False): 101 | super(BottleneckResnetBlock, self).__init__() 102 | self.block1 = nn.Sequential( 103 | source_block.conv1, 104 | source_block.bn1, 105 | ) 106 | 107 | self.block2 = nn.Sequential( 108 | source_block.conv2, 109 | source_block.bn2 110 | ) 111 | 112 | self.block3 = nn.Sequential( 113 | source_block.conv3, 114 | source_block.bn3 115 | ) 116 | self.relu = nn.ReLU(inplace=True) 117 | 118 | self.downsample = source_block.downsample 119 | self.stride = source_block.stride 120 | if se: 121 | self.se_module = source_block.se_module 122 | else: 123 | self.se_module = None 124 | 125 | def forward(self, x): 126 | residual = x 127 | 128 | out = self.relu(self.block1(x)) 129 | out = self.relu(self.block2(out)) 130 | out = self.block3(out) 131 | 132 | if self.downsample is not None: 133 | residual = self.downsample(x) 134 | 135 | if self.se_module is not None: 136 | out += self.se_module(out) 137 | out += residual 138 | out = self.relu(out) 139 | 140 | return out 141 | --------------------------------------------------------------------------------