├── .flake8 ├── .github └── workflows │ └── main.yaml ├── .gitignore ├── .isort.cfg ├── CHANGELOG.md ├── LICENSE ├── README.md ├── ptflops ├── __init__.py ├── aten_engine.py ├── aten_ops.py ├── flops_counter.py ├── pytorch_engine.py ├── pytorch_ops.py └── utils.py ├── pyproject.toml ├── samples ├── bert.py └── classification.py └── tests └── common_test.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | application_import_names = ptflops 3 | import-order-style = pep8 4 | max-line-length = 90 5 | per-file-ignores = __init__.py:F401 6 | -------------------------------------------------------------------------------- /.github/workflows/main.yaml: -------------------------------------------------------------------------------- 1 | name: Python package 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: [main] 7 | 8 | jobs: 9 | build: 10 | 11 | runs-on: ubuntu-latest 12 | strategy: 13 | matrix: 14 | python-version: [3.9] 15 | 16 | steps: 17 | - name: Checkout repository and submodules 18 | uses: actions/checkout@v2 19 | with: 20 | submodules: recursive 21 | - name: Set up Python ${{ matrix.python-version }} 22 | uses: actions/setup-python@v2 23 | with: 24 | python-version: ${{ matrix.python-version }} 25 | - name: Cache pip 26 | uses: actions/cache@v2 27 | with: 28 | # This path is specific to Ubuntu 29 | path: ~/.cache/pip 30 | # Look to see if there is a cache hit for the corresponding requirements file 31 | key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }} 32 | ${{ runner.os }}-pip-${{ hashFiles('test_requirements.txt') }} 33 | restore-keys: | 34 | ${{ runner.os }}-pip- 35 | ${{ runner.os }}- 36 | - name: Install dependencies 37 | run: | 38 | python -m pip install --upgrade pip 39 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 40 | - name: Install ptflops 41 | run: | 42 | pip install .[dev] 43 | - name: Testing with pytest 44 | run: | 45 | python -m pytest . -s 46 | - name: Linting with flake8 47 | run: | 48 | python -m flake8 . 49 | python -m isort -rc --check-only --diff ./ptflops ./tests -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | .history/ 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | .pytest_cache/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | db.sqlite3 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # pyenv 78 | .python-version 79 | 80 | # celery beat schedule file 81 | celerybeat-schedule 82 | 83 | # SageMath parsed files 84 | *.sage.py 85 | 86 | # Environments 87 | .env 88 | .venv 89 | env/ 90 | venv/ 91 | ENV/ 92 | env.bak/ 93 | venv.bak/ 94 | 95 | # Spyder project settings 96 | .spyderproject 97 | .spyproject 98 | 99 | # Rope project settings 100 | .ropeproject 101 | 102 | # mkdocs documentation 103 | /site 104 | 105 | # mypy 106 | .mypy_cache/ 107 | -------------------------------------------------------------------------------- /.isort.cfg: -------------------------------------------------------------------------------- 1 | [isort] 2 | line_length = 79 3 | multi_line_output = 0 4 | known_standard_library = setuptools 5 | known_first_party = ptflops 6 | known_third_party = PIL,asynctest,cityscapesscripts,cv2,matplotlib,mmcv,numpy,onnx,pycocotools,pytest,robustness_eval,roi_align,roi_pool,seaborn,six,terminaltables,torch,torchvision 7 | no_lines_before = STDLIB,LOCALFOLDER 8 | default_section = THIRDPARTY 9 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # ptflops versions log 2 | 3 | ## v 0.7.4 4 | - Fix hook for nn.functional.interpolate. 5 | - Add ignore and custom modules for aten. 6 | - Add an option to disable counting of functional-style operations in pytorch backend. 7 | 8 | ## v 0.7.3 9 | - Add aten backend to collect the amount of flops on aten level. 10 | 11 | ## v 0.7.2.2 12 | - Switch from setup.py to pyproject 13 | 14 | ## v 0.7.2 15 | - Add type annotations and doc strings to the main API. 16 | - Add support of HuggingFace/Timm VIT transformers. 17 | - Update torchvision benchmark in docs. 18 | 19 | ## v 0.7.1.2 20 | - Fix failure when using input constructor. 21 | 22 | ## v 0.7.1 23 | - Experimental support of torchvision.ops.DeformConv2d 24 | - Experimental support of torch.functional.* and tensor.* operators 25 | 26 | ## v 0.7 27 | - Add ConvNext to sample, fix wrong torchvision compatibility requirement. 28 | - Support LayerNorm. 29 | 30 | ## v 0.6.9 31 | - Fix unnecessary warnings. 32 | - Improve per layer statistics output. 33 | 34 | ## v 0.6.8 35 | - Add support of GELU activation. 36 | - Fix per layer statistic output in case of zero parameters number. 37 | - Cleanup flops and params attrs after ptflops has finished counting. 38 | 39 | ## v 0.6.7 40 | - Add batch_first flag support in MultiheadAttention hook 41 | 42 | ## v 0.6.6 43 | - Add hooks for Instance and Group norms. 44 | 45 | ## v 0.6.5 46 | - Add a hook for MultiheadAttention. 47 | 48 | ## v 0.6.4 49 | - Fix unaccounted bias flops in Linear. 50 | - Fix hook for ConvTranspose*d. 51 | 52 | ## v 0.6.3 53 | - Implicitly use repr to print a model with extra_repr. 54 | 55 | ## v 0.6.2 56 | - Fix integer overflow on Windows. 57 | - Check if the input object is inherited from nn.Module. 58 | 59 | ## v 0.6.1 60 | - Add experimental version of hooks for recurrent layers (RNN, GRU, LSTM). 61 | 62 | ## v 0.6 63 | - Add verbose option to log layers that are not supported by ptflops. 64 | - Add an option to filter a list of operations from the final result. 65 | 66 | ## v 0.5.2 67 | - Fix handling of intermediate dimensions in the Linear layer hook. 68 | 69 | ## v 0.5 70 | - Add per sequential number of parameters estimation. 71 | - Fix sample doesn't work without GPU. 72 | - Clarified output in sample. 73 | 74 | ## v 0.4 75 | - Allocate temporal blobs on the same device as model's parameters are located. 76 | 77 | ## v 0.3 78 | - Add 1d operators: batch norm, poolings, convolution. 79 | - Add ability to output extended report to any output stream. 80 | 81 | ## v 0.2 82 | - Add new operations: Conv3d, BatchNorm3d, MaxPool3d, AvgPool3d, ConvTranspose2d. 83 | - Add some results on widespread models to the README. 84 | - Minor bugfixes. 85 | 86 | ## v 0.1 87 | - Initial release with basic functionality 88 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Vladislav Sovrasov 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Flops counting tool for neural networks in pytorch framework 2 | [![Pypi version](https://img.shields.io/pypi/v/ptflops.svg)](https://pypi.org/project/ptflops/) 3 | 4 | This tool is designed to compute the theoretical amount of multiply-add operations 5 | in neural networks. It can also compute the number of parameters and 6 | print per-layer computational cost of a given network. 7 | 8 | `ptflops` has two backends, `pytorch` and `aten`. `pytorch` backend is a legacy one, it considers `nn.Modules` only. However, 9 | it's still useful, since it provides a better par-layer analytics for CNNs. In all other cases it's recommended to use 10 | `aten` backend, which considers aten operations, and therefore it covers more model architectures (including transformers). 11 | The default backend is `aten`. Please, don't use `pytorch` backend for transformer architectures. 12 | 13 | ## `aten` backend 14 | ### Operations considered: 15 | - aten.mm, aten.matmul, aten.addmm, aten.bmm 16 | - aten.convolution 17 | 18 | ### Usage tips 19 | - Use `verbose=True` to see the operations which were not considered during complexity computation. 20 | - This backend prints per-module statistics only for modules directly nested into the root `nn.Module`. 21 | Deeper modules at the second level of nesting are not shown in the per-layer statistics. 22 | - `ignore_modules` option forces `ptflops` to ignore the listed modules. This can be useful 23 | for research purposes. For instance, one can drop all convolutions from the counting process 24 | specifying `ignore_modules=[torch.ops.aten.convolution, torch.ops.aten._convolution]`. 25 | 26 | ## `pytorch` backend 27 | ### Supported layers: 28 | - Conv1d/2d/3d (including grouping) 29 | - ConvTranspose1d/2d/3d (including grouping) 30 | - BatchNorm1d/2d/3d, GroupNorm, InstanceNorm1d/2d/3d, LayerNorm 31 | - Activations (ReLU, PReLU, ELU, ReLU6, LeakyReLU, GELU) 32 | - Linear 33 | - Upsample 34 | - Poolings (AvgPool1d/2d/3d, MaxPool1d/2d/3d and adaptive ones) 35 | 36 | Experimental support: 37 | - RNN, LSTM, GRU (NLH layout is assumed) 38 | - RNNCell, LSTMCell, GRUCell 39 | - torch.nn.MultiheadAttention 40 | - torchvision.ops.DeformConv2d 41 | - visual transformers from [timm](https://github.com/huggingface/pytorch-image-models) 42 | 43 | ### Usage tips 44 | 45 | - This backend doesn't take into account some of the `torch.nn.functional.*` and `tensor.*` operations. Therefore unsupported operations are 46 | not contributing to the final complexity estimation. See `ptflops/pytorch_ops.py:FUNCTIONAL_MAPPING,TENSOR_OPS_MAPPING` to check supported ops. 47 | Sometimes functional-level hooks conflict with hooks for `nn.Module` (for instance, custom ones). In that case, counting with these ops can be disabled by 48 | passing `backend_specific_config={"count_functional" : False}`. 49 | - `ptflops` launches a given model on a random tensor and estimates amount of computations during inference. Complicated models can have several inputs, some of them could be optional. To construct non-trivial input one can use the `input_constructor` argument of the `get_model_complexity_info`. `input_constructor` is a function that takes the input spatial resolution as a tuple and returns a dict with named input arguments of the model. Next, this dict would be passed to the model as a keyword arguments. 50 | - `verbose` parameter allows to get information about modules that don't contribute to the final numbers. 51 | - `ignore_modules` option forces `ptflops` to ignore the listed modules. This can be useful 52 | for research purposes. For instance, one can drop all convolutions from the counting process 53 | specifying `ignore_modules=[torch.nn.Conv2d]`. 54 | 55 | ## Requirements 56 | Pytorch >= 2.0. Use `pip install ptflops==0.7.2.2` to work with torch 1.x. 57 | 58 | ## Install the latest version 59 | From PyPI: 60 | ```bash 61 | pip install ptflops 62 | ``` 63 | 64 | From this repository: 65 | ```bash 66 | pip install --upgrade git+https://github.com/sovrasov/flops-counter.pytorch.git 67 | ``` 68 | 69 | ## Example 70 | ```python 71 | import torchvision.models as models 72 | import torch 73 | from ptflops import get_model_complexity_info 74 | 75 | with torch.cuda.device(0): 76 | net = models.densenet161() 77 | macs, params = get_model_complexity_info(net, (3, 224, 224), as_strings=True, backend='pytorch' 78 | print_per_layer_stat=True, verbose=True) 79 | print('{:<30} {:<8}'.format('Computational complexity: ', macs)) 80 | print('{:<30} {:<8}'.format('Number of parameters: ', params)) 81 | 82 | macs, params = get_model_complexity_info(net, (3, 224, 224), as_strings=True, backend='aten' 83 | print_per_layer_stat=True, verbose=True) 84 | print('{:<30} {:<8}'.format('Computational complexity: ', macs)) 85 | print('{:<30} {:<8}'.format('Number of parameters: ', params)) 86 | ``` 87 | 88 | ## Citation 89 | If ptflops was useful for your paper or tech report, please cite me: 90 | ``` 91 | @online{ptflops, 92 | author = {Vladislav Sovrasov}, 93 | title = {ptflops: a flops counting tool for neural networks in pytorch framework}, 94 | year = {2018-2024}, 95 | url = {https://github.com/sovrasov/flops-counter.pytorch}, 96 | } 97 | ``` 98 | 99 | ## Credits 100 | 101 | Thanks to @warmspringwinds and Horace He for the initial version of the script. 102 | 103 | ## Benchmark 104 | 105 | ### [torchvision](https://pytorch.org/vision/0.16/models.html) 106 | 107 | Model | Input Resolution | Params(M) | MACs(G) (`pytorch`) | MACs(G) (`aten`) 108 | --- |--- |--- |--- |--- 109 | alexnet | 224x224 | 61.10 | 0.72 | 0.71 110 | convnext_base | 224x224 | 88.59 | 15.43 | 15.38 111 | densenet121 | 224x224 | 7.98 | 2.90 | 112 | efficientnet_b0 | 224x224 | 5.29 | 0.41 | 113 | efficientnet_v2_m | 224x224 | 54.14 | 5.43 | 114 | googlenet | 224x224 | 13.00 | 1.51 | 115 | inception_v3 | 224x224 | 27.16 | 5.75 | 5.71 116 | maxvit_t | 224x224 | 30.92 | 5.48 | 117 | mnasnet1_0 | 224x224 | 4.38 | 0.33 | 118 | mobilenet_v2 | 224x224 | 3.50 | 0.32 | 119 | mobilenet_v3_large | 224x224 | 5.48 | 0.23 | 120 | regnet_y_1_6gf | 224x224 | 11.20 | 1.65 | 121 | resnet18 | 224x224 | 11.69 | 1.83 | 1.81 122 | resnet50 | 224x224 | 25.56 | 4.13 | 4.09 123 | resnext50_32x4d | 224x224 | 25.03 | 4.29 | 124 | shufflenet_v2_x1_0 | 224x224 | 2.28 | 0.15 | 125 | squeezenet1_0 | 224x224 | 1.25 | 0.84 | 0.82 126 | vgg16 | 224x224 | 138.36 | 15.52 | 15.48 127 | vit_b_16 | 224x224 | 86.57 | 17.61 (wrong) | 16.86 128 | wide_resnet50_2 | 224x224 | 68.88 | 11.45 | 129 | 130 | 131 | ### [timm](https://github.com/huggingface/pytorch-image-models) 132 | 133 | Model | Input Resolution | Params(M) | MACs(G) 134 | -------------------------------------------------------------------------------- /ptflops/__init__.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2019-2024 Sovrasov V. - All Rights Reserved 3 | * You may use, distribute and modify this code under the 4 | * terms of the MIT license. 5 | * You should have received a copy of the MIT license with 6 | * this file. If not visit https://opensource.org/licenses/MIT 7 | ''' 8 | 9 | 10 | from .flops_counter import FLOPS_BACKEND, get_model_complexity_info 11 | from .utils import flops_to_string, params_to_string 12 | 13 | __all__ = [ 14 | "get_model_complexity_info", 15 | "flops_to_string", 16 | "params_to_string", 17 | "FLOPS_BACKEND", 18 | ] 19 | -------------------------------------------------------------------------------- /ptflops/aten_engine.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2024 Sovrasov V. - All Rights Reserved 3 | * You may use, distribute and modify this code under the 4 | * terms of the MIT license. 5 | * You should have received a copy of the MIT license with 6 | * this file. If not visit https://opensource.org/licenses/MIT 7 | ''' 8 | 9 | 10 | import sys 11 | import traceback 12 | from collections import defaultdict 13 | from copy import deepcopy 14 | from functools import partial 15 | from typing import Dict, Optional, Tuple, Union 16 | 17 | import torch 18 | from torch.utils._python_dispatch import TorchDispatchMode 19 | 20 | from ptflops.pytorch_engine import get_model_parameters_number 21 | from ptflops.utils import flops_to_string 22 | from .aten_ops import ATEN_OPS_MAPPING 23 | 24 | 25 | class FlopCounterMode(TorchDispatchMode): 26 | def __init__(self, module=None, verbose=False, print_per_layer_stat=False, 27 | output_params=None, custom_hooks={}, ignored_ops=[]): 28 | self.verbose = verbose 29 | if output_params is None: 30 | output_params = defaultdict(dict) 31 | self.output_params = output_params 32 | self.print_fn = partial(print, **self.output_params['print_params']) 33 | self.all_ops = deepcopy(ATEN_OPS_MAPPING) 34 | self.all_ops.update(custom_hooks) 35 | self.ignored_ops = ignored_ops 36 | 37 | self.print_per_layer_stat = print_per_layer_stat 38 | self.flop_counts = defaultdict(lambda: defaultdict(int)) 39 | self.parents = ['Global'] 40 | self._total_complexity = None 41 | if module is not None: 42 | for name, mod in dict(module.named_children()).items(): 43 | mod.register_forward_pre_hook(self.enter_module(name)) 44 | mod.register_forward_hook(self.exit_module(name)) 45 | 46 | @property 47 | def complexity(self): 48 | return self._total_complexity 49 | 50 | def enter_module(self, name): 51 | def f(*args): 52 | self.parents.append(name) 53 | return f 54 | 55 | def exit_module(self, name): 56 | def f(*args): 57 | assert(self.parents[-1] == name) 58 | self.parents.pop() 59 | return f 60 | 61 | def __enter__(self): 62 | self.flop_counts.clear() 63 | super().__enter__() 64 | 65 | def __exit__(self, *args): 66 | self._total_complexity = sum(self.flop_counts['Global'].values()) 67 | if self.print_per_layer_stat: 68 | self.print_fn('Total:' + 69 | flops_to_string(self._total_complexity, 70 | **self.output_params['serialize_params'])) 71 | for mod in self.flop_counts.keys(): 72 | self.print_fn("Module: ", mod) 73 | for k, v in self.flop_counts[mod].items(): 74 | self.print_fn( 75 | f'{k}: ' + 76 | flops_to_string(v, **self.output_params['serialize_params'])) 77 | self.print_fn() 78 | super().__exit__(*args) 79 | 80 | def __torch_dispatch__(self, func, types, args=(), kwargs=None): 81 | def normalize_tuple(x): 82 | if not isinstance(x, tuple): 83 | return (x,) 84 | return x 85 | kwargs = kwargs if kwargs else {} 86 | 87 | out = func(*args, **kwargs) 88 | func_packet = func._overloadpacket 89 | 90 | if func_packet in self.ignored_ops: 91 | self.print_fn(f'Warning: {func_packet} operation is ignored') 92 | elif func_packet in self.all_ops: 93 | flop_count = self.all_ops[func_packet](args, normalize_tuple(out)) 94 | for par in self.parents: 95 | self.flop_counts[par][func_packet] += flop_count 96 | elif self.verbose: 97 | self.print_fn(f'Warning: {func_packet} operation is treated as a zero-op') 98 | 99 | return out 100 | 101 | 102 | def get_flops_aten(model, input_res, 103 | print_per_layer_stat=True, 104 | input_constructor=None, ost=sys.stdout, 105 | verbose=False, ignore_modules=[], 106 | custom_modules_hooks={}, 107 | output_precision=2, 108 | flops_units: Optional[str] = 'GMac', 109 | param_units: Optional[str] = 'M', 110 | extra_config: Dict = {}) -> Tuple[Union[int, None], 111 | Union[int, None]]: 112 | 113 | params_sum = get_model_parameters_number(model) 114 | model.eval() 115 | output_params = {'serialize_params': 116 | {'units': flops_units, 'precision': output_precision}, 117 | 'print_params': {'file': ost}} 118 | 119 | if input_constructor: 120 | batch = input_constructor(input_res) 121 | else: 122 | try: 123 | batch = torch.ones(()).new_empty((1, *input_res), 124 | dtype=next(model.parameters()).dtype, 125 | device=next(model.parameters()).device) 126 | except StopIteration: 127 | batch = torch.ones(()).new_empty((1, *input_res)) 128 | 129 | try: 130 | counter = FlopCounterMode(model, verbose, print_per_layer_stat, output_params, 131 | custom_modules_hooks, ignore_modules) 132 | with counter: 133 | if isinstance(batch, dict): 134 | _ = model(**batch) 135 | else: 136 | _ = model(batch) 137 | macs_count = counter.complexity 138 | 139 | except Exception as e: 140 | print("Flops estimation was not finished successfully because of" 141 | f" the following exception:\n{type(e)} : {e}") 142 | traceback.print_exc() 143 | 144 | return None, None 145 | 146 | return macs_count, params_sum 147 | -------------------------------------------------------------------------------- /ptflops/aten_ops.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2023 Sovrasov V. - All Rights Reserved 3 | * You may use, distribute and modify this code under the 4 | * terms of the MIT license. 5 | * You should have received a copy of the MIT license with 6 | * this file. If not visit https://opensource.org/licenses/MIT 7 | ''' 8 | 9 | from typing import Any, List 10 | 11 | import torch 12 | 13 | aten = torch.ops.aten 14 | 15 | 16 | def prod(x: torch.Size) -> int: 17 | res = 1 18 | for i in x: 19 | res *= i 20 | return res 21 | 22 | 23 | def matmul_flop(inputs: List[Any], outputs: List[Any]) -> int: 24 | """ 25 | Count flops for matmul. 26 | """ 27 | # Inputs should be a list of length 2. 28 | # Inputs contains the shapes of two matrices. 29 | input_shapes = [v.shape for v in inputs] 30 | assert len(input_shapes) == 2, input_shapes 31 | assert input_shapes[0][-1] == input_shapes[1][-2], input_shapes 32 | flop = prod(input_shapes[0]) * input_shapes[-1][-1] 33 | return flop 34 | 35 | 36 | def addmm_flop(inputs: List[Any], outputs: List[Any]) -> int: 37 | """ 38 | Count flops for fully connected layers (nn.Linear). 39 | Bias is considered if exists. 40 | """ 41 | # inputs: bias, input, weight 42 | input_shapes = [v.shape for v in inputs[1:3]] 43 | # input_shapes[0]: [batch size, input feature dimension] 44 | # input_shapes[1]: [batch size, output feature dimension] 45 | assert len(input_shapes[0]) == 2, input_shapes[0] 46 | assert len(input_shapes[1]) == 2, input_shapes[1] 47 | batch_size, input_dim = input_shapes[0] 48 | output_dim = input_shapes[1][1] 49 | flops = batch_size * input_dim * output_dim 50 | 51 | if inputs[0] is not None: 52 | flops += batch_size * output_dim 53 | 54 | return flops 55 | 56 | 57 | def bmm_flop(inputs: List[Any], outputs: List[Any]) -> int: 58 | """ 59 | Count flops for the bmm operation. 60 | """ 61 | # Inputs should be a list of length 2. 62 | # Inputs contains the shapes of two tensors. 63 | assert len(inputs) == 2, len(inputs) 64 | input_shapes = [v.shape for v in inputs] 65 | n, c, t = input_shapes[0] 66 | d = input_shapes[-1][-1] 67 | flop = n * c * t * d 68 | return flop 69 | 70 | 71 | def conv_flop_count( 72 | x_shape: torch.Size, 73 | w_shape: torch.Size, 74 | out_shape: torch.Size, 75 | transposed: bool = False, 76 | bias: bool = False, 77 | ) -> int: 78 | """ 79 | Count MACs for convolution. 80 | Summation is ignored when applying conv kernel, but counted for bias. 81 | Args: 82 | x_shape (list(int)): The input shape before convolution. 83 | w_shape (list(int)): The filter shape. 84 | out_shape (list(int)): The output shape after convolution. 85 | transposed (bool): is the convolution transposed 86 | bias (bool): is the bias counted 87 | Returns: 88 | int: the number of MACs 89 | """ 90 | batch_size = x_shape[0] 91 | conv_shape = (x_shape if transposed else out_shape)[2:] 92 | flop = batch_size * prod(w_shape) * prod(conv_shape) 93 | if bias: 94 | flop += batch_size * out_shape[1] * prod(out_shape[2:]) 95 | return flop 96 | 97 | 98 | def conv_flop(inputs: List[Any], outputs: List[Any]) -> int: 99 | """ 100 | Count flops for convolution. 101 | """ 102 | (input, w, b, stride, pad, dilation, 103 | transposed, _, groups) = inputs 104 | output = outputs[0] 105 | return conv_flop_count(input.shape, w.shape, output.shape, 106 | transposed=transposed, bias=b is not None) 107 | 108 | 109 | ATEN_OPS_MAPPING = { 110 | aten.mm: matmul_flop, 111 | aten.matmul: matmul_flop, 112 | aten.addmm: addmm_flop, 113 | aten.bmm: bmm_flop, 114 | aten.convolution: conv_flop, 115 | aten._convolution: conv_flop, 116 | } 117 | -------------------------------------------------------------------------------- /ptflops/flops_counter.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2019-2024 Sovrasov V. - All Rights Reserved 3 | * You may use, distribute and modify this code under the 4 | * terms of the MIT license. 5 | * You should have received a copy of the MIT license with 6 | * this file. If not visit https://opensource.org/licenses/MIT 7 | ''' 8 | 9 | import sys 10 | from enum import Enum 11 | from typing import Any, Callable, Dict, List, Optional, TextIO, Tuple, Union 12 | 13 | import torch.nn as nn 14 | 15 | from .aten_engine import get_flops_aten 16 | from .pytorch_engine import get_flops_pytorch 17 | from .utils import flops_to_string, params_to_string 18 | 19 | 20 | class FLOPS_BACKEND(Enum): 21 | PYTORCH = 'pytorch' 22 | ATEN = 'aten' 23 | 24 | 25 | def get_model_complexity_info(model: nn.Module, 26 | input_res: Tuple[int, ...], 27 | print_per_layer_stat: bool = True, 28 | as_strings: bool = True, 29 | input_constructor: Optional[Callable[[Tuple], Dict]] = None, 30 | ost: TextIO = sys.stdout, 31 | verbose: bool = False, 32 | ignore_modules: List[Union[nn.Module, Any]] = [], 33 | custom_modules_hooks: Dict[Union[nn.Module, Any], Any] = {}, 34 | backend: Union[str, FLOPS_BACKEND] = FLOPS_BACKEND.PYTORCH, 35 | flops_units: Optional[str] = None, 36 | param_units: Optional[str] = None, 37 | output_precision: int = 2, 38 | backend_specific_config: Dict = {}) -> Tuple[ 39 | Union[str, int, None], 40 | Union[str, int, None]]: 41 | """ 42 | Analyzes the input model and collects the amounts of parameters and MACs 43 | required to make a forward pass of the model. 44 | 45 | :param model: Input model to analyze 46 | :type model: nn.Module 47 | :param input_res: A tuple that sets the input resolution for the model. Batch 48 | dimension is added automatically: (3, 224, 224) -> (1, 3, 224, 224). 49 | :type input_res: Tuple[int, ...] 50 | :param print_per_layer_stat: Flag to enable or disable printing of per-layer 51 | MACs/params statistics. This feature works only for layers derived 52 | from torch.nn.Module. Other operations are ignored. 53 | :type print_per_layer_stat: bool 54 | :param as_strings: Flag that allows to get ready-to-print string representation 55 | of the final params/MACs estimations. Otherwise, a tuple with raw numbers 56 | will be returned. 57 | :type as_strings: bool 58 | :param input_constructor: A callable that takes the :input_res parameter and 59 | returns an output suitable for the model. It can be used if model requires 60 | more than one input tensor or any other kind of irregular input. 61 | :type input_constructor: Optional[Callable[[Tuple], Dict]] 62 | :param ost: A stream to print output. 63 | :type ost: TextIO 64 | :param verbose: Parameter to control printing of extra information and warnings. 65 | :type verbose: bool 66 | :param ignore_modules: A list of torch.nn.Module or torch.ops.aten modules to ignore. 67 | :type ignore_modules: List[Union[nn.Module, Any]] 68 | :param custom_modules_hooks: A dict that contains custom hooks for torch.nn.Module or 69 | torch.ops.aten modules. 70 | :type custom_modules_hooks: Dict[Union[nn.Module, Any], Any] 71 | :param backend: Backend that used for evaluating model complexity. 72 | :type backend: FLOPS_BACKEND 73 | :param flops_units: Units for string representation of MACs (GMac, MMac or KMac). 74 | :type flops_units: Optional[str] 75 | :param param_units: Units for string representation of params (M, K or B). 76 | :type param_units: Optional[str] 77 | :param output_precision: Floating point precision for representing MACs/params in 78 | given units. 79 | :type output_precision: int 80 | :param backend_specific_config: Extra configuration for a specific backend. 81 | :type backend_specific_config: dict 82 | 83 | Returns: 84 | Tuple[Union[str, int, None], Union[str, int, None]]: Return value is a tuple 85 | (macs, params): Nones in case of a failure during computations, or 86 | strings if :as_strings is true or integers otherwise. 87 | """ 88 | assert type(input_res) is tuple 89 | assert len(input_res) >= 1 90 | assert isinstance(model, nn.Module) 91 | 92 | if FLOPS_BACKEND(backend) == FLOPS_BACKEND.PYTORCH: 93 | flops_count, params_count = \ 94 | get_flops_pytorch(model, input_res, 95 | print_per_layer_stat, 96 | input_constructor, ost, 97 | verbose, ignore_modules, 98 | custom_modules_hooks, 99 | output_precision=output_precision, 100 | flops_units=flops_units, 101 | param_units=param_units, 102 | extra_config=backend_specific_config) 103 | elif FLOPS_BACKEND(backend) == FLOPS_BACKEND.ATEN: 104 | flops_count, params_count = get_flops_aten(model, input_res, 105 | print_per_layer_stat, 106 | input_constructor, ost, 107 | verbose, ignore_modules, 108 | custom_modules_hooks, 109 | output_precision=output_precision, 110 | flops_units=flops_units, 111 | param_units=param_units, 112 | extra_config=backend_specific_config) 113 | else: 114 | raise ValueError('Wrong backend name') 115 | 116 | if as_strings and flops_count is not None and params_count is not None: 117 | flops_string = flops_to_string( 118 | flops_count, 119 | units=flops_units, 120 | precision=output_precision 121 | ) 122 | params_string = params_to_string( 123 | params_count, 124 | units=param_units, 125 | precision=output_precision 126 | ) 127 | return flops_string, params_string 128 | 129 | return flops_count, params_count 130 | -------------------------------------------------------------------------------- /ptflops/pytorch_engine.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2021-2023 Sovrasov V. - All Rights Reserved 3 | * You may use, distribute and modify this code under the 4 | * terms of the MIT license. 5 | * You should have received a copy of the MIT license with 6 | * this file. If not visit https://opensource.org/licenses/MIT 7 | ''' 8 | 9 | import sys 10 | import traceback 11 | from functools import partial 12 | from typing import Dict, Optional, Tuple, Union 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | 18 | from .pytorch_ops import (CUSTOM_MODULES_MAPPING, FUNCTIONAL_MAPPING, 19 | MODULES_MAPPING, TENSOR_OPS_MAPPING) 20 | from .utils import flops_to_string, params_to_string 21 | 22 | 23 | def get_flops_pytorch(model, input_res, 24 | print_per_layer_stat=True, 25 | input_constructor=None, ost=sys.stdout, 26 | verbose=False, ignore_modules=[], 27 | custom_modules_hooks={}, 28 | output_precision=2, 29 | flops_units: Optional[str] = 'GMac', 30 | param_units: Optional[str] = 'M', 31 | extra_config: Dict = {}) -> Tuple[Union[int, None], 32 | Union[int, None]]: 33 | global CUSTOM_MODULES_MAPPING 34 | CUSTOM_MODULES_MAPPING = custom_modules_hooks 35 | flops_model = add_flops_counting_methods(model) 36 | flops_model.eval() 37 | flops_model.start_flops_count(ost=ost, verbose=verbose, 38 | ignore_list=ignore_modules) 39 | if input_constructor: 40 | batch = input_constructor(input_res) 41 | else: 42 | try: 43 | batch = torch.ones(()).new_empty((1, *input_res), 44 | dtype=next(flops_model.parameters()).dtype, 45 | device=next(flops_model.parameters()).device) 46 | except StopIteration: 47 | batch = torch.ones(()).new_empty((1, *input_res)) 48 | 49 | enable_func_ops_patching = extra_config.get('count_functional', True) 50 | torch_functional_flops = [] 51 | torch_tensor_ops_flops = [] 52 | if enable_func_ops_patching: 53 | patch_functional(torch_functional_flops) 54 | patch_tensor_ops(torch_tensor_ops_flops) 55 | 56 | def reset_environment(): 57 | flops_model.stop_flops_count() 58 | if enable_func_ops_patching: 59 | unpatch_functional() 60 | unpatch_tensor_ops() 61 | global CUSTOM_MODULES_MAPPING 62 | CUSTOM_MODULES_MAPPING = {} 63 | 64 | try: 65 | if isinstance(batch, dict): 66 | _ = flops_model(**batch) 67 | else: 68 | _ = flops_model(batch) 69 | flops_count, params_count = flops_model.compute_average_flops_cost() 70 | flops_count += sum(torch_functional_flops) 71 | flops_count += sum(torch_tensor_ops_flops) 72 | 73 | except Exception as e: 74 | print("Flops estimation was not finished successfully because of" 75 | f" the following exception:\n{type(e)} : {e}") 76 | traceback.print_exc() 77 | reset_environment() 78 | 79 | return None, None 80 | 81 | if print_per_layer_stat: 82 | print_model_with_flops( 83 | flops_model, 84 | flops_count, 85 | params_count, 86 | ost=ost, 87 | flops_units=flops_units, 88 | param_units=param_units, 89 | precision=output_precision 90 | ) 91 | reset_environment() 92 | 93 | return int(flops_count), params_count 94 | 95 | 96 | def accumulate_flops(self): 97 | if is_supported_instance(self): 98 | return self.__flops__ 99 | else: 100 | sum = 0 101 | for m in self.children(): 102 | sum += m.accumulate_flops() 103 | return sum 104 | 105 | 106 | def print_model_with_flops(model, total_flops, total_params, 107 | flops_units: Optional[str] = 'GMac', 108 | param_units: Optional[str] = 'M', 109 | precision=3, ost=sys.stdout): 110 | if total_flops < 1: 111 | total_flops = 1 112 | if total_params < 1: 113 | total_params = 1 114 | 115 | def accumulate_params(self): 116 | if is_supported_instance(self): 117 | return self.__params__ 118 | else: 119 | sum = 0 120 | for m in self.children(): 121 | sum += m.accumulate_params() 122 | return sum 123 | 124 | def flops_repr(self): 125 | accumulated_params_num = self.accumulate_params() 126 | accumulated_flops_cost = self.accumulate_flops() / model.__batch_counter__ 127 | if accumulated_params_num > total_params: 128 | print('Warning: parameters of some of the modules were counted twice because' 129 | ' of multiple links to the same modules.' 130 | ' Extended per layer parameters num statistic could be unreliable.') 131 | 132 | return ', '.join([params_to_string(accumulated_params_num, 133 | units=param_units, precision=precision), 134 | '{:.3%} Params'.format(accumulated_params_num / total_params), 135 | flops_to_string(accumulated_flops_cost, 136 | units=flops_units, precision=precision), 137 | '{:.3%} MACs'.format(accumulated_flops_cost / total_flops), 138 | self.original_extra_repr()]) 139 | 140 | def add_extra_repr(m): 141 | m.accumulate_flops = accumulate_flops.__get__(m) 142 | m.accumulate_params = accumulate_params.__get__(m) 143 | flops_extra_repr = flops_repr.__get__(m) 144 | if m.extra_repr != flops_extra_repr: 145 | m.original_extra_repr = m.extra_repr 146 | m.extra_repr = flops_extra_repr 147 | assert m.extra_repr != m.original_extra_repr 148 | 149 | def del_extra_repr(m): 150 | if hasattr(m, 'original_extra_repr'): 151 | m.extra_repr = m.original_extra_repr 152 | del m.original_extra_repr 153 | if hasattr(m, 'accumulate_flops'): 154 | del m.accumulate_flops 155 | 156 | model.apply(add_extra_repr) 157 | print(repr(model), file=ost) 158 | model.apply(del_extra_repr) 159 | 160 | 161 | def get_model_parameters_number(model): 162 | params_num = sum(p.numel() for p in model.parameters() if p.requires_grad) 163 | return params_num 164 | 165 | 166 | def add_flops_counting_methods(net_main_module): 167 | # adding additional methods to the existing module object, 168 | # this is done this way so that each function has access to self object 169 | net_main_module.start_flops_count = start_flops_count.__get__(net_main_module) 170 | net_main_module.stop_flops_count = stop_flops_count.__get__(net_main_module) 171 | net_main_module.reset_flops_count = reset_flops_count.__get__(net_main_module) 172 | net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__( 173 | net_main_module) 174 | 175 | net_main_module.reset_flops_count() 176 | 177 | return net_main_module 178 | 179 | 180 | def compute_average_flops_cost(self): 181 | """ 182 | A method that will be available after add_flops_counting_methods() is called 183 | on a desired net object. 184 | 185 | Returns current mean flops consumption per image. 186 | 187 | """ 188 | 189 | for m in self.modules(): 190 | m.accumulate_flops = accumulate_flops.__get__(m) 191 | 192 | flops_sum = self.accumulate_flops() 193 | 194 | for m in self.modules(): 195 | if hasattr(m, 'accumulate_flops'): 196 | del m.accumulate_flops 197 | 198 | params_sum = get_model_parameters_number(self) 199 | return flops_sum / self.__batch_counter__, params_sum 200 | 201 | 202 | def start_flops_count(self, **kwargs): 203 | """ 204 | A method that will be available after add_flops_counting_methods() is called 205 | on a desired net object. 206 | 207 | Activates the computation of mean flops consumption per image. 208 | Call it before you run the network. 209 | 210 | """ 211 | add_batch_counter_hook_function(self) 212 | 213 | seen_types = set() 214 | 215 | def add_flops_counter_hook_function(module, ost, verbose, ignore_list): 216 | if type(module) in ignore_list: 217 | seen_types.add(type(module)) 218 | if is_supported_instance(module): 219 | module.__params__ = 0 220 | elif is_supported_instance(module): 221 | if hasattr(module, '__flops_handle__'): 222 | return 223 | if type(module) in CUSTOM_MODULES_MAPPING: 224 | handle = module.register_forward_hook( 225 | CUSTOM_MODULES_MAPPING[type(module)]) 226 | else: 227 | handle = module.register_forward_hook(MODULES_MAPPING[type(module)]) 228 | module.__flops_handle__ = handle 229 | seen_types.add(type(module)) 230 | else: 231 | if verbose and not type(module) in (nn.Sequential, nn.ModuleList) and \ 232 | not type(module) in seen_types: 233 | print('Warning: module ' + type(module).__name__ + 234 | ' is treated as a zero-op.', file=ost) 235 | seen_types.add(type(module)) 236 | 237 | self.apply(partial(add_flops_counter_hook_function, **kwargs)) 238 | 239 | 240 | def stop_flops_count(self): 241 | """ 242 | A method that will be available after add_flops_counting_methods() is called 243 | on a desired net object. 244 | 245 | Stops computing the mean flops consumption per image. 246 | Call whenever you want to pause the computation. 247 | 248 | """ 249 | remove_batch_counter_hook_function(self) 250 | self.apply(remove_flops_counter_hook_function) 251 | self.apply(remove_flops_counter_variables) 252 | 253 | 254 | def reset_flops_count(self): 255 | """ 256 | A method that will be available after add_flops_counting_methods() is called 257 | on a desired net object. 258 | 259 | Resets statistics computed so far. 260 | 261 | """ 262 | add_batch_counter_variables_or_reset(self) 263 | self.apply(add_flops_counter_variable_or_reset) 264 | 265 | 266 | # ---- Internal functions 267 | def batch_counter_hook(module, input, output): 268 | batch_size = 1 269 | if len(input) > 0: 270 | # Can have multiple inputs, getting the first one 271 | input = input[0] 272 | batch_size = len(input) 273 | else: 274 | pass 275 | print('Warning! No positional inputs found for a module,' 276 | ' assuming batch size is 1.') 277 | module.__batch_counter__ += batch_size 278 | 279 | 280 | def add_batch_counter_variables_or_reset(module): 281 | 282 | module.__batch_counter__ = 0 283 | 284 | 285 | def add_batch_counter_hook_function(module): 286 | if hasattr(module, '__batch_counter_handle__'): 287 | return 288 | 289 | handle = module.register_forward_hook(batch_counter_hook) 290 | module.__batch_counter_handle__ = handle 291 | 292 | 293 | def remove_batch_counter_hook_function(module): 294 | if hasattr(module, '__batch_counter_handle__'): 295 | module.__batch_counter_handle__.remove() 296 | del module.__batch_counter_handle__ 297 | 298 | 299 | def add_flops_counter_variable_or_reset(module): 300 | if is_supported_instance(module): 301 | if hasattr(module, '__flops__') or hasattr(module, '__params__'): 302 | print('Warning: variables __flops__ or __params__ are already ' 303 | 'defined for the module' + type(module).__name__ + 304 | ' ptflops can affect your code!') 305 | module.__ptflops_backup_flops__ = module.__flops__ 306 | module.__ptflops_backup_params__ = module.__params__ 307 | module.__flops__ = 0 308 | module.__params__ = get_model_parameters_number(module) 309 | 310 | 311 | def is_supported_instance(module): 312 | if type(module) in MODULES_MAPPING or type(module) in CUSTOM_MODULES_MAPPING: 313 | return True 314 | return False 315 | 316 | 317 | def remove_flops_counter_hook_function(module): 318 | if is_supported_instance(module): 319 | if hasattr(module, '__flops_handle__'): 320 | module.__flops_handle__.remove() 321 | del module.__flops_handle__ 322 | 323 | 324 | def remove_flops_counter_variables(module): 325 | if is_supported_instance(module): 326 | if hasattr(module, '__flops__'): 327 | del module.__flops__ 328 | if hasattr(module, '__ptflops_backup_flops__'): 329 | module.__flops__ = module.__ptflops_backup_flops__ 330 | if hasattr(module, '__params__'): 331 | del module.__params__ 332 | if hasattr(module, '__ptflops_backup_params__'): 333 | module.__params__ = module.__ptflops_backup_params__ 334 | 335 | 336 | class torch_function_wrapper: 337 | def __init__(self, op, handler, collector) -> None: 338 | self.collector = collector 339 | self.op = op 340 | self.handler = handler 341 | 342 | def __call__(self, *args, **kwds): 343 | flops = self.handler(*args, **kwds) 344 | self.collector.append(flops) 345 | return self.op(*args, **kwds) 346 | 347 | 348 | def patch_functional(collector): 349 | # F.linear = torch_function_wrapper(F.linear, FUNCTIONAL_MAPPING[F.linear], collector) 350 | F.relu = torch_function_wrapper(F.relu, FUNCTIONAL_MAPPING[F.relu], collector) 351 | F.prelu = torch_function_wrapper(F.prelu, FUNCTIONAL_MAPPING[F.prelu], collector) 352 | F.elu = torch_function_wrapper(F.elu, FUNCTIONAL_MAPPING[F.elu], collector) 353 | F.relu6 = torch_function_wrapper(F.relu6, FUNCTIONAL_MAPPING[F.relu6], collector) 354 | F.gelu = torch_function_wrapper(F.gelu, FUNCTIONAL_MAPPING[F.gelu], collector) 355 | 356 | F.avg_pool1d = torch_function_wrapper(F.avg_pool1d, 357 | FUNCTIONAL_MAPPING[F.avg_pool1d], collector) 358 | F.avg_pool2d = torch_function_wrapper(F.avg_pool2d, 359 | FUNCTIONAL_MAPPING[F.avg_pool2d], collector) 360 | F.avg_pool3d = torch_function_wrapper(F.avg_pool3d, 361 | FUNCTIONAL_MAPPING[F.avg_pool3d], collector) 362 | F.max_pool1d = torch_function_wrapper(F.max_pool1d, 363 | FUNCTIONAL_MAPPING[F.max_pool1d], collector) 364 | F.max_pool2d = torch_function_wrapper(F.max_pool2d, 365 | FUNCTIONAL_MAPPING[F.max_pool2d], collector) 366 | F.max_pool3d = torch_function_wrapper(F.max_pool3d, 367 | FUNCTIONAL_MAPPING[F.max_pool3d], collector) 368 | F.adaptive_avg_pool1d = torch_function_wrapper( 369 | F.adaptive_avg_pool1d, FUNCTIONAL_MAPPING[F.adaptive_avg_pool1d], collector) 370 | F.adaptive_avg_pool2d = torch_function_wrapper( 371 | F.adaptive_avg_pool2d, FUNCTIONAL_MAPPING[F.adaptive_avg_pool2d], collector) 372 | F.adaptive_avg_pool3d = torch_function_wrapper( 373 | F.adaptive_avg_pool3d, FUNCTIONAL_MAPPING[F.adaptive_avg_pool3d], collector) 374 | F.adaptive_max_pool1d = torch_function_wrapper( 375 | F.adaptive_max_pool1d, FUNCTIONAL_MAPPING[F.adaptive_max_pool1d], collector) 376 | F.adaptive_max_pool2d = torch_function_wrapper( 377 | F.adaptive_max_pool2d, FUNCTIONAL_MAPPING[F.adaptive_max_pool2d], collector) 378 | F.adaptive_max_pool3d = torch_function_wrapper( 379 | F.adaptive_max_pool3d, FUNCTIONAL_MAPPING[F.adaptive_max_pool3d], collector) 380 | 381 | F.softmax = torch_function_wrapper( 382 | F.softmax, FUNCTIONAL_MAPPING[F.softmax], collector) 383 | 384 | F.upsample = torch_function_wrapper( 385 | F.upsample, FUNCTIONAL_MAPPING[F.upsample], collector) 386 | F.interpolate = torch_function_wrapper( 387 | F.interpolate, FUNCTIONAL_MAPPING[F.interpolate], collector) 388 | 389 | if hasattr(F, "silu"): 390 | F.silu = torch_function_wrapper(F.silu, FUNCTIONAL_MAPPING[F.silu], collector) 391 | 392 | 393 | def unpatch_functional(): 394 | # F.linear = F.linear.op 395 | F.relu = F.relu.op 396 | F.prelu = F.prelu.op 397 | F.elu = F.elu.op 398 | F.relu6 = F.relu6.op 399 | F.gelu = F.gelu.op 400 | if hasattr(F, "silu"): 401 | F.silu = F.silu.op 402 | 403 | F.avg_pool1d = F.avg_pool1d.op 404 | F.avg_pool2d = F.avg_pool2d.op 405 | F.avg_pool3d = F.avg_pool3d.op 406 | F.max_pool1d = F.max_pool1d.op 407 | F.max_pool2d = F.max_pool2d.op 408 | F.max_pool3d = F.max_pool3d.op 409 | F.adaptive_avg_pool1d = F.adaptive_avg_pool1d.op 410 | F.adaptive_avg_pool2d = F.adaptive_avg_pool2d.op 411 | F.adaptive_avg_pool3d = F.adaptive_avg_pool3d.op 412 | F.adaptive_max_pool1d = F.adaptive_max_pool1d.op 413 | F.adaptive_max_pool2d = F.adaptive_max_pool2d.op 414 | F.adaptive_max_pool3d = F.adaptive_max_pool3d.op 415 | 416 | F.softmax = F.softmax.op 417 | 418 | F.upsample = F.upsample.op 419 | F.interpolate = F.interpolate.op 420 | 421 | 422 | def wrap_tensor_op(op, collector): 423 | tensor_op_handler = torch_function_wrapper( 424 | op, TENSOR_OPS_MAPPING[op], collector) 425 | 426 | def wrapper(*args, **kwargs): 427 | return tensor_op_handler(*args, **kwargs) 428 | 429 | wrapper.op = tensor_op_handler.op 430 | 431 | return wrapper 432 | 433 | 434 | def patch_tensor_ops(collector): 435 | torch.matmul = torch_function_wrapper( 436 | torch.matmul, TENSOR_OPS_MAPPING[torch.matmul], collector) 437 | torch.Tensor.matmul = wrap_tensor_op(torch.Tensor.matmul, collector) 438 | torch.mm = torch_function_wrapper( 439 | torch.mm, TENSOR_OPS_MAPPING[torch.mm], collector) 440 | torch.Tensor.mm = wrap_tensor_op(torch.Tensor.mm, collector) 441 | torch.bmm = torch_function_wrapper( 442 | torch.bmm, TENSOR_OPS_MAPPING[torch.bmm], collector) 443 | torch.Tensor.bmm = wrap_tensor_op(torch.Tensor.bmm, collector) 444 | 445 | torch.addmm = torch_function_wrapper( 446 | torch.addmm, TENSOR_OPS_MAPPING[torch.addmm], collector) 447 | torch.Tensor.addmm = wrap_tensor_op(torch.Tensor.addmm, collector) 448 | torch.baddbmm = torch_function_wrapper( 449 | torch.baddbmm, TENSOR_OPS_MAPPING[torch.baddbmm], collector) 450 | 451 | torch.mul = torch_function_wrapper( 452 | torch.mul, TENSOR_OPS_MAPPING[torch.mul], collector) 453 | torch.Tensor.mul = wrap_tensor_op(torch.Tensor.mul, collector) 454 | torch.add = torch_function_wrapper( 455 | torch.add, TENSOR_OPS_MAPPING[torch.add], collector) 456 | torch.Tensor.add = wrap_tensor_op(torch.Tensor.add, collector) 457 | 458 | 459 | def unpatch_tensor_ops(): 460 | torch.matmul = torch.matmul.op 461 | torch.Tensor.matmul = torch.Tensor.matmul.op 462 | torch.mm = torch.mm.op 463 | torch.Tensor.mm = torch.Tensor.mm.op 464 | torch.bmm = torch.bmm.op 465 | torch.Tensor.bmm = torch.Tensor.bmm.op 466 | 467 | torch.addmm = torch.addmm.op 468 | torch.Tensor.addmm = torch.Tensor.addmm.op 469 | torch.baddbmm = torch.baddbmm.op 470 | 471 | torch.mul = torch.mul.op 472 | torch.Tensor.mul = torch.Tensor.mul.op 473 | torch.add = torch.add.op 474 | torch.Tensor.add = torch.Tensor.add.op 475 | -------------------------------------------------------------------------------- /ptflops/pytorch_ops.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2021-2023 Sovrasov V. - All Rights Reserved 3 | * You may use, distribute and modify this code under the 4 | * terms of the MIT license. 5 | * You should have received a copy of the MIT license with 6 | * this file. If not visit https://opensource.org/licenses/MIT 7 | ''' 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | 15 | def empty_flops_counter_hook(module, input, output): 16 | module.__flops__ += 0 17 | 18 | 19 | def upsample_flops_counter_hook(module, input, output): 20 | output_size = output[0] 21 | batch_size = output_size.shape[0] 22 | output_elements_count = batch_size 23 | for val in output_size.shape[1:]: 24 | output_elements_count *= val 25 | module.__flops__ += int(output_elements_count) 26 | 27 | 28 | def relu_flops_counter_hook(module, input, output): 29 | active_elements_count = output.numel() 30 | module.__flops__ += int(active_elements_count) 31 | 32 | 33 | def linear_flops_counter_hook(module, input, output): 34 | input = input[0] 35 | # pytorch checks dimensions, so here we don't care much 36 | output_last_dim = output.shape[-1] 37 | input_last_dim = input.shape[-1] 38 | pre_last_dims_prod = np.prod(input.shape[0:-1], dtype=np.int64) 39 | bias_flops = output_last_dim if module.bias is not None else 0 40 | module.__flops__ += int((input_last_dim * output_last_dim + bias_flops) 41 | * pre_last_dims_prod) 42 | 43 | 44 | def pool_flops_counter_hook(module, input, output): 45 | input = input[0] 46 | module.__flops__ += int(np.prod(input.shape, dtype=np.int64)) 47 | 48 | 49 | def bn_flops_counter_hook(module, input, output): 50 | input = input[0] 51 | 52 | batch_flops = np.prod(input.shape, dtype=np.int64) 53 | if hasattr(module, "affine") and module.affine: 54 | batch_flops *= 2 55 | module.__flops__ += int(batch_flops) 56 | 57 | 58 | def conv_flops_counter_hook(conv_module, input, output, extra_per_position_flops=0): 59 | # Can have multiple inputs, getting the first one 60 | input = input[0] 61 | 62 | batch_size = input.shape[0] 63 | output_dims = list(output.shape[2:]) 64 | 65 | kernel_dims = list(conv_module.kernel_size) 66 | in_channels = conv_module.in_channels 67 | out_channels = conv_module.out_channels 68 | groups = conv_module.groups 69 | 70 | filters_per_channel = out_channels // groups 71 | conv_per_position_flops = int(np.prod(kernel_dims, dtype=np.int64)) * \ 72 | (in_channels * filters_per_channel + extra_per_position_flops) 73 | 74 | active_elements_count = batch_size * int(np.prod(output_dims, dtype=np.int64)) 75 | 76 | overall_conv_flops = conv_per_position_flops * active_elements_count 77 | 78 | bias_flops = 0 79 | 80 | if conv_module.bias is not None: 81 | 82 | bias_flops = out_channels * active_elements_count 83 | 84 | overall_flops = overall_conv_flops + bias_flops 85 | 86 | conv_module.__flops__ += int(overall_flops) 87 | 88 | 89 | def deformable_conv_flops_counter_hook(conv_module, input, output): 90 | # 20 = 4 x 5 is an approximate cost of billinear interpolation, 2x2 grid is used 91 | # 4 is an approximate cost of fractional coordinates computation 92 | deformable_conv_extra_complexity = 20 + 4 93 | # consider also modulation multiplication 94 | if len(input) == 3 and input[2] is not None: 95 | deformable_conv_extra_complexity += 1 96 | conv_flops_counter_hook(conv_module, input, output, deformable_conv_extra_complexity) 97 | 98 | 99 | def rnn_flops(flops, rnn_module, w_ih, w_hh, input_size): 100 | # matrix matrix mult ih state and internal state 101 | flops += w_ih.shape[0]*w_ih.shape[1] 102 | # matrix matrix mult hh state and internal state 103 | flops += w_hh.shape[0]*w_hh.shape[1] 104 | if isinstance(rnn_module, (nn.RNN, nn.RNNCell)): 105 | # add both operations 106 | flops += rnn_module.hidden_size 107 | elif isinstance(rnn_module, (nn.GRU, nn.GRUCell)): 108 | # hadamard of r 109 | flops += rnn_module.hidden_size 110 | # adding operations from both states 111 | flops += rnn_module.hidden_size*3 112 | # last two hadamard product and add 113 | flops += rnn_module.hidden_size*3 114 | elif isinstance(rnn_module, (nn.LSTM, nn.LSTMCell)): 115 | # adding operations from both states 116 | flops += rnn_module.hidden_size*4 117 | # two hadamard product and add for C state 118 | flops += rnn_module.hidden_size + rnn_module.hidden_size + rnn_module.hidden_size 119 | # final hadamard 120 | flops += rnn_module.hidden_size + rnn_module.hidden_size + rnn_module.hidden_size 121 | return flops 122 | 123 | 124 | def rnn_flops_counter_hook(rnn_module, input, output): 125 | """ 126 | Takes into account batch goes at first position, contrary 127 | to pytorch common rule (but actually it doesn't matter). 128 | If sigmoid and tanh are hard, only a comparison FLOPS should be accurate 129 | """ 130 | flops = 0 131 | # input is a tuple containing a sequence to process and (optionally) hidden state 132 | inp = input[0] 133 | batch_size = inp.shape[0] 134 | seq_length = inp.shape[1] 135 | num_layers = rnn_module.num_layers 136 | 137 | for i in range(num_layers): 138 | w_ih = rnn_module.__getattr__('weight_ih_l' + str(i)) 139 | w_hh = rnn_module.__getattr__('weight_hh_l' + str(i)) 140 | if i == 0: 141 | input_size = rnn_module.input_size 142 | else: 143 | input_size = rnn_module.hidden_size 144 | flops = rnn_flops(flops, rnn_module, w_ih, w_hh, input_size) 145 | if rnn_module.bias: 146 | b_ih = rnn_module.__getattr__('bias_ih_l' + str(i)) 147 | b_hh = rnn_module.__getattr__('bias_hh_l' + str(i)) 148 | flops += b_ih.shape[0] + b_hh.shape[0] 149 | 150 | flops *= batch_size 151 | flops *= seq_length 152 | if rnn_module.bidirectional: 153 | flops *= 2 154 | rnn_module.__flops__ += int(flops) 155 | 156 | 157 | def rnn_cell_flops_counter_hook(rnn_cell_module, input, output): 158 | flops = 0 159 | inp = input[0] 160 | batch_size = inp.shape[0] 161 | w_ih = rnn_cell_module.__getattr__('weight_ih') 162 | w_hh = rnn_cell_module.__getattr__('weight_hh') 163 | input_size = inp.shape[1] 164 | flops = rnn_flops(flops, rnn_cell_module, w_ih, w_hh, input_size) 165 | if rnn_cell_module.bias: 166 | b_ih = rnn_cell_module.__getattr__('bias_ih') 167 | b_hh = rnn_cell_module.__getattr__('bias_hh') 168 | flops += b_ih.shape[0] + b_hh.shape[0] 169 | 170 | flops *= batch_size 171 | rnn_cell_module.__flops__ += int(flops) 172 | 173 | 174 | def multihead_attention_counter_hook(multihead_attention_module, input, output): 175 | flops = 0 176 | 177 | q, k, v = input 178 | 179 | batch_first = multihead_attention_module.batch_first \ 180 | if hasattr(multihead_attention_module, 'batch_first') else False 181 | if batch_first: 182 | batch_size = q.shape[0] 183 | len_idx = 1 184 | else: 185 | batch_size = q.shape[1] 186 | len_idx = 0 187 | 188 | dim_idx = 2 189 | 190 | qdim = q.shape[dim_idx] 191 | kdim = k.shape[dim_idx] 192 | vdim = v.shape[dim_idx] 193 | 194 | qlen = q.shape[len_idx] 195 | klen = k.shape[len_idx] 196 | vlen = v.shape[len_idx] 197 | 198 | num_heads = multihead_attention_module.num_heads 199 | assert qdim == multihead_attention_module.embed_dim 200 | 201 | if multihead_attention_module.kdim is None: 202 | assert kdim == qdim 203 | if multihead_attention_module.vdim is None: 204 | assert vdim == qdim 205 | 206 | flops = 0 207 | 208 | # Q scaling 209 | flops += qlen * qdim 210 | 211 | # Initial projections 212 | flops += ( 213 | (qlen * qdim * qdim) # QW 214 | + (klen * kdim * kdim) # KW 215 | + (vlen * vdim * vdim) # VW 216 | ) 217 | 218 | if multihead_attention_module.in_proj_bias is not None: 219 | flops += (qlen + klen + vlen) * qdim 220 | 221 | # attention heads: scale, matmul, softmax, matmul 222 | qk_head_dim = qdim // num_heads 223 | v_head_dim = vdim // num_heads 224 | 225 | head_flops = ( 226 | (qlen * klen * qk_head_dim) # QK^T 227 | + (qlen * klen) # softmax 228 | + (qlen * klen * v_head_dim) # AV 229 | ) 230 | 231 | flops += num_heads * head_flops 232 | 233 | # final projection, bias is always enabled 234 | flops += qlen * vdim * (vdim + 1) 235 | 236 | flops *= batch_size 237 | multihead_attention_module.__flops__ += int(flops) 238 | 239 | 240 | def timm_attention_counter_hook(attention_module, input, output): 241 | flops = 0 242 | B, N, C = input[0].shape # [Batch_size, Seq_len, Dimension] 243 | 244 | # QKV projection is already covered in MODULES_MAPPING 245 | 246 | # Q scaling 247 | flops += N * attention_module.head_dim * attention_module.num_heads 248 | 249 | # head flops 250 | head_flops = ( 251 | (N * N * attention_module.head_dim) # QK^T 252 | + (N * N) # softmax 253 | + (N * N * attention_module.head_dim) # AV 254 | ) 255 | flops += head_flops * attention_module.num_heads 256 | 257 | # Final projection is already covered in MODULES_MAPPING 258 | 259 | flops *= B 260 | attention_module.__flops__ += int(flops) 261 | 262 | 263 | CUSTOM_MODULES_MAPPING = {} 264 | 265 | MODULES_MAPPING = { 266 | # convolutions 267 | nn.Conv1d: conv_flops_counter_hook, 268 | nn.Conv2d: conv_flops_counter_hook, 269 | nn.Conv3d: conv_flops_counter_hook, 270 | # activations 271 | nn.ReLU: relu_flops_counter_hook, 272 | nn.PReLU: relu_flops_counter_hook, 273 | nn.ELU: relu_flops_counter_hook, 274 | nn.LeakyReLU: relu_flops_counter_hook, 275 | nn.ReLU6: relu_flops_counter_hook, 276 | # poolings 277 | nn.MaxPool1d: pool_flops_counter_hook, 278 | nn.AvgPool1d: pool_flops_counter_hook, 279 | nn.AvgPool2d: pool_flops_counter_hook, 280 | nn.MaxPool2d: pool_flops_counter_hook, 281 | nn.MaxPool3d: pool_flops_counter_hook, 282 | nn.AvgPool3d: pool_flops_counter_hook, 283 | nn.AdaptiveMaxPool1d: pool_flops_counter_hook, 284 | nn.AdaptiveAvgPool1d: pool_flops_counter_hook, 285 | nn.AdaptiveMaxPool2d: pool_flops_counter_hook, 286 | nn.AdaptiveAvgPool2d: pool_flops_counter_hook, 287 | nn.AdaptiveMaxPool3d: pool_flops_counter_hook, 288 | nn.AdaptiveAvgPool3d: pool_flops_counter_hook, 289 | # BNs 290 | nn.BatchNorm1d: bn_flops_counter_hook, 291 | nn.BatchNorm2d: bn_flops_counter_hook, 292 | nn.BatchNorm3d: bn_flops_counter_hook, 293 | 294 | nn.InstanceNorm1d: bn_flops_counter_hook, 295 | nn.InstanceNorm2d: bn_flops_counter_hook, 296 | nn.InstanceNorm3d: bn_flops_counter_hook, 297 | nn.GroupNorm: bn_flops_counter_hook, 298 | nn.LayerNorm: bn_flops_counter_hook, 299 | # FC 300 | nn.Linear: linear_flops_counter_hook, 301 | # Upscale 302 | nn.Upsample: upsample_flops_counter_hook, 303 | # Deconvolution 304 | nn.ConvTranspose1d: conv_flops_counter_hook, 305 | nn.ConvTranspose2d: conv_flops_counter_hook, 306 | nn.ConvTranspose3d: conv_flops_counter_hook, 307 | # RNN 308 | nn.RNN: rnn_flops_counter_hook, 309 | nn.GRU: rnn_flops_counter_hook, 310 | nn.LSTM: rnn_flops_counter_hook, 311 | nn.RNNCell: rnn_cell_flops_counter_hook, 312 | nn.LSTMCell: rnn_cell_flops_counter_hook, 313 | nn.GRUCell: rnn_cell_flops_counter_hook, 314 | nn.MultiheadAttention: multihead_attention_counter_hook 315 | } 316 | 317 | if hasattr(nn, 'GELU'): 318 | MODULES_MAPPING[nn.GELU] = relu_flops_counter_hook 319 | 320 | try: 321 | import torchvision.ops as tops 322 | MODULES_MAPPING[tops.DeformConv2d] = deformable_conv_flops_counter_hook 323 | except ImportError: 324 | pass 325 | 326 | try: 327 | from timm.models.vision_transformer import Attention as timm_Attention 328 | MODULES_MAPPING[timm_Attention] = timm_attention_counter_hook 329 | except ImportError: 330 | pass 331 | 332 | 333 | def _linear_functional_flops_hook(input, weight, bias=None): 334 | out_features = weight.shape[0] 335 | macs = input.numel() * out_features 336 | if bias is not None: 337 | macs += out_features 338 | return macs 339 | 340 | 341 | def _numel_functional_flops_hook(input, *args, **kwargs): 342 | return input.numel() 343 | 344 | 345 | def _interpolate_functional_flops_hook(*args, **kwargs): 346 | input = kwargs.get('input', None) 347 | if input is None and len(args) > 0: 348 | input = args[0] 349 | 350 | assert input.dim() - 2 > 0, "Input of interpolate should have NC... layout" 351 | 352 | size = kwargs.get('size', None) 353 | if size is None and len(args) > 1: 354 | size = args[1] 355 | 356 | if size is not None: 357 | if isinstance(size, tuple) or isinstance(size, list): 358 | return int(np.prod(size, dtype=np.int64)) * \ 359 | np.prod(input.shape[:2], dtype=np.int64) 360 | else: 361 | return int(size) ** (input.dim() - 2) * \ 362 | np.prod(input.shape[:2], dtype=np.int64) 363 | 364 | scale_factor = kwargs.get('scale_factor', None) 365 | if scale_factor is None and len(args) > 2: 366 | scale_factor = args[2] 367 | assert scale_factor is not None, "either size or scale_factor" 368 | "should be passes to interpolate" 369 | 370 | flops = input.numel() 371 | if isinstance(scale_factor, tuple) and len(scale_factor) == len(input.shape) - 2: 372 | flops *= int(np.prod(scale_factor, dtype=np.int64)) 373 | else: # NC... layout is assumed, see interpolate docs 374 | flops *= scale_factor ** (input.dim() - 2) 375 | 376 | return flops 377 | 378 | 379 | def _matmul_tensor_flops_hook(input, other, *args, **kwargs): 380 | flops = np.prod(input.shape, dtype=np.int64) * other.shape[-1] 381 | return flops 382 | 383 | 384 | def _addmm_tensor_flops_hook(input, mat1, mat2, *, beta=1, alpha=1, out=None): 385 | flops = np.prod(mat1.shape, dtype=np.int64) * mat2.shape[-1] 386 | if beta != 0: 387 | flops += np.prod(input.shape, dtype=np.int64) 388 | return flops 389 | 390 | 391 | def _elementwise_tensor_flops_hook(input, other, *args, **kwargs): 392 | if not torch.is_tensor(input): 393 | if torch.is_tensor(other): 394 | return np.prod(other.shape, dtype=np.int64) 395 | else: 396 | return 1 397 | elif not torch.is_tensor(other): 398 | return np.prod(input.shape, dtype=np.int64) 399 | else: 400 | dim_input = len(input.shape) 401 | dim_other = len(other.shape) 402 | max_dim = max(dim_input, dim_other) 403 | 404 | final_shape = [] 405 | for i in range(max_dim): 406 | in_i = input.shape[i] if i < dim_input else 1 407 | ot_i = other.shape[i] if i < dim_other else 1 408 | if in_i > ot_i: 409 | final_shape.append(in_i) 410 | else: 411 | final_shape.append(ot_i) 412 | flops = np.prod(final_shape, dtype=np.int64) 413 | return flops 414 | 415 | 416 | FUNCTIONAL_MAPPING = { 417 | F.linear: _linear_functional_flops_hook, 418 | F.relu: _numel_functional_flops_hook, 419 | F.prelu: _numel_functional_flops_hook, 420 | F.elu: _numel_functional_flops_hook, 421 | F.relu6: _numel_functional_flops_hook, 422 | F.gelu: _numel_functional_flops_hook, 423 | 424 | F.avg_pool1d: _numel_functional_flops_hook, 425 | F.avg_pool2d: _numel_functional_flops_hook, 426 | F.avg_pool3d: _numel_functional_flops_hook, 427 | F.max_pool1d: _numel_functional_flops_hook, 428 | F.max_pool2d: _numel_functional_flops_hook, 429 | F.max_pool3d: _numel_functional_flops_hook, 430 | F.adaptive_avg_pool1d: _numel_functional_flops_hook, 431 | F.adaptive_avg_pool2d: _numel_functional_flops_hook, 432 | F.adaptive_avg_pool3d: _numel_functional_flops_hook, 433 | F.adaptive_max_pool1d: _numel_functional_flops_hook, 434 | F.adaptive_max_pool2d: _numel_functional_flops_hook, 435 | F.adaptive_max_pool3d: _numel_functional_flops_hook, 436 | 437 | F.softmax: _numel_functional_flops_hook, 438 | 439 | F.upsample: _interpolate_functional_flops_hook, 440 | F.interpolate: _interpolate_functional_flops_hook, 441 | } 442 | 443 | if hasattr(F, "silu"): 444 | FUNCTIONAL_MAPPING[F.silu] = _numel_functional_flops_hook 445 | 446 | 447 | TENSOR_OPS_MAPPING = { 448 | torch.matmul: _matmul_tensor_flops_hook, 449 | torch.Tensor.matmul: _matmul_tensor_flops_hook, 450 | torch.mm: _matmul_tensor_flops_hook, 451 | torch.Tensor.mm: _matmul_tensor_flops_hook, 452 | torch.bmm: _matmul_tensor_flops_hook, 453 | torch.Tensor.bmm: _matmul_tensor_flops_hook, 454 | 455 | torch.addmm: _addmm_tensor_flops_hook, 456 | torch.baddbmm: _addmm_tensor_flops_hook, 457 | torch.Tensor.addmm: _addmm_tensor_flops_hook, 458 | 459 | torch.mul: _elementwise_tensor_flops_hook, 460 | torch.Tensor.mul: _elementwise_tensor_flops_hook, 461 | torch.add: _elementwise_tensor_flops_hook, 462 | torch.Tensor.add: _elementwise_tensor_flops_hook, 463 | } 464 | -------------------------------------------------------------------------------- /ptflops/utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2021-2023 Sovrasov V. - All Rights Reserved 3 | * You may use, distribute and modify this code under the 4 | * terms of the MIT license. 5 | * You should have received a copy of the MIT license with 6 | * this file. If not visit https://opensource.org/licenses/MIT 7 | ''' 8 | 9 | 10 | from typing import Optional 11 | 12 | 13 | def flops_to_string(flops: int, units: Optional[str] = None, precision: int = 2) -> str: 14 | """ 15 | Converts integer MACs representation to a readable string. 16 | 17 | :param flops: Input MACs. 18 | :param units: Units for string representation of MACs (GMac, MMac or KMac). 19 | :param precision: Floating point precision for representing MACs in 20 | given units. 21 | """ 22 | if units is None: 23 | if flops // 10**9 > 0: 24 | return str(round(flops / 10.**9, precision)) + ' GMac' 25 | elif flops // 10**6 > 0: 26 | return str(round(flops / 10.**6, precision)) + ' MMac' 27 | elif flops // 10**3 > 0: 28 | return str(round(flops / 10.**3, precision)) + ' KMac' 29 | else: 30 | return str(flops) + ' Mac' 31 | else: 32 | if units == 'GMac': 33 | return str(round(flops / 10.**9, precision)) + ' ' + units 34 | elif units == 'MMac': 35 | return str(round(flops / 10.**6, precision)) + ' ' + units 36 | elif units == 'KMac': 37 | return str(round(flops / 10.**3, precision)) + ' ' + units 38 | else: 39 | return str(flops) + ' Mac' 40 | 41 | 42 | def params_to_string(params_num: int, units: Optional[str] = None, 43 | precision: int = 2) -> str: 44 | """ 45 | Converts integer params representation to a readable string. 46 | 47 | :param flops: Input number of parameters. 48 | :param units: Units for string representation of params (M, K or B). 49 | :param precision: Floating point precision for representing params in 50 | given units. 51 | """ 52 | if units is None: 53 | if params_num // 10 ** 6 > 0: 54 | return str(round(params_num / 10 ** 6, precision)) + ' M' 55 | elif params_num // 10 ** 3: 56 | return str(round(params_num / 10 ** 3, precision)) + ' k' 57 | else: 58 | return str(params_num) 59 | else: 60 | if units == 'M': 61 | return str(round(params_num / 10.**6, precision)) + ' ' + units 62 | elif units == 'K': 63 | return str(round(params_num / 10.**3, precision)) + ' ' + units 64 | elif units == 'B': 65 | return str(round(params_num / 10.**9, precision)) + ' ' + units 66 | else: 67 | return str(params_num) 68 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools >= 61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "ptflops" 7 | version = "0.7.5" 8 | dependencies = [ 9 | "torch>=2.0", 10 | ] 11 | requires-python = ">=3.9" 12 | authors = [ 13 | {name = "Vladislav Sovrasov", email = "sovrasov.vlad@gmail.com"}, 14 | ] 15 | maintainers = [ 16 | {name = "Vladislav Sovrasov", email = "sovrasov.vlad@gmail.com"}, 17 | ] 18 | description = "Flops counter for neural networks in pytorch framework" 19 | readme = "README.md" 20 | license = {file = "LICENSE"} 21 | keywords = ["pytorch", "cnn", "transformer"] 22 | classifiers = [ 23 | "License :: OSI Approved :: MIT License", 24 | "Programming Language :: Python :: 3.9" 25 | ] 26 | 27 | [project.optional-dependencies] 28 | dev = [ 29 | "flake8==3.8.1", 30 | "flake8-import-order==0.18.1", 31 | "isort==4.3.21", 32 | "torchvision>=0.5.0", 33 | "pytest==7.1.2", 34 | "packaging", 35 | ] 36 | 37 | [project.urls] 38 | Homepage = "https://github.com/sovrasov/flops-counter.pytorch/" 39 | Documentation = "https://github.com/sovrasov/flops-counter.pytorch/blob/master/README.md" 40 | Repository = "https://github.com/sovrasov/flops-counter.pytorch.git" 41 | "Bug Tracker" = "https://github.com/sovrasov/flops-counter.pytorch/issues" 42 | Changelog = "https://github.com/sovrasov/flops-counter.pytorch/blob/master/CHANGELOG.md" 43 | 44 | [tool.setuptools.packages.find] 45 | include = ["ptflops"] -------------------------------------------------------------------------------- /samples/bert.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | from transformers import BertForSequenceClassification, BertTokenizer 5 | 6 | from ptflops import get_model_complexity_info 7 | 8 | 9 | def bert_input_constructor(input_shape, tokenizer): 10 | inp_seq = "" 11 | for _ in range(input_shape[1] - 2): # there are two special tokens [CLS] and [SEP] 12 | inp_seq += tokenizer.pad_token # let's use pad token to form a fake 13 | # sequence for subsequent flops calculation 14 | 15 | inputs = tokenizer([inp_seq] * input_shape[0], padding=True, truncation=True, 16 | return_tensors="pt") 17 | labels = torch.tensor([1] * input_shape[0]) 18 | # Batch size input_shape[0], sequence length input_shape[128] 19 | inputs = dict(inputs) 20 | inputs.update({"labels": labels}) 21 | return inputs 22 | 23 | 24 | if __name__ == '__main__': 25 | bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 26 | model = BertForSequenceClassification.from_pretrained('bert-base-uncased') 27 | flops_count, params_count = get_model_complexity_info( 28 | model, (2, 128), as_strings=True, 29 | input_constructor=partial(bert_input_constructor, tokenizer=bert_tokenizer), 30 | print_per_layer_stat=False) 31 | print('{:<30} {:<8}'.format('Computational complexity: ', flops_count)) 32 | print('{:<30} {:<8}'.format('Number of parameters: ', params_count)) 33 | 34 | # Output: 35 | # Computational complexity: 21.74 GMac 36 | # Number of parameters: 109.48 M 37 | -------------------------------------------------------------------------------- /samples/classification.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | 4 | import torch 5 | import torchvision 6 | from packaging import version 7 | from torchvision import models as models 8 | 9 | from ptflops import get_model_complexity_info 10 | 11 | pt_models = {'resnet18': models.resnet18, 12 | 'resnet50': models.resnet50, 13 | 'alexnet': models.alexnet, 14 | 'vgg16': models.vgg16, 15 | 'squeezenet': models.squeezenet1_0, 16 | 'densenet': models.densenet161, 17 | 'inception': models.inception_v3, 18 | 'convnext_base': models.convnext_base, 19 | 'vit_b_16': models.vit_b_16} 20 | 21 | if version.parse(torchvision.__version__) > version.parse('0.15'): 22 | pt_models['vit_b_16'] = models.vit_b_16 23 | 24 | 25 | if __name__ == '__main__': 26 | parser = argparse.ArgumentParser(description='ptflops sample script') 27 | parser.add_argument('--device', type=int, default=0, 28 | help='Device to store the model.') 29 | parser.add_argument('--model', choices=list(pt_models.keys()), 30 | type=str, default='resnet18') 31 | parser.add_argument('--backend', choices=list(['pytorch', 'aten']), 32 | type=str, default='pytorch') 33 | parser.add_argument('--result', type=str, default=None) 34 | args = parser.parse_args() 35 | 36 | if args.result is None: 37 | ost = sys.stdout 38 | else: 39 | ost = open(args.result, 'w') 40 | 41 | net = pt_models[args.model]() 42 | 43 | if torch.cuda.is_available(): 44 | net.cuda(device=args.device) 45 | 46 | if args.model == 'inception': 47 | input_res = (3, 299, 299) 48 | else: 49 | input_res = (3, 224, 224) 50 | 51 | macs, params = get_model_complexity_info(net, input_res, 52 | as_strings=True, 53 | backend=args.backend, 54 | print_per_layer_stat=True, 55 | ost=ost) 56 | print('{:<30} {:<8}'.format('Computational complexity: ', macs)) 57 | print('{:<30} {:<8}'.format('Number of parameters: ', params)) 58 | -------------------------------------------------------------------------------- /tests/common_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import torch.nn as nn 4 | 5 | from ptflops import get_model_complexity_info 6 | from ptflops.flops_counter import FLOPS_BACKEND 7 | 8 | 9 | class TestOperations: 10 | @pytest.fixture 11 | def default_input_image_size(self): 12 | return (3, 224, 224) 13 | 14 | @pytest.fixture 15 | def simple_model_mm(self): 16 | class CustomModel(nn.Module): 17 | def __init__(self): 18 | super().__init__() 19 | 20 | def forward(self, x): 21 | return x.matmul(x.t()) 22 | 23 | return CustomModel() 24 | 25 | @pytest.mark.parametrize("backend", [FLOPS_BACKEND.PYTORCH, FLOPS_BACKEND.ATEN]) 26 | def test_conv(self, default_input_image_size, backend: FLOPS_BACKEND): 27 | net = nn.Sequential(nn.Conv2d(3, 2, 3, bias=True)) 28 | macs, params = get_model_complexity_info(net, default_input_image_size, 29 | as_strings=False, 30 | print_per_layer_stat=False, 31 | backend=backend) 32 | 33 | assert params == 3 * 3 * 2 * 3 + 2 34 | assert macs == 2759904 35 | 36 | @pytest.mark.parametrize("backend", [FLOPS_BACKEND.PYTORCH, FLOPS_BACKEND.ATEN]) 37 | def test_fc(self, backend: FLOPS_BACKEND): 38 | net = nn.Sequential(nn.Linear(3, 2, bias=True)) 39 | macs, params = get_model_complexity_info(net, (3,), 40 | as_strings=False, 41 | print_per_layer_stat=False, 42 | backend=backend) 43 | 44 | assert params == 3 * 2 + 2 45 | assert macs == 8 46 | 47 | @pytest.mark.parametrize("backend", [FLOPS_BACKEND.PYTORCH, FLOPS_BACKEND.ATEN]) 48 | def test_fc_multidim(self, backend: FLOPS_BACKEND): 49 | net = nn.Sequential(nn.Linear(3, 2, bias=False)) 50 | macs, params = get_model_complexity_info(net, (4, 5, 3), 51 | as_strings=False, 52 | print_per_layer_stat=False, 53 | backend=backend) 54 | 55 | assert params == 3 * 2 56 | assert macs == (3 * 2) * 4 * 5 57 | 58 | def test_input_constructor_tensor(self): 59 | net = nn.Sequential(nn.Linear(3, 2, bias=True)) 60 | 61 | def input_constructor(input_res): 62 | return torch.ones(()).new_empty((1, *input_res)) 63 | 64 | macs, params = get_model_complexity_info(net, (3,), 65 | input_constructor=input_constructor, 66 | as_strings=False, 67 | print_per_layer_stat=False, 68 | backend=FLOPS_BACKEND.PYTORCH) 69 | 70 | assert (macs, params) == (8, 8) 71 | 72 | def test_input_constructor_dict(self): 73 | class CustomLinear(nn.Module): 74 | def __init__(self): 75 | super().__init__() 76 | self.linear = nn.Linear(3, 2, bias=True) 77 | 78 | def forward(self, x): 79 | return self.linear(x) 80 | 81 | def input_constructor(input_res): 82 | return dict(x=torch.ones(()).new_empty((1, *input_res))) 83 | 84 | macs, params = \ 85 | get_model_complexity_info(CustomLinear(), (3,), 86 | input_constructor=input_constructor, 87 | as_strings=False, 88 | print_per_layer_stat=False, 89 | backend=FLOPS_BACKEND.PYTORCH) 90 | 91 | assert (macs, params) == (8, 8) 92 | 93 | @pytest.mark.parametrize("out_size", [(20, 20), 20]) 94 | def test_func_interpolate_args(self, out_size): 95 | class CustomModel(nn.Module): 96 | def __init__(self): 97 | super().__init__() 98 | 99 | def forward(self, x): 100 | return nn.functional.interpolate(input=x, size=out_size, 101 | mode='bilinear', align_corners=False) 102 | 103 | macs, params = \ 104 | get_model_complexity_info(CustomModel(), (3, 10, 10), 105 | as_strings=False, 106 | print_per_layer_stat=False, 107 | backend=FLOPS_BACKEND.PYTORCH) 108 | 109 | assert params == 0 110 | assert macs == 1200 111 | 112 | CustomModel.forward = lambda self, x: nn.functional.interpolate(x, out_size, 113 | mode='bilinear') 114 | 115 | macs, params = \ 116 | get_model_complexity_info(CustomModel(), (3, 10, 10), 117 | as_strings=False, 118 | print_per_layer_stat=False, 119 | backend=FLOPS_BACKEND.PYTORCH) 120 | assert params == 0 121 | assert macs == 1200 122 | 123 | CustomModel.forward = lambda self, x: nn.functional.interpolate(x, scale_factor=2, 124 | mode='bilinear') 125 | 126 | macs, params = \ 127 | get_model_complexity_info(CustomModel(), (3, 10, 10), 128 | as_strings=False, 129 | print_per_layer_stat=False, 130 | backend=FLOPS_BACKEND.PYTORCH) 131 | assert params == 0 132 | assert macs == 1200 133 | 134 | def test_ten_matmul(self, simple_model_mm): 135 | macs, params = \ 136 | get_model_complexity_info(simple_model_mm, (10, ), 137 | as_strings=False, 138 | print_per_layer_stat=False, 139 | backend=FLOPS_BACKEND.PYTORCH) 140 | 141 | assert params == 0 142 | assert macs > 0 143 | 144 | def test_aten_ignore(self, simple_model_mm): 145 | ignored_list = [torch.ops.aten.matmul, torch.ops.aten.mm] 146 | macs, params = \ 147 | get_model_complexity_info(simple_model_mm, (10, ), backend=FLOPS_BACKEND.ATEN, 148 | as_strings=False, 149 | print_per_layer_stat=False, 150 | ignore_modules=ignored_list) 151 | 152 | assert params == 0 153 | assert macs == 0 154 | 155 | def test_aten_custom(self, simple_model_mm): 156 | reference = 42 157 | custom_hooks = {torch.ops.aten.mm: lambda inputs, outputs: reference} 158 | 159 | macs, params = \ 160 | get_model_complexity_info(simple_model_mm, (10, ), backend=FLOPS_BACKEND.ATEN, 161 | as_strings=False, 162 | print_per_layer_stat=False, 163 | custom_modules_hooks=custom_hooks) 164 | 165 | assert params == 0 166 | assert macs == reference 167 | 168 | def test_torch_ignore_func(self, simple_model_mm): 169 | macs, params = \ 170 | get_model_complexity_info(simple_model_mm, (10, ), 171 | backend=FLOPS_BACKEND.PYTORCH, 172 | as_strings=False, 173 | print_per_layer_stat=False, 174 | backend_specific_config={'count_functional': False}) 175 | 176 | assert params == 0 177 | assert macs == 0 178 | --------------------------------------------------------------------------------