├── .gitignore ├── LICENSE ├── README.md ├── setup.py └── tflite_flops ├── __init__.py ├── __main__.py └── calc.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 oVo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tflite-flops 2 | Roughly calculate FLOPs (floating-point operations) of a tflite format model. 3 | 4 | ### Install 5 | ``` 6 | pip3 install git+https://github.com/lisosia/tflite-flops 7 | ``` 8 | 9 | ### Usage 10 | ``` 11 | python3 -m tflite_flops example.tflite 12 | ``` 13 | 14 | ### Exapmle 15 | ``` 16 | wget https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_224.tgz 17 | tar xvf mobilenet_v2_1.0_224.tgz 18 | python3 -m tflite_flops ./mobilenet_v2_1.0_224.tflite 19 | ``` 20 | below lines printed 21 | ``` 22 | OP_NAME | M FLOPS 23 | ------------------------------ 24 | CONV_2D | 21.7 25 | DEPTHWISE_CONV_2D | 7.2 26 | CONV_2D | 12.8 27 | . 28 | . 29 | . 30 | CONV_2D | 2.6 31 | RESHAPE | 32 | SOFTMAX | 33 | ------------------------------ 34 | Total: 601.6 M FLOPS 35 | ``` 36 | 37 | ### How is it calculated? 38 | 39 | In the case of Conv layer 40 | ``` 41 | Multiply-Accumulate (MAC) = output_h * output_w * output_c * kernel_h * kernel_w * input_c 42 | (= output_h * output_w * weight_size) 43 | Floating-point operations (FLOPs) = 2 * MAC 44 | ``` 45 | 46 | ### Limitation 47 | 48 | Only Conv and DepthwiseConv layers are considered for now. It is enough for most of the time. 49 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name='tflite-flops', 5 | version='0.0.0', 6 | description='roughly calculate FLOPS of tflite format model', 7 | install_requires=['tflite'], 8 | packages=['tflite_flops'], 9 | python_requires='>=3.5' 10 | ) 11 | -------------------------------------------------------------------------------- /tflite_flops/__init__.py: -------------------------------------------------------------------------------- 1 | from tflite_flops.calc import calc_flops 2 | -------------------------------------------------------------------------------- /tflite_flops/__main__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import tflite_flops 4 | 5 | def main(): 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('input_model', help='Input TFLite model') 8 | 9 | args = parser.parse_args() 10 | 11 | tflite_flops.calc_flops(args.input_model) 12 | 13 | if __name__ == '__main__': 14 | main() 15 | -------------------------------------------------------------------------------- /tflite_flops/calc.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """calculate flops of tflite model. only conv and depthwise_conv considered 3 | depends on https://pypi.org/project/tflite/ 4 | reference code 5 | https://github.com/jackwish/tflite/blob/master/tests/test_mobilenet.py 6 | """ 7 | 8 | import sys 9 | import tflite 10 | 11 | 12 | def calc_flops(path): 13 | with open(path, 'rb') as f: 14 | buf = f.read() 15 | model = tflite.Model.GetRootAsModel(buf, 0) 16 | 17 | graph = model.Subgraphs(0) 18 | 19 | # help(tflite.BuiltinOperator) 20 | # ABS = 101 21 | # CONV_2D = 3 22 | # CUMSUM = 128 23 | 24 | # print funcs 25 | _dict_builtin_op_code_to_name = {v: k for k, v in tflite.BuiltinOperator.__dict__.items() if type(v) == int} 26 | def print_header(): 27 | print("%-18s | M FLOPS" % ("OP_NAME")) 28 | print("------------------------------") 29 | def print_flops(op_code_builtin, flops): 30 | print("%-18s | %.1f" % (_dict_builtin_op_code_to_name[op_code_builtin], flops / 1.0e6)) 31 | def print_none(op_code_builtin): 32 | print("%-18s | " % (_dict_builtin_op_code_to_name[op_code_builtin])) 33 | def print_footer(total_flops): 34 | print("------------------------------") 35 | print("Total: %.1f M FLOPS" % (total_flops / 1.0e6)) 36 | 37 | total_flops = 0.0 38 | print_header() 39 | for i in range(graph.OperatorsLength()): 40 | op = graph.Operators(i) 41 | op_code = model.OperatorCodes(op.OpcodeIndex()) 42 | op_code_builtin = op_code.BuiltinCode() 43 | 44 | op_opt = op.BuiltinOptions() 45 | 46 | flops = 0.0 47 | if op_code_builtin == tflite.BuiltinOperator.CONV_2D: 48 | # input shapes: in, weight, bias 49 | in_shape = graph.Tensors( op.Inputs(0) ).ShapeAsNumpy() 50 | filter_shape = graph.Tensors( op.Inputs(1) ).ShapeAsNumpy() 51 | bias_shape = graph.Tensors( op.Inputs(2) ).ShapeAsNumpy() 52 | # output shape 53 | out_shape = graph.Tensors( op.Outputs(0) ).ShapeAsNumpy() 54 | # ops options 55 | opt = tflite.Conv2DOptions() 56 | opt.Init(op_opt.Bytes, op_opt.Pos) 57 | # opt.StrideH() 58 | 59 | # flops. 2x means mul(1)+add(1). 2x not needed if you calculate MACCs 60 | # refer to https://github.com/AlexeyAB/darknet/src/convolutional_layer.c `l.blopfs =` 61 | flops = 2 * out_shape[1] * out_shape[2] * filter_shape[0] * filter_shape[1] * filter_shape[2] * filter_shape[3] 62 | print_flops(op_code_builtin, flops) 63 | 64 | elif op_code_builtin == tflite.BuiltinOperator.DEPTHWISE_CONV_2D: 65 | in_shape = graph.Tensors( op.Inputs(0) ).ShapeAsNumpy() 66 | filter_shape = graph.Tensors( op.Inputs(1) ).ShapeAsNumpy() 67 | out_shape = graph.Tensors( op.Outputs(0) ).ShapeAsNumpy() 68 | # flops 69 | flops = 2 * out_shape[1] * out_shape[2] * filter_shape[0] * filter_shape[1] * filter_shape[2] * filter_shape[3] 70 | print_flops(op_code_builtin, flops) 71 | 72 | else: 73 | print_none(op_code_builtin) 74 | 75 | total_flops += flops 76 | print_footer(total_flops) 77 | 78 | if __name__ == "__main__": 79 | path = sys.argv[1] 80 | calc_flops(path) 81 | 82 | ########################################################################## 83 | # darknet 84 | # maxpool_layer.c 85 | # l.bflops = (l.size*l.size*l.c * l.out_h*l.out_w) / 1000000000.; 86 | # convolutional_layer.c 87 | # l.bflops = (2.0 * l.nweights * l.out_h*l.out_w) / 1000000000.; 88 | # shortcut_layer.c 89 | # l.bflops = l.out_w * l.out_h * l.out_c * l.n / 1000000000.; 90 | 91 | # netron (not refered) 92 | # https://github.com/lutzroeder/netron/tree/main/source 93 | --------------------------------------------------------------------------------