├── .github └── workflows │ └── test.yml ├── .gitignore ├── LICENSE ├── README.md ├── pthflops ├── __init__.py ├── ops_fx.py ├── ops_jit.py └── utils.py ├── setup.py ├── test ├── smoke_test.py └── test_ops.py └── tox.ini /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Test Pytorch Flops Counter 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | build: 7 | runs-on: ubuntu-latest 8 | strategy: 9 | matrix: 10 | python-version: [3.8] 11 | pytorch-version: [1.5.1, 1.6.0, 1.7.1, 1.8.0, nightly] 12 | steps: 13 | - uses: actions/checkout@v2 14 | - name: Set up Python ${{ matrix.python-version }} 15 | uses: actions/setup-python@v2 16 | with: 17 | python-version: ${{ matrix.python-version }} 18 | - name: Add conda to system path 19 | run: | 20 | echo $CONDA/bin >> $GITHUB_PATH 21 | - name: Install dependencies 22 | run: | 23 | conda create -q -n test-environment python=${{ matrix.python-version }} 24 | source activate test-environment 25 | pip install Pillow==6.1 26 | if [[ "${{ matrix.pytorch-version }}" == "nightly" ]]; then 27 | conda install pytorch torchvision cpuonly -c pytorch-nightly 28 | else 29 | conda install pytorch==${{ matrix.pytorch-version }} torchvision cpuonly -c pytorch 30 | fi 31 | conda install flake8 32 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 33 | python setup.py install 34 | - name: Test with pytest 35 | run: | 36 | source activate test-environment 37 | conda install pytest 38 | pytest test/ 39 | - name : Lint with flake8 40 | run: | 41 | source activate test-environment 42 | flake8 . --exit-zero -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | #vscode 107 | .vscode/ 108 | .vs/ 109 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2019-2021, Adrian Bulat 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![License](https://img.shields.io/badge/License-BSD%203--Clause-blue.svg)](https://opensource.org/licenses/BSD-3-Clause) [![Test Pytorch Flops Counter](https://github.com/1adrianb/pytorch-estimate-flops/workflows/Test%20Pytorch%20Flops%20Counter/badge.svg)](https://travis-ci.com/1adrianb/pytorch-estimate-flops) 2 | [![PyPI](https://img.shields.io/pypi/v/pthflops.svg?style=flat)](https://pypi.org/project/pthflops/) 3 | 4 | # pytorch-estimate-flops 5 | 6 | Simple pytorch utility that estimates the number of FLOPs for a given network. For now only some basic operations are supported (basically the ones I needed for my models). More will be added soon. 7 | 8 | All contributions are welcomed. 9 | 10 | ## Installation 11 | 12 | You can install the model using pip: 13 | 14 | ```bash 15 | pip install pthflops 16 | ``` 17 | or directly from the github repository: 18 | ```bash 19 | git clone https://github.com/1adrianb/pytorch-estimate-flops && cd pytorch-estimate-flops 20 | python setup.py install 21 | ``` 22 | 23 | Note: pytorch 1.8 or newer is recommended. 24 | 25 | ## Example 26 | 27 | ```python 28 | import torch 29 | from torchvision.models import resnet18 30 | 31 | from pthflops import count_ops 32 | 33 | # Create a network and a corresponding input 34 | device = 'cuda:0' 35 | model = resnet18().to(device) 36 | inp = torch.rand(1,3,224,224).to(device) 37 | 38 | # Count the number of FLOPs 39 | count_ops(model, inp) 40 | ``` 41 | 42 | Ignoring certain layers: 43 | 44 | ```python 45 | import torch 46 | from torch import nn 47 | from pthflops import count_ops 48 | 49 | class CustomLayer(nn.Module): 50 | def __init__(self): 51 | super(CustomLayer, self).__init__() 52 | self.conv1 = nn.Conv2d(5, 5, 1, 1, 0) 53 | # ... other layers present inside will also be ignored 54 | 55 | def forward(self, x): 56 | return self.conv1(x) 57 | 58 | # Create a network and a corresponding input 59 | inp = torch.rand(1,5,7,7) 60 | net = nn.Sequential( 61 | nn.Conv2d(5, 5, 1, 1, 0), 62 | nn.ReLU(inplace=True), 63 | CustomLayer() 64 | ) 65 | 66 | # Count the number of FLOPs, jit mode: 67 | count_ops(net, inp, ignore_layers=['CustomLayer']) 68 | 69 | # Note: if you are using python 1.8 or newer with fx instead of jit, the naming convention changed. As such, you will have to pass ['_2_conv1'] 70 | # Please check your model definition to account for this. 71 | # Count the number of FLOPs, fx mode: 72 | count_ops(net, inp, ignore_layers=['_2_conv1']) 73 | 74 | ``` 75 | -------------------------------------------------------------------------------- /pthflops/__init__.py: -------------------------------------------------------------------------------- 1 | from .ops_jit import count_ops_jit 2 | try: 3 | from .ops_fx import count_ops_fx 4 | force_jit = False 5 | except: 6 | force_jit = True 7 | print('Unable to import torch.fx, you pytorch version may be too old.') 8 | 9 | __version__ = '0.4.2' 10 | 11 | 12 | def count_ops(model, input, mode='fx', custom_ops={}, ignore_layers=[], print_readable=True, verbose=True, *args): 13 | if 'fx' == mode and not force_jit: 14 | return count_ops_fx( 15 | model, 16 | input, 17 | custom_ops=custom_ops, 18 | ignore_layers=ignore_layers, 19 | print_readable=print_readable, 20 | verbose=verbose, 21 | *args) 22 | elif 'jit' == mode or force_jit: 23 | if force_jit: 24 | print("FX is unsupported on your pytorch version, falling back to JIT") 25 | return count_ops_jit( 26 | model, 27 | input, 28 | custom_ops=custom_ops, 29 | ignore_layers=ignore_layers, 30 | print_readable=print_readable, 31 | verbose=verbose, 32 | *args) 33 | else: 34 | raise ValueError('Unknown mode selected.') 35 | -------------------------------------------------------------------------------- /pthflops/ops_fx.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | from typing import Any, Dict, List, Tuple 3 | 4 | import torch 5 | import torch.fx 6 | import torch.nn as nn 7 | 8 | from .utils import same_device, print_table 9 | 10 | 11 | def _count_convNd(module: Any, output: torch.Tensor, args: Tuple[Any], kwargs: Dict[str, Any]) -> int: 12 | r"""Estimates the number of FLOPs in conv layer 13 | 14 | .. warning:: 15 | Currently it ignore the padding 16 | 17 | :param node_string: an onnx node defining a convolutional layer 18 | 19 | :return: number of FLOPs 20 | :rtype: `int` 21 | """ 22 | kernel_size = list(module.kernel_size) 23 | in_channels = module.in_channels 24 | out_channels = module.out_channels 25 | 26 | filters_per_channel = out_channels // module.groups 27 | conv_per_position_flops = reduce(lambda x, y: x * y, kernel_size) * \ 28 | in_channels * filters_per_channel 29 | 30 | active_elements_count = output.shape[0] * reduce(lambda x, y: x * y, output.shape[2:]) 31 | 32 | overall_conv_flops = conv_per_position_flops * active_elements_count 33 | 34 | bias_ops = 0 35 | if module.bias is not None: 36 | bias_ops = out_channels * active_elements_count 37 | 38 | total_ops = overall_conv_flops + bias_ops 39 | 40 | return total_ops 41 | 42 | 43 | def _count_relu(module: Any, output: torch.Tensor, args: Tuple[Any], kwargs: Dict[str, Any]) -> int: 44 | r"""Estimates the number of FLOPs of a ReLU activation. 45 | The function will count the comparison operation as a FLOP. 46 | 47 | :param node_string: an onnx node defining a ReLU op 48 | 49 | :return: number of FLOPs 50 | :rtype: `int` 51 | """ 52 | total_ops = 2 * output.numel() # also count the comparison 53 | return total_ops 54 | 55 | 56 | def _count_avgpool(module: Any, output: torch.Tensor, args: Tuple[Any], kwargs: Dict[str, Any]) -> int: 57 | r"""Estimates the number of FLOPs of an Average Pooling layer. 58 | 59 | :param node_string: an onnx node defining an average pooling layer 60 | 61 | :return: number of FLOPs 62 | :rtype: `int` 63 | """ 64 | out_ops = output.numel() 65 | 66 | kernel_size = [module.kernel_size] * \ 67 | (output.dim() - 2) if isinstance(module.kernel_size, int) else module.kernel_size 68 | 69 | ops_add = reduce(lambda x, y: x * y, kernel_size) - 1 70 | ops_div = 1 71 | total_ops = (ops_add + ops_div) * out_ops 72 | return total_ops 73 | 74 | 75 | def _count_globalavgpool(module: Any, output: torch.Tensor, args: Tuple[Any], kwargs: Dict[str, Any]) -> int: 76 | r"""Estimates the number of FLOPs of an Average Pooling layer. 77 | 78 | :param node_string: an onnx node defining an average pooling layer 79 | 80 | :return: number of FLOPs 81 | :rtype: `int` 82 | """ 83 | inp = args[0] 84 | 85 | ops_add = reduce(lambda x, y: x * y, [inp.shape[-2], inp.shape[-1]]) - 1 86 | ops_div = 1 87 | total_ops = (ops_add + ops_div) * output.numel() 88 | return total_ops 89 | 90 | 91 | def _count_maxpool(module: Any, output: torch.Tensor, args: Tuple[Any], kwargs: Dict[str, Any]) -> int: 92 | r"""Estimates the number of FLOPs of a Max Pooling layer. 93 | 94 | :param node_string: an onnx node defining a max pooling layer 95 | 96 | :return: number of FLOPs 97 | :rtype: `int` 98 | """ 99 | out_ops = output.numel() 100 | 101 | kernel_size = [module.kernel_size] * \ 102 | (output.dim() - 2) if isinstance(module.kernel_size, int) else module.kernel_size 103 | ops_add = reduce(lambda x, y: x * y, kernel_size) - 1 104 | total_ops = ops_add * out_ops 105 | return total_ops 106 | 107 | 108 | def _count_bn(module: Any, output: torch.Tensor, args: Tuple[Any], kwargs: Dict[str, Any]) -> int: 109 | r"""Estimates the number of FLOPs of a Batch Normalisation operation. 110 | 111 | :param node_string: an onnx node defining a batch norm op 112 | 113 | :return: number of FLOPs 114 | :rtype: `int` 115 | """ 116 | total_ops = output.numel() * 2 117 | return total_ops 118 | 119 | 120 | def _count_linear(module: Any, output: torch.Tensor, args: Tuple[Any], kwargs: Dict[str, Any]) -> int: 121 | r"""Estimates the number of a GEMM or linear layer. 122 | 123 | :param node_string: an onnx node defining a GEMM or linear layer 124 | 125 | :return: number of FLOPs 126 | :rtype: `int` 127 | """ 128 | bias_ops = 0 129 | if isinstance(module, nn.Linear): 130 | if module.bias is not None: 131 | bias_ops = output.shape[-1] 132 | total_ops = args[0].numel() * output.shape[-1] + bias_ops 133 | return total_ops 134 | 135 | 136 | def _count_add_mul(module: Any, output: torch.Tensor, args: Tuple[Any], kwargs: Dict[str, Any]) -> int: 137 | r"""Estimates the number of FLOPs of a summation op. 138 | 139 | :param node_string: an onnx node defining a summation op 140 | 141 | :return: number of FLOPs 142 | :rtype: `int` 143 | """ 144 | return output.numel() * len(args) 145 | 146 | 147 | def _undefined_op(module: Any, output: torch.Tensor, args: Tuple[Any], kwargs: Dict[str, Any]) -> int: 148 | r"""Default case for undefined or free (in terms of FLOPs) operations 149 | 150 | :param node_string: an onnx node 151 | 152 | :return: always 0 153 | :rtype: `int` 154 | """ 155 | return 0 156 | 157 | 158 | def count_operations(module: Any) -> Any: 159 | if isinstance(module, torch.nn.modules.conv._ConvNd): 160 | return _count_convNd 161 | elif isinstance(module, nn.ReLU): 162 | return _count_relu 163 | elif isinstance(module, torch.nn.modules.batchnorm._BatchNorm): 164 | return _count_bn 165 | elif isinstance(module, torch.nn.modules.pooling._MaxPoolNd): 166 | return _count_maxpool 167 | elif isinstance(module, torch.nn.modules.pooling._AvgPoolNd): 168 | return _count_avgpool 169 | elif isinstance(module, torch.nn.modules.pooling._AdaptiveAvgPoolNd): 170 | return _count_globalavgpool 171 | elif isinstance(module, torch.nn.Linear): 172 | return _count_linear 173 | elif 'add' == module or 'mul' == module: 174 | return _count_add_mul 175 | elif 'matmul' == module: 176 | return _count_linear 177 | else: 178 | return _undefined_op 179 | 180 | 181 | class ProfilingInterpreter(torch.fx.Interpreter): 182 | def __init__(self, mod: torch.nn.Module, custom_ops: Dict[str, Any] = {}): 183 | gm = torch.fx.symbolic_trace(mod) 184 | super().__init__(gm) 185 | 186 | self.custom_ops = custom_ops 187 | 188 | self.flops: Dict[torch.fx.Node, float] = {} 189 | self.parameters: Dict[torch.fx.Node, float] = {} 190 | 191 | def run_node(self, n: torch.fx.Node) -> Any: 192 | return_val = super().run_node(n) 193 | if isinstance(return_val, Tuple): 194 | self.flops[n] = return_val[1] 195 | self.parameters[n] = return_val[2] 196 | return_val = return_val[0] 197 | 198 | return return_val 199 | 200 | def call_module(self, target: 'Target', args: Tuple[Any], kwargs: Dict[str, Any]) -> Any: 201 | # Execute the method and return the result 202 | assert isinstance(target, str) 203 | submod = self.fetch_attr(target) 204 | output = submod(*args, **kwargs) 205 | 206 | if submod in self.custom_ops: 207 | count_ops_funct = self.custom_ops[submod] 208 | else: 209 | count_ops_funct = count_operations(submod) 210 | current_ops = count_ops_funct(submod, output, args, kwargs) 211 | current_params = sum(p.numel() for p in submod.parameters()) 212 | 213 | return output, current_ops, current_params 214 | 215 | def call_function(self, target: 'Target', args: Tuple[Any], kwargs: Dict[str, Any]) -> Any: 216 | assert not isinstance(target, str) 217 | 218 | # Execute the function and return the result 219 | output = target(*args, **kwargs) 220 | 221 | current_ops = count_operations(target.__name__)(target, output, args, kwargs) 222 | 223 | return output, current_ops, 0 224 | 225 | 226 | def count_ops_fx(model: torch.nn.Module, 227 | input: torch.Tensor, 228 | custom_ops: Dict[Any, 229 | Any] = {}, 230 | ignore_layers: List[str] = [], 231 | print_readable: bool = True, 232 | verbose: bool = True, 233 | *args): 234 | r"""Estimates the number of FLOPs of an :class:`torch.nn.Module` 235 | 236 | :param model: the :class:`torch.nn.Module` 237 | :param input: a N-d :class:`torch.tensor` containing the input to the model 238 | :param custom_ops: :class:`dict` containing custom counting functions. The keys represent the name 239 | of the targeted aten op, while the value a lambda or callback to a function returning the number of ops. 240 | This can override the ops present in the package. 241 | :param ignore_layers: :class:`list` containing the name of the modules to be ignored. 242 | :param print_readable: boolean, if True will print the number of FLOPs. default is True 243 | :param verbose: boolean, if True will print all the non-zero OPS operations from the network 244 | 245 | :return: number of FLOPs 246 | :rtype: `int` 247 | """ 248 | model, input = same_device(model, input) 249 | 250 | # Place the model in eval mode, required for some models 251 | model_status = model.training 252 | model.eval() 253 | 254 | tracer = ProfilingInterpreter(model, custom_ops=custom_ops) 255 | tracer.run(input) 256 | 257 | ops = 0 258 | all_data = [] 259 | 260 | for name, current_ops in tracer.flops.items(): 261 | model_status = model.training 262 | 263 | if any(name.name == ign_name for ign_name in ignore_layers): 264 | continue 265 | 266 | ops += current_ops 267 | 268 | if current_ops and verbose: 269 | all_data.append(['{}'.format(name), current_ops]) 270 | 271 | if print_readable: 272 | if verbose: 273 | print_table(all_data) 274 | print("Input size: {0}".format(tuple(input.shape))) 275 | print("{:,} FLOPs or approx. {:,.2f} GFLOPs".format(ops, ops / 1e+9)) 276 | 277 | if model_status: 278 | model.train() 279 | 280 | return ops, all_data 281 | -------------------------------------------------------------------------------- /pthflops/ops_jit.py: -------------------------------------------------------------------------------- 1 | import re 2 | from functools import reduce 3 | from collections import defaultdict 4 | from distutils.version import LooseVersion 5 | from typing import Any, Dict, List, Tuple 6 | import torch 7 | 8 | from .utils import print_table, scope_name_workaround, same_device, deprecated 9 | 10 | 11 | def string_to_shape(node_string: str, bias: bool = False): 12 | r"""Extract the shape of a given tensor from an onnx string 13 | 14 | :param node_string: a :class:`str` or the node from which the shape will be extracted 15 | :param bias: boolean, if True will return the shape of the bias. If no bias is found the function will return None 16 | 17 | :return: a tuple containing the shape of the tensor 18 | :rtype: :class:`tuple` 19 | """ 20 | if not isinstance(node_string, str): 21 | node_string = str(node_string) 22 | node_string = node_string.replace('!', '') 23 | if bias: 24 | m = re.search(r"Float\((\d+)\)", node_string) 25 | else: 26 | m = re.search(r"Float\(([\d\s\,]+)\)", node_string) 27 | return m if m is None else tuple(int(x) for x in m.groups()[0].split(',')) 28 | 29 | 30 | def _parse_node_inputs(node: str, version: int = 2): 31 | inputs = {} 32 | inputs_names = [] 33 | for idx, inp in enumerate(node.inputs()): 34 | inp = str(inp) 35 | curr_node_name = re.search(r'(.*) defined in ', inp).group(1) 36 | if version <= 2: 37 | extracted_data = re.search(r'%' + curr_node_name + r' : Float\(([^%]*)\)[,| ]', inp) 38 | elif version == 3: 39 | extracted_data = re.search(r'%' + curr_node_name + r' : Float\(((\d+, )+)', inp) 40 | 41 | if extracted_data is not None: 42 | extracted_data = extracted_data.group(1) 43 | else: 44 | return _parse_node_inputs( 45 | list(node.inputs())[0].node(), 46 | version 47 | ) 48 | if version <= 2: 49 | inputs[curr_node_name] = re.findall(r'(\d+):', extracted_data) 50 | elif version == 3: 51 | inputs[curr_node_name] = re.findall(r'(\d+)', extracted_data) 52 | inputs[curr_node_name] = list(map(int, inputs[curr_node_name])) 53 | inputs_names.append(curr_node_name) 54 | return inputs, inputs_names 55 | 56 | 57 | def parse_node_info(node: str, version: int = 2): 58 | inputs, inputs_names = _parse_node_inputs(node, version=version) 59 | node = str(node) 60 | node_name = re.search(r'%(.*) : ', node).group(1) 61 | out_size = 0 62 | if version == 2: 63 | out_size = re.search(r'Float\(\d+:(\d+),', node) 64 | elif version == 3: 65 | out_size = re.search(r'strides=\[(\d+),', node) 66 | 67 | if out_size: 68 | out_size = out_size.group(1) 69 | else: 70 | out_size = 0 71 | print("Unable to parse output size of node: {0}, then the output size is 0".format(node)) 72 | 73 | return node_name, inputs, inputs_names, int(out_size) 74 | 75 | 76 | def _count_convNd(node: str, version: int = 2): 77 | r"""Estimates the number of FLOPs in conv layer 78 | 79 | .. warning:: 80 | Currently it ignore the padding 81 | 82 | :param node_string: an onnx node defining a convolutional layer 83 | 84 | :return: number of FLOPs 85 | :rtype: `int` 86 | """ 87 | kernel_size = node['kernel_shape'] 88 | num_groups = node['group'] 89 | 90 | if version == 1: 91 | inp = string_to_shape(list(node.inputs())[0]) 92 | out = string_to_shape(list(node.outputs())[0]) 93 | out_ops = reduce(lambda x, y: x * y, out) 94 | bias_ops = 1 if string_to_shape(list(node.inputs())[0], True) is not None else 0 95 | f_in = inp[1] 96 | elif version in [2, 3]: 97 | node_name, inputs, inputs_names, out_ops = parse_node_info(node, version=version) 98 | f_in = inputs[inputs_names[0]][1] 99 | bias_ops = 1 if len(inputs_names) == 3 else 0 100 | 101 | kernel_ops = f_in 102 | for ks in kernel_size: 103 | kernel_ops *= ks 104 | 105 | kernel_ops = kernel_ops // num_groups 106 | combined_ops = kernel_ops + bias_ops 107 | 108 | total_ops = combined_ops * out_ops 109 | 110 | return total_ops 111 | 112 | 113 | def _count_relu(node: str, version: int = 2): 114 | r"""Estimates the number of FLOPs of a ReLU activation. 115 | The function will count the comparison operation as a FLOP. 116 | 117 | :param node_string: an onnx node defining a ReLU op 118 | 119 | :return: number of FLOPs 120 | :rtype: `int` 121 | """ 122 | if version == 1: 123 | inp = string_to_shape(list(node.inputs())[0]) 124 | elif version in [2, 3]: 125 | node_name, inputs, inputs_names, out_ops = parse_node_info(node, version=version) 126 | inp = inputs[inputs_names[0]] 127 | total_ops = 2 * reduce(lambda x, y: x * y, inp) # also count the comparison 128 | return total_ops 129 | 130 | 131 | def _count_avgpool(node: str, version: int = 2): 132 | r"""Estimates the number of FLOPs of an Average Pooling layer. 133 | 134 | :param node_string: an onnx node defining an average pooling layer 135 | 136 | :return: number of FLOPs 137 | :rtype: `int` 138 | """ 139 | if version == 1: 140 | out = string_to_shape(list(node.outputs())[0]) 141 | out_ops = reduce(lambda x, y: x * y, out) 142 | elif version in [2, 3]: 143 | node_name, inputs, inputs_names, out_ops = parse_node_info(node, version=version) 144 | 145 | ops_add = reduce(lambda x, y: x * y, node['kernel_shape']) - 1 146 | ops_div = 1 147 | total_ops = (ops_add + ops_div) * out_ops 148 | return total_ops 149 | 150 | 151 | def _count_globalavgpool(node: str, version: int = 2): 152 | r"""Estimates the number of FLOPs of an Average Pooling layer. 153 | 154 | :param node_string: an onnx node defining an average pooling layer 155 | 156 | :return: number of FLOPs 157 | :rtype: `int` 158 | """ 159 | if version == 1: 160 | inp = string_to_shape(list(node.inputs())[0]) 161 | out = string_to_shape(list(node.outputs())[0]) 162 | out_ops = reduce(lambda x, y: x * y, out) 163 | elif version in [2, 3]: 164 | node_name, inputs, inputs_names, out_ops = parse_node_info(node, version=version) 165 | inp = inputs[inputs_names[0]] 166 | 167 | ops_add = reduce(lambda x, y: x * y, [inp[-2], inp[-1]]) - 1 168 | ops_div = 1 169 | total_ops = (ops_add + ops_div) * out_ops 170 | return total_ops 171 | 172 | 173 | def _count_maxpool(node: str, version: int = 2): 174 | r"""Estimates the number of FLOPs of a Max Pooling layer. 175 | 176 | :param node_string: an onnx node defining a max pooling layer 177 | 178 | :return: number of FLOPs 179 | :rtype: `int` 180 | """ 181 | if version == 1: 182 | out = string_to_shape(list(node.outputs())[0]) 183 | out_ops = reduce(lambda x, y: x * y, out) 184 | elif version in [2, 3]: 185 | node_name, inputs, inputs_names, out_ops = parse_node_info(node, version=version) 186 | 187 | ops_add = reduce(lambda x, y: x * y, node['kernel_shape']) - 1 188 | total_ops = ops_add * out_ops 189 | return total_ops 190 | 191 | 192 | def _count_bn(node: str, version: int = 2): 193 | r"""Estimates the number of FLOPs of a Batch Normalisation operation. 194 | 195 | :param node_string: an onnx node defining a batch norm op 196 | 197 | :return: number of FLOPs 198 | :rtype: `int` 199 | """ 200 | if version == 1: 201 | if 'BatchNorm1d' in node.scopeName(): 202 | inp = string_to_shape(list(node.inputs())[1]) 203 | else: 204 | inp = string_to_shape(list(node.inputs())[0]) 205 | elif version in [2, 3]: 206 | node_name, inputs, inputs_names, out_ops = parse_node_info(node, version=version) 207 | inp = inputs[inputs_names[0]] 208 | 209 | total_ops = reduce(lambda x, y: x * y, inp) * 2 210 | return total_ops 211 | 212 | 213 | def _count_linear(node: str, version: int = 2): 214 | r"""Estimates the number of a GEMM or linear layer. 215 | 216 | :param node_string: an onnx node defining a GEMM or linear layer 217 | 218 | :return: number of FLOPs 219 | :rtype: `int` 220 | """ 221 | if version == 1: 222 | inp = string_to_shape(list(node.inputs())[0]) 223 | out = string_to_shape(list(node.outputs())[0]) 224 | f_in = inp[1] 225 | out_ops = reduce(lambda x, y: x * y, out) 226 | elif version in [2, 3]: 227 | node_name, inputs, inputs_names, out_ops = parse_node_info(node, version=version) 228 | inp = inputs[inputs_names[0]] 229 | f_in = inputs[inputs_names[0]][1] 230 | 231 | total_ops = f_in * out_ops 232 | return total_ops 233 | 234 | 235 | def _count_add_mul(node: str, version: int = 2): 236 | r"""Estimates the number of FLOPs of a summation op. 237 | 238 | :param node_string: an onnx node defining a summation op 239 | 240 | :return: number of FLOPs 241 | :rtype: `int` 242 | """ 243 | if version == 1: 244 | inp = string_to_shape(list(node.inputs())[0]) 245 | elif version in [2, 3]: 246 | node_name, inputs, inputs_names, out_ops = parse_node_info(node, version=version) 247 | inp = inputs[inputs_names[0]] 248 | return reduce(lambda x, y: x * y, inp) 249 | 250 | 251 | def _undefined_op(node: str, version: int = 2): 252 | r"""Default case for undefined or free (in terms of FLOPs) operations 253 | 254 | :param node_string: an onnx node 255 | 256 | :return: always 0 257 | :rtype: `int` 258 | """ 259 | return 0 260 | 261 | count_operations = defaultdict( 262 | lambda: _undefined_op, 263 | { 264 | 'onnx::Conv': _count_convNd, 265 | 'onnx::Relu': _count_relu, 266 | 'onnx::AveragePool': _count_avgpool, 267 | 'onnx::MaxPool': _count_maxpool, 268 | 'onnx::BatchNormalization': _count_bn, 269 | 'onnx::Gemm': _count_linear, 270 | 'onnx::MatMul': _count_linear, 271 | 'onnx::Add': _count_add_mul, 272 | 'onnx::Mul': _count_add_mul, 273 | 'onnx::GlobalAveragePool': _count_globalavgpool 274 | } 275 | ) 276 | 277 | 278 | @deprecated("JIT mode is deprecated, please update to pytorch 1.8.0 or newer and use FX.") 279 | def count_ops_jit(model: torch.nn.Module, 280 | input: torch.Tensor, 281 | custom_ops: Dict[Any, 282 | Any] = {}, 283 | ignore_layers: List[str] = [], 284 | print_readable: bool = True, 285 | verbose: bool = True, 286 | *args): 287 | r"""Estimates the number of FLOPs of an :class:`torch.nn.Module` 288 | 289 | :param model: the :class:`torch.nn.Module` 290 | :param input: a N-d :class:`torch.tensor` containing the input to the model 291 | :param custom_ops: :class:`dict` containing custom counting functions. The keys represent the name 292 | of the targeted aten op, while the value a lambda or callback to a function returning the number of ops. 293 | This can override the ops present in the package. 294 | :param ignore_layers: :class:`list` containing the name of the modules to be ignored. 295 | :param print_readable: boolean, if True will print the number of FLOPs. default is True 296 | :param verbose: boolean, if True will print all the non-zero OPS operations from the network 297 | 298 | :return: number of FLOPs 299 | :rtype: `int` 300 | """ 301 | 302 | model, input = same_device(model, input) 303 | 304 | # Place the model in eval mode, required for some models 305 | model_status = model.training 306 | model.eval() 307 | 308 | # Convert pytorch module to ONNX 309 | version = 1 310 | if LooseVersion(torch.__version__) > LooseVersion('1.3.1'): 311 | with scope_name_workaround(): 312 | trace, _ = torch.jit._get_trace_graph(model, input, *args) 313 | graph = torch.onnx._optimize_trace(trace, torch.onnx.OperatorExportTypes.ONNX) 314 | 315 | if LooseVersion(torch.__version__) >= LooseVersion('1.6.0') and \ 316 | LooseVersion(torch.__version__) < LooseVersion('1.8.0'): 317 | version = 2 318 | elif LooseVersion(torch.__version__) >= LooseVersion('1.8.0'): 319 | version = 3 320 | else: 321 | # PyTorch 1.3 and bellow 322 | trace, _ = torch.jit.get_trace_graph(model, input, *args) 323 | torch.onnx._optimize_trace(trace, torch.onnx.OperatorExportTypes.ONNX) 324 | graph = trace.graph() 325 | 326 | ops = 0 327 | all_data = [] 328 | for node in graph.nodes(): 329 | if any(name in node.scopeName() for name in ignore_layers): 330 | continue 331 | if node.kind() in custom_ops.keys(): 332 | current_ops = custom_ops[node.kind()](node, version) 333 | else: 334 | current_ops = count_operations[node.kind()](node, version) 335 | ops += current_ops 336 | 337 | if current_ops and verbose: 338 | all_data.append(['{}/{}'.format(node.scopeName(), node.kind()), current_ops]) 339 | 340 | if print_readable: 341 | if verbose: 342 | print_table(all_data) 343 | print("Input size: {0}".format(tuple(input.shape))) 344 | print("{:,} FLOPs or approx. {:,.2f} GFLOPs".format(ops, ops / 1e+9)) 345 | 346 | if model_status: 347 | model.train() 348 | 349 | return ops, all_data 350 | -------------------------------------------------------------------------------- /pthflops/utils.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import inspect 3 | import warnings 4 | from typing import Iterable 5 | 6 | import torch 7 | 8 | 9 | def print_table(rows, header=['Operation', 'OPS']): 10 | r"""Simple helper function to print a list of lists as a table 11 | 12 | :param rows: a :class:`list` of :class:`list` containing the data to be printed. Each entry in the list 13 | represents an individual row 14 | :param input: (optional) a :class:`list` containing the header of the table 15 | """ 16 | if len(rows) == 0: 17 | return 18 | col_max = [max([len(str(val[i])) for val in rows]) + 3 for i in range(len(rows[0]))] 19 | row_format = ''.join(["{:<" + str(length) + "}" for length in col_max]) 20 | 21 | if len(header) > 0: 22 | print(row_format.format(*header)) 23 | print(row_format.format(*['-' * (val - 2) for val in col_max])) 24 | 25 | for row in rows: 26 | print(row_format.format(*row)) 27 | print(row_format.format(*['-' * (val - 3) for val in col_max])) 28 | 29 | 30 | def same_device(model, input): 31 | # Remove dataparallel wrapper if present 32 | if isinstance(model, torch.nn.DataParallel): 33 | model = model.module 34 | 35 | # Make sure that the input is on the same device as the model 36 | if len(list(model.parameters())): 37 | input_device = input.device if not isinstance(input, Iterable) else input[0].device 38 | if next(model.parameters()).device != input_device: 39 | if isinstance(input, Iterable): 40 | for inp in input: 41 | inp.to(next(model.parameters()).device) 42 | else: 43 | input.to(next(model.parameters()).device) 44 | 45 | return model, input 46 | 47 | # Workaround for scopename in pytorch 1.4 and newer 48 | # see: https://github.com/pytorch/pytorch/issues/33463 49 | 50 | 51 | class scope_name_workaround(object): 52 | def __init__(self): 53 | self.backup = None 54 | 55 | def __enter__(self): 56 | def _tracing_name(self_, tracing_state): 57 | if not tracing_state._traced_module_stack: 58 | return None 59 | module = tracing_state._traced_module_stack[-1] 60 | for name, child in module.named_children(): 61 | if child is self_: 62 | return name 63 | return None 64 | 65 | def _slow_forward(self_, *input, **kwargs): 66 | tracing_state = torch._C._get_tracing_state() 67 | if not tracing_state or isinstance(self_.forward, torch._C.ScriptMethod): 68 | return self_.forward(*input, **kwargs) 69 | if not hasattr(tracing_state, '_traced_module_stack'): 70 | tracing_state._traced_module_stack = [] 71 | name = _tracing_name(self_, tracing_state) 72 | if name: 73 | tracing_state.push_scope('%s[%s]' % (self_._get_name(), name)) 74 | else: 75 | tracing_state.push_scope(self_._get_name()) 76 | tracing_state._traced_module_stack.append(self_) 77 | try: 78 | result = self_.forward(*input, **kwargs) 79 | finally: 80 | tracing_state.pop_scope() 81 | tracing_state._traced_module_stack.pop() 82 | return result 83 | 84 | self.backup = torch.nn.Module._slow_forward 85 | setattr(torch.nn.Module, '_slow_forward', _slow_forward) 86 | 87 | def __exit__(self, type, value, tb): 88 | setattr(torch.nn.Module, '_slow_forward', self.backup) 89 | 90 | # Source: https://stackoverflow.com/questions/2536307/decorators-in-the-python-standard-lib-deprecated-specifically 91 | string_types = (type(b''), type(u'')) 92 | 93 | 94 | def deprecated(reason): 95 | """ 96 | This is a decorator which can be used to mark functions 97 | as deprecated. It will result in a warning being emitted 98 | when the function is used. 99 | """ 100 | if isinstance(reason, string_types): 101 | 102 | # The @deprecated is used with a 'reason'. 103 | # 104 | # .. code-block:: python 105 | # 106 | # @deprecated("please, use another function") 107 | # def old_function(x, y): 108 | # pass 109 | 110 | def decorator(func1): 111 | 112 | if inspect.isclass(func1): 113 | fmt1 = "Call to deprecated class {name} ({reason})." 114 | else: 115 | fmt1 = "Call to deprecated function {name} ({reason})." 116 | 117 | @functools.wraps(func1) 118 | def new_func1(*args, **kwargs): 119 | warnings.simplefilter('always', DeprecationWarning) 120 | warnings.warn( 121 | fmt1.format(name=func1.__name__, reason=reason), 122 | category=DeprecationWarning, 123 | stacklevel=2 124 | ) 125 | warnings.simplefilter('default', DeprecationWarning) 126 | return func1(*args, **kwargs) 127 | 128 | return new_func1 129 | 130 | return decorator 131 | 132 | elif inspect.isclass(reason) or inspect.isfunction(reason): 133 | # The @deprecated is used without any 'reason'. 134 | # 135 | # .. code-block:: python 136 | # 137 | # @deprecated 138 | # def old_function(x, y): 139 | # pass 140 | 141 | func2 = reason 142 | 143 | if inspect.isclass(func2): 144 | fmt2 = "Call to deprecated class {name}." 145 | else: 146 | fmt2 = "Call to deprecated function {name}." 147 | 148 | @functools.wraps(func2) 149 | def new_func2(*args, **kwargs): 150 | warnings.simplefilter('always', DeprecationWarning) 151 | warnings.warn( 152 | fmt2.format(name=func2.__name__), 153 | category=DeprecationWarning, 154 | stacklevel=2 155 | ) 156 | warnings.simplefilter('default', DeprecationWarning) 157 | return func2(*args, **kwargs) 158 | 159 | return new_func2 160 | 161 | else: 162 | raise TypeError(repr(type(reason))) 163 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | from os import path 4 | import re 5 | from setuptools import setup, find_packages 6 | # To use consisten encodings 7 | from codecs import open 8 | 9 | # Function from: https://github.com/pytorch/vision/blob/master/setup.py 10 | 11 | 12 | def read(*names, **kwargs): 13 | with io.open( 14 | os.path.join(os.path.dirname(__file__), *names), 15 | encoding=kwargs.get("encoding", "utf8") 16 | ) as fp: 17 | return fp.read() 18 | 19 | # Function from: https://github.com/pytorch/vision/blob/master/setup.py 20 | 21 | 22 | def find_version(*file_paths): 23 | version_file = read(*file_paths) 24 | version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", 25 | version_file, re.M) 26 | if version_match: 27 | return version_match.group(1) 28 | raise RuntimeError("Unable to find version string.") 29 | 30 | here = path.abspath(path.dirname(__file__)) 31 | 32 | # Get the long description from the README file 33 | with open(path.join(here, 'README.md'), encoding='utf-8') as readme_file: 34 | long_description = readme_file.read() 35 | 36 | VERSION = find_version('pthflops', '__init__.py') 37 | 38 | requirements = [ 39 | 'torch' 40 | ] 41 | 42 | setup( 43 | name='pthflops', 44 | version=VERSION, 45 | 46 | description="Estimate FLOPs of neural networks", 47 | long_description=long_description, 48 | long_description_content_type="text/markdown", 49 | 50 | # Author details 51 | author="Adrian Bulat", 52 | author_email="adrian@adrianbulat.com", 53 | url="https://github.com/1adrianb/pytorch-estimate-flops", 54 | 55 | # Package info 56 | packages=find_packages(exclude=('test',)), 57 | 58 | install_requires=requirements, 59 | license='BSD', 60 | zip_safe=True, 61 | 62 | classifiers=[ 63 | 'Development Status :: 4 - Beta', 64 | 'Intended Audience :: Developers', 65 | 'Operating System :: OS Independent', 66 | 'License :: OSI Approved :: BSD License', 67 | 'Natural Language :: English', 68 | 69 | # Supported python versions 70 | 'Programming Language :: Python :: 3', 71 | 'Programming Language :: Python :: 3.5', 72 | 'Programming Language :: Python :: 3.6', 73 | 'Programming Language :: Python :: 3.7', 74 | 'Programming Language :: Python :: 3.8', 75 | 'Programming Language :: Python :: 3.9', 76 | ], 77 | ) 78 | -------------------------------------------------------------------------------- /test/smoke_test.py: -------------------------------------------------------------------------------- 1 | import pthflops 2 | -------------------------------------------------------------------------------- /test/test_ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import unittest 3 | import pytest 4 | from pthflops import count_ops 5 | from torchvision.models import resnet18 6 | 7 | # TODO: Add test for every op 8 | 9 | 10 | class Tester(unittest.TestCase): 11 | 12 | def test_overall(self): 13 | input = torch.rand(1, 3, 224, 224) 14 | net = resnet18() 15 | estimated, estimations_dict = count_ops(net, input, print_readable=False, verbose=False) 16 | expected = 1826843136 17 | assert expected == pytest.approx(estimated, 1000000) 18 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 120 3 | ignore = E305,E402,E721,E722,F401,F403,F405,F821,F841,F999 --------------------------------------------------------------------------------