├── .gitignore ├── LICENSE ├── README.md ├── ptflops ├── __init__.py └── flops_counter.py ├── sample.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | .idea 106 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Vladislav Sovrasov 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Flops counter for convolutional networks in pytorch framework 2 | [![Pypi version](https://img.shields.io/pypi/v/ptflops.svg)](https://pypi.org/project/ptflops/) 3 | 4 | This script is designed to compute the theoretical amount of multiply-add operations 5 | in convolutional neural networks. It also can compute the number of parameters and 6 | print per-layer computational cost of a given network. 7 | 8 | Supported layers: 9 | - Convolution2d (including grouping) 10 | - BatchNorm2d 11 | - Activations (ReLU, PReLU, ELU, ReLU6, LeakyReLU) 12 | - Linear 13 | - Upsample 14 | - Poolings (AvgPool2d, MaxPool2d and adaptive ones) 15 | 16 | Requirements: Pytorch 0.4.1 or 1.0, torchvision 0.2.1 17 | 18 | Thanks to @warmspringwinds for the initial version of script. 19 | 20 | ## Install the latest version 21 | ```bash 22 | pip install --upgrade git+https://github.com/zhouyuangan/flops-counter.pytorch.git 23 | ``` 24 | 25 | ## Example 26 | ```python 27 | import torchvision.models as models 28 | import torch 29 | from ptflops import get_model_complexity_info 30 | 31 | with torch.cuda.device(0): 32 | net = models.densenet161() 33 | flops, params = get_model_complexity_info(net, (3, 224, 224), as_strings=True, print_per_layer_stat=True) 34 | print('Flops: ' + flops) 35 | print('Params: ' + params) 36 | ``` 37 | 38 | ## Benchmark 39 | 40 | ### [torchvision](https://pytorch.org/docs/1.0.0/torchvision/models.html) 41 | 42 | Model | Input Resolution | Params(M) | MACs(G) | Top-1 error | Top-5 error 43 | --- |--- |--- |--- |--- |--- 44 | alexnet |224x224 | 61.1 | 0.72 | 43.45 | 20.91 45 | vgg11 |224x224 | 132.86 | 7.63 | 30.98 | 11.37 46 | vgg13 |224x224 | 133.05 | 11.34 | 30.07 | 10.75 47 | vgg16 |224x224 | 138.36 | 15.5 | 28.41 | 9.62 48 | vgg19 |224x224 | 143.67 | 19.67 | 27.62 | 9.12 49 | vgg11_bn |224x224 | 132.87 | 7.64 | 29.62 | 10.19 50 | vgg13_bn |224x224 | 133.05 | 11.36 | 28.45 | 9.63 51 | vgg16_bn |224x224 | 138.37 | 15.53 | 26.63 | 8.50 52 | vgg19_bn |224x224 | 143.68 | 19.7 | 25.76 | 8.15 53 | resnet18 |224x224 | 11.69 | 1.82 | 30.24 | 10.92 54 | resnet34 |224x224 | 21.8 | 3.68 | 26.70 | 8.58 55 | resnet50 |224x224 | 25.56 | 4.12 | 23.85 | 7.13 56 | resnet101 |224x224 | 44.55 | 7.85 | 22.63 | 6.44 57 | resnet152 |224x224 | 60.19 | 11.58 | 21.69 | 5.94 58 | squeezenet1_0 |224x224 | 1.25 | 0.83 | 41.90 | 19.58 59 | squeezenet1_1 |224x224 | 1.24 | 0.36 | 41.81 | 19.38 60 | densenet121 |224x224 | 7.98 | 2.88 | 25.35 | 7.83 61 | densenet169 |224x224 | 14.15 | 3.42 | 24.00 | 7.00 62 | densenet201 |224x224 | 20.01 | 4.37 | 22.80 | 6.43 63 | densenet161 |224x224 | 28.68 | 7.82 | 22.35 | 6.20 64 | inception_v3 |224x224 | 27.16 | 2.85 | 22.55 | 6.44 65 | 66 | * Top-1 error - ImageNet single-crop top-1 error (224x224) 67 | * Top-5 error - ImageNet single-crop top-5 error (224x224) 68 | -------------------------------------------------------------------------------- /ptflops/__init__.py: -------------------------------------------------------------------------------- 1 | from .flops_counter import get_model_complexity_info 2 | -------------------------------------------------------------------------------- /ptflops/flops_counter.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import numpy as np 4 | 5 | 6 | def get_model_complexity_info(model, input_res, print_per_layer_stat=True, as_strings=True, 7 | input_constructor=None, is_cuda=False): 8 | """ 9 | 10 | :param model: 11 | :param input_res: 12 | :param print_per_layer_stat: 13 | :param as_strings: 14 | :param input_constructor: 15 | :param is_cuda: model and input both put on GPUs 16 | :return: 17 | """ 18 | assert type(input_res) is tuple 19 | assert len(input_res) == 3 20 | if is_cuda: 21 | model = model.cuda() 22 | flops_model = add_flops_counting_methods(model) 23 | flops_model.eval().start_flops_count() 24 | if input_constructor: 25 | input = input_constructor(input_res) 26 | _ = flops_model(**input) 27 | else: 28 | if is_cuda: 29 | batch = torch.FloatTensor(1, *input_res).cuda() 30 | else: 31 | batch = torch.FloatTensor(1, *input_res) 32 | _ = flops_model(batch) 33 | 34 | if print_per_layer_stat: 35 | print_model_with_flops(flops_model) 36 | flops_count = flops_model.compute_average_flops_cost() 37 | params_count = get_model_parameters_number(flops_model) 38 | flops_model.stop_flops_count() 39 | 40 | if as_strings: 41 | return flops_to_string(flops_count), params_to_string(params_count) 42 | 43 | return flops_count, params_count 44 | 45 | def flops_to_string(flops, units='GMac', precision=2): 46 | if units is None: 47 | if flops // 10**9 > 0: 48 | return str(round(flops / 10.**9, precision)) + ' GMac' 49 | elif flops // 10**6 > 0: 50 | return str(round(flops / 10.**6, precision)) + ' MMac' 51 | elif flops // 10**3 > 0: 52 | return str(round(flops / 10.**3, precision)) + ' KMac' 53 | else: 54 | return str(flops) + ' Mac' 55 | else: 56 | if units == 'GMac': 57 | return str(round(flops / 10.**9, precision)) + ' ' + units 58 | elif units == 'MMac': 59 | return str(round(flops / 10.**6, precision)) + ' ' + units 60 | elif units == 'KMac': 61 | return str(round(flops / 10.**3, precision)) + ' ' + units 62 | else: 63 | return str(flops) + ' Mac' 64 | 65 | def params_to_string(params_num): 66 | if params_num // 10 ** 6 > 0: 67 | return str(round(params_num / 10 ** 6, 2)) + ' M' 68 | elif params_num // 10 ** 3: 69 | return str(round(params_num / 10 ** 3, 2)) + ' k' 70 | 71 | def print_model_with_flops(model, units='GMac', precision=3): 72 | total_flops = model.compute_average_flops_cost() 73 | 74 | def accumulate_flops(self): 75 | if is_supported_instance(self): 76 | return self.__flops__ / model.__batch_counter__ 77 | else: 78 | sum = 0 79 | for m in self.children(): 80 | sum += m.accumulate_flops() 81 | return sum 82 | 83 | def flops_repr(self): 84 | accumulated_flops_cost = self.accumulate_flops() 85 | return ', '.join([flops_to_string(accumulated_flops_cost, units=units, precision=precision), 86 | '{:.3%} MACs'.format(accumulated_flops_cost / total_flops), 87 | self.original_extra_repr()]) 88 | 89 | def add_extra_repr(m): 90 | m.accumulate_flops = accumulate_flops.__get__(m) 91 | flops_extra_repr = flops_repr.__get__(m) 92 | if m.extra_repr != flops_extra_repr: 93 | m.original_extra_repr = m.extra_repr 94 | m.extra_repr = flops_extra_repr 95 | assert m.extra_repr != m.original_extra_repr 96 | 97 | def del_extra_repr(m): 98 | if hasattr(m, 'original_extra_repr'): 99 | m.extra_repr = m.original_extra_repr 100 | del m.original_extra_repr 101 | if hasattr(m, 'accumulate_flops'): 102 | del m.accumulate_flops 103 | 104 | model.apply(add_extra_repr) 105 | print(model) 106 | model.apply(del_extra_repr) 107 | 108 | def get_model_parameters_number(model): 109 | params_num = sum(p.numel() for p in model.parameters() if p.requires_grad) 110 | return params_num 111 | 112 | def add_flops_counting_methods(net_main_module): 113 | # adding additional methods to the existing module object, 114 | # this is done this way so that each function has access to self object 115 | net_main_module.start_flops_count = start_flops_count.__get__(net_main_module) 116 | net_main_module.stop_flops_count = stop_flops_count.__get__(net_main_module) 117 | net_main_module.reset_flops_count = reset_flops_count.__get__(net_main_module) 118 | net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__(net_main_module) 119 | 120 | net_main_module.reset_flops_count() 121 | 122 | # Adding variables necessary for masked flops computation 123 | net_main_module.apply(add_flops_mask_variable_or_reset) 124 | 125 | return net_main_module 126 | 127 | 128 | def compute_average_flops_cost(self): 129 | """ 130 | A method that will be available after add_flops_counting_methods() is called 131 | on a desired net object. 132 | 133 | Returns current mean flops consumption per image. 134 | 135 | """ 136 | 137 | batches_count = self.__batch_counter__ 138 | flops_sum = 0 139 | for module in self.modules(): 140 | if is_supported_instance(module): 141 | flops_sum += module.__flops__ 142 | 143 | return flops_sum / batches_count 144 | 145 | 146 | def start_flops_count(self): 147 | """ 148 | A method that will be available after add_flops_counting_methods() is called 149 | on a desired net object. 150 | 151 | Activates the computation of mean flops consumption per image. 152 | Call it before you run the network. 153 | 154 | """ 155 | add_batch_counter_hook_function(self) 156 | self.apply(add_flops_counter_hook_function) 157 | 158 | 159 | def stop_flops_count(self): 160 | """ 161 | A method that will be available after add_flops_counting_methods() is called 162 | on a desired net object. 163 | 164 | Stops computing the mean flops consumption per image. 165 | Call whenever you want to pause the computation. 166 | 167 | """ 168 | remove_batch_counter_hook_function(self) 169 | self.apply(remove_flops_counter_hook_function) 170 | 171 | 172 | def reset_flops_count(self): 173 | """ 174 | A method that will be available after add_flops_counting_methods() is called 175 | on a desired net object. 176 | 177 | Resets statistics computed so far. 178 | 179 | """ 180 | add_batch_counter_variables_or_reset(self) 181 | self.apply(add_flops_counter_variable_or_reset) 182 | 183 | 184 | def add_flops_mask(module, mask): 185 | def add_flops_mask_func(module): 186 | if isinstance(module, torch.nn.Conv2d): 187 | module.__mask__ = mask 188 | module.apply(add_flops_mask_func) 189 | 190 | 191 | def remove_flops_mask(module): 192 | module.apply(add_flops_mask_variable_or_reset) 193 | 194 | 195 | # ---- Internal functions 196 | def is_supported_instance(module): 197 | if isinstance(module, (torch.nn.Conv2d, torch.nn.ReLU, torch.nn.PReLU, torch.nn.ELU, \ 198 | torch.nn.LeakyReLU, torch.nn.ReLU6, torch.nn.Linear, \ 199 | torch.nn.MaxPool2d, torch.nn.AvgPool2d, torch.nn.BatchNorm2d, \ 200 | torch.nn.Upsample, nn.AdaptiveMaxPool2d, nn.AdaptiveAvgPool2d)): 201 | return True 202 | 203 | return False 204 | 205 | 206 | def empty_flops_counter_hook(module, input, output): 207 | module.__flops__ += 0 208 | 209 | 210 | def upsample_flops_counter_hook(module, input, output): 211 | output_size = output[0] 212 | batch_size = output_size.shape[0] 213 | output_elements_count = batch_size 214 | for val in output_size.shape[1:]: 215 | output_elements_count *= val 216 | module.__flops__ += int(output_elements_count) 217 | 218 | 219 | def relu_flops_counter_hook(module, input, output): 220 | active_elements_count = output.numel() 221 | module.__flops__ += int(active_elements_count) 222 | 223 | 224 | def linear_flops_counter_hook(module, input, output): 225 | input = input[0] 226 | batch_size = input.shape[0] 227 | module.__flops__ += int(batch_size * input.shape[1] * output.shape[1]) 228 | 229 | 230 | def pool_flops_counter_hook(module, input, output): 231 | input = input[0] 232 | module.__flops__ += int(np.prod(input.shape)) 233 | 234 | def bn_flops_counter_hook(module, input, output): 235 | module.affine 236 | input = input[0] 237 | 238 | batch_flops = np.prod(input.shape) 239 | if module.affine: 240 | batch_flops *= 2 241 | module.__flops__ += int(batch_flops) 242 | 243 | def conv_flops_counter_hook(conv_module, input, output): 244 | # Can have multiple inputs, getting the first one 245 | input = input[0] 246 | 247 | batch_size = input.shape[0] 248 | output_height, output_width = output.shape[2:] 249 | 250 | kernel_height, kernel_width = conv_module.kernel_size 251 | in_channels = conv_module.in_channels 252 | out_channels = conv_module.out_channels 253 | groups = conv_module.groups 254 | 255 | filters_per_channel = out_channels // groups 256 | conv_per_position_flops = kernel_height * kernel_width * in_channels * filters_per_channel 257 | 258 | active_elements_count = batch_size * output_height * output_width 259 | 260 | if conv_module.__mask__ is not None: 261 | # (b, 1, h, w) 262 | flops_mask = conv_module.__mask__.expand(batch_size, 1, output_height, output_width) 263 | active_elements_count = flops_mask.sum() 264 | 265 | overall_conv_flops = conv_per_position_flops * active_elements_count 266 | 267 | bias_flops = 0 268 | 269 | if conv_module.bias is not None: 270 | 271 | bias_flops = out_channels * active_elements_count 272 | 273 | overall_flops = overall_conv_flops + bias_flops 274 | 275 | conv_module.__flops__ += int(overall_flops) 276 | 277 | 278 | def batch_counter_hook(module, input, output): 279 | batch_size = 1 280 | if len(input) > 0: 281 | # Can have multiple inputs, getting the first one 282 | input = input[0] 283 | batch_size = len(input) 284 | else: 285 | pass 286 | print('Warning! No positional inputs found for a module, assuming batch size is 1.') 287 | module.__batch_counter__ += batch_size 288 | 289 | 290 | def add_batch_counter_variables_or_reset(module): 291 | 292 | module.__batch_counter__ = 0 293 | 294 | 295 | def add_batch_counter_hook_function(module): 296 | if hasattr(module, '__batch_counter_handle__'): 297 | return 298 | 299 | handle = module.register_forward_hook(batch_counter_hook) 300 | module.__batch_counter_handle__ = handle 301 | 302 | 303 | def remove_batch_counter_hook_function(module): 304 | if hasattr(module, '__batch_counter_handle__'): 305 | module.__batch_counter_handle__.remove() 306 | del module.__batch_counter_handle__ 307 | 308 | 309 | def add_flops_counter_variable_or_reset(module): 310 | if is_supported_instance(module): 311 | module.__flops__ = 0 312 | 313 | 314 | def add_flops_counter_hook_function(module): 315 | if is_supported_instance(module): 316 | if hasattr(module, '__flops_handle__'): 317 | return 318 | 319 | if isinstance(module, torch.nn.Conv2d): 320 | handle = module.register_forward_hook(conv_flops_counter_hook) 321 | elif isinstance(module, (torch.nn.ReLU, torch.nn.PReLU, torch.nn.ELU, \ 322 | torch.nn.LeakyReLU, torch.nn.ReLU6)): 323 | handle = module.register_forward_hook(relu_flops_counter_hook) 324 | elif isinstance(module, torch.nn.Linear): 325 | handle = module.register_forward_hook(linear_flops_counter_hook) 326 | elif isinstance(module, (torch.nn.AvgPool2d, torch.nn.MaxPool2d, nn.AdaptiveMaxPool2d, \ 327 | nn.AdaptiveAvgPool2d)): 328 | handle = module.register_forward_hook(pool_flops_counter_hook) 329 | elif isinstance(module, torch.nn.BatchNorm2d): 330 | handle = module.register_forward_hook(bn_flops_counter_hook) 331 | elif isinstance(module, torch.nn.Upsample): 332 | handle = module.register_forward_hook(upsample_flops_counter_hook) 333 | else: 334 | handle = module.register_forward_hook(empty_flops_counter_hook) 335 | module.__flops_handle__ = handle 336 | 337 | 338 | def remove_flops_counter_hook_function(module): 339 | if is_supported_instance(module): 340 | if hasattr(module, '__flops_handle__'): 341 | module.__flops_handle__.remove() 342 | del module.__flops_handle__ 343 | # --- Masked flops counting 344 | 345 | 346 | # Also being run in the initialization 347 | def add_flops_mask_variable_or_reset(module): 348 | if is_supported_instance(module): 349 | module.__mask__ = None 350 | -------------------------------------------------------------------------------- /sample.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torchvision.models as models 3 | import torch 4 | from ptflops import get_model_complexity_info 5 | 6 | pt_models = {'resnet18': models.resnet18, 'resnet50': models.resnet50, 7 | 'alexnet': models.alexnet, 8 | 'vgg16': models.vgg16, 9 | 'squeezenet': models.squeezenet1_0, 10 | 'densenet': models.densenet161, 11 | 'inception': models.inception_v3} 12 | 13 | if __name__ == '__main__': 14 | parser = argparse.ArgumentParser(description='Flops counter sample script.') 15 | parser.add_argument('--device', type=int, default=-1, help='Device to store the model.') 16 | parser.add_argument('--model', choices=list(pt_models.keys()), type=str, default='resnet18') 17 | args = parser.parse_args() 18 | 19 | with torch.cuda.device(args.device): 20 | net = pt_models[args.model]() 21 | flops, params = get_model_complexity_info(net, (224, 224), as_strings=True, print_per_layer_stat=True) 22 | print('Flops: ' + flops) 23 | print('Params: ' + params) 24 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import sys 4 | from setuptools import setup, find_packages 5 | 6 | readme = open('README.md').read() 7 | 8 | VERSION = '0.1' 9 | 10 | requirements = [ 11 | 'torch', 12 | ] 13 | 14 | setup( 15 | # Metadata 16 | name='ptflops', 17 | version=VERSION, 18 | author='Vladislav Sovrasov', 19 | author_email='sovrasov.vlad@gmail.com', 20 | url='https://github.com/sovrasov/flops-counter.pytorch', 21 | description='Flops counter for convolutional networks in pytorch framework', 22 | long_description=readme, 23 | license='MIT', 24 | 25 | # Package info 26 | packages=find_packages(exclude=('*test*',)), 27 | 28 | # 29 | zip_safe=True, 30 | install_requires=requirements, 31 | 32 | # Classifiers 33 | classifiers=[ 34 | 'Programming Language :: Python :: 3', 35 | ], 36 | ) 37 | --------------------------------------------------------------------------------