├── .gitignore ├── LICENCE ├── README.md ├── compare1.py ├── compare2.py ├── example1.py ├── example2.py ├── requirements.txt ├── screen_shot.png ├── setup.py └── torch_flops ├── __init__.py ├── flops_engine.py └── flops_ops.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.pyc 3 | build/ 4 | dist/ 5 | *.egg-info/ 6 | TODO.md 7 | temp/ 8 | package_*.sh -------------------------------------------------------------------------------- /LICENCE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Yue Lu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # torch_flops 2 | 3 | ## Introduction 4 | [torch_flops中文介绍 - 知乎](https://zhuanlan.zhihu.com/p/663566912) 5 | 6 | This is a library for calculating FLOPs of pytorch models. Compared with other libraries such as [`thop`](https://github.com/Lyken17/pytorch-OpCounter), [`ptflops`](https://github.com/sovrasov/flops-counter.pytorch), [`torchinfo`](https://github.com/TylerYep/torchinfo) and [`torchanalyse`](https://github.com/HaoKang-Timmy/torchanalyse), the **advantage of this library** is that it can capture **all calculation operations** in the `forward` process, **not limited to only the subclasses of** `nn.Module`. 7 | 8 | **Update Note**: Introducing support for displaying the **execution time** of each operation. Please use `flops_counter.print_result_table()` to see the detailed results. 9 | 10 | **Update Note**: Introducing support for displaying the **GPU memory usage** of each operation. In the result table, `mem_before_op`, `mem_after_op` represent the memories (counted using `torch.cuda.max_memory_allocated()` `(default)` or `torch.cuda.memory_allocated()`) before and after the operation. `mem_delta` represent the difference between `mem_after_op` and `mem_before_op`. Please note that just run one model each time in a program in order to obtain accurate memory statistics. 11 | 12 | 13 | ## Usage 14 | ### Installation 15 | ``` 16 | pip install torch_flops -i https://pypi.org/simple 17 | ``` 18 | 19 | ### Requirements 20 | 21 | + python >= 3.10 (for new python features) 22 | + pytorch >= 2.0 (for `torch.fx` support) 23 | + tabulate (for printing the summary of operations) 24 | 25 | ### Example 1 26 | An expamle for calculating the FLOPs of ViT-base16 and ResNet-50 is given in [`example1.py`](example1.py). The example requires the [`timm`](https://github.com/huggingface/pytorch-image-models) library. You can calculate the FLOPs in three lines: 27 | ```python 28 | # NOTE: First run the model once for accurate time measurement in the following process. 29 | # The input `x` and the model should be placed on GPU for memory measurement. 30 | with torch.no_grad(): 31 | model(x) 32 | # Initialize the `TorchFLOPsByFX`. Please read the doc of the class for initialization options. 33 | flops_counter = TorchFLOPsByFX(model) 34 | # Feed the input tensor to the model 35 | flops_counter.propagate(x) 36 | # Print the full result table. It also returns the detailed result of each operation in a 2D list. 37 | result_table = flops_counter.print_result_table() 38 | # Print FLOPs, execution time and max GPU memory. 39 | total_flops = flops_counter.print_total_flops(show=True) 40 | total_time = flops_counter.print_total_time() 41 | max_memory = flops_counter.print_max_memory() 42 | ``` 43 | The output of `example1.py` is: 44 | ``` 45 | ========== vit_base16 ========== 46 | total_flops = 35,164,979,282 47 | total_time = 14.015 ms 48 | max_memory = 362,289,152 Bytes 49 | ========== resnet50 ========== 50 | total_flops = 8,227,340,288 51 | total_time = 10.867 ms 52 | max_memory = 249,894,400 Bytes 53 | ``` 54 | ![image](./screen_shot.png) 55 | 56 | ### Example 2 57 | Another example of calculating the FLOPs for an attention block is provided in [`example2.py`](example2.py). However, You can define a simple model to check the result (see [`compare.py`](compare.py)). 58 | 59 | ```python 60 | C = 768 61 | device = 'cuda:0' 62 | 63 | # Define the model: an attention block (refer to "timm": https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py) 64 | block = Block(C, num_heads=2, qkv_bias=True) 65 | block.attn.fused_attn = False 66 | block.eval() 67 | model = block 68 | 69 | # Input 70 | # N: number of tokens 71 | N = 14**2 + 1 72 | B = 1 73 | x = torch.randn([B, N, C]).to(device) 74 | model.to(device) 75 | 76 | # NOTE: First run the model once for accurate time measurement in the following process. 77 | with torch.no_grad(): 78 | model(x) 79 | 80 | # Output 81 | # Build the graph of the model. You can specify the operations (listed in `MODULE_FLOPs_MAPPING`, `FUNCTION_FLOPs_MAPPING` and `METHOD_FLOPs_MAPPING` in 'flops_ops.py') to ignore. 82 | flops_counter = TorchFLOPsByFX(model) 83 | # Print the grath (not essential) 84 | print('*' * 120) 85 | flops_counter.graph_model.graph.print_tabular() 86 | # Feed the input tensor 87 | with torch.no_grad(): 88 | flops_counter.propagate(x) 89 | # Print the flops of each node in the graph. Note that if there are unsupported operations, the "flops" of these ops will be marked as 'not recognized'. 90 | print('*' * 120) 91 | result_table = flops_counter.print_result_table() 92 | # Print the total FLOPs 93 | total_flops = flops_counter.print_total_flops() 94 | total_time = flops_counter.print_total_time() 95 | max_memory = flops_counter.print_max_memory() 96 | ``` 97 | You can also feed more than one sequential arguments for the model in `propagate()` if the `model.forward()` function need not only one arguments. 98 | 99 | # Advantage 100 | `torch_flops` can capture all the operations excuted in the forward including the operations not wrapped by `nn.Module`, like `torch.matmul`, `@`, `+` and `tensor.exp`, and it can ignore the FLOPs of the modules not used in the forward process. 101 | 102 | There is a comparison of `torch_flops` (this repo), `torchanalyse`, `thop` and `ptflops` in the script [`compare.py`](compare.py). 103 | The output of 104 | 105 | `python compare.py`: 106 | 107 | ``` 108 | **************************************** Model **************************************** 109 | SimpleModel( 110 | (layer): Linear(in_features=5, out_features=4, bias=True) 111 | ) 112 | tensor([[-0.2077, 0.2623, 1.3978, -0.4170]], grad_fn=) 113 | ================================================================================ 114 | **************************************** torch_flops **************************************** 115 | =========== =========== =========== ===================== ======= 116 | node_name node_op op_target nn_module_stack[-1] flops 117 | =========== =========== =========== ===================== ======= 118 | x placeholder x 0 119 | layer call_module layer Linear 40 120 | output output output 0 121 | =========== =========== =========== ===================== ======= 122 | torch_flops: 40 FLOPs 123 | ================================================================================ 124 | **************************************** torchanalyse **************************************** 125 | torchanalyse: 40 FLOPs 126 | ================================================================================ 127 | **************************************** thop **************************************** 128 | [INFO] Register count_linear() for . 129 | thop: 20 MACs 130 | ================================================================================ 131 | **************************************** ptflops **************************************** 132 | Warning: module SimpleModel is treated as a zero-op. 133 | SimpleModel( 134 | 24, 100.000% Params, 24.0 Mac, 100.000% MACs, 135 | (layer): Linear(24, 100.000% Params, 24.0 Mac, 100.000% MACs, in_features=5, out_features=4, bias=True) 136 | ) 137 | ptflops: 24 MACs 138 | ================================================================================ 139 | ``` 140 | 141 | Now let's add an operation `x += 1.` in `forward()`. 142 | The output of 143 | 144 | `python compare.py --add_one`: 145 | 146 | ``` 147 | **************************************** Model **************************************** 148 | SimpleModel( 149 | (layer): Linear(in_features=5, out_features=4, bias=True) 150 | ) 151 | tensor([[1.0426, 0.6963, 1.7114, 1.6526]], grad_fn=) 152 | ================================================================================ 153 | **************************************** torch_flops **************************************** 154 | =========== ============= ======================= ===================== ======= 155 | node_name node_op op_target nn_module_stack[-1] flops 156 | =========== ============= ======================= ===================== ======= 157 | x placeholder x 0 158 | layer call_module layer Linear 40 159 | add call_function 4 160 | output output output 0 161 | =========== ============= ======================= ===================== ======= 162 | torch_flops: 44 FLOPs 163 | ================================================================================ 164 | **************************************** torchanalyse **************************************** 165 | torchanalyse: 40 FLOPs 166 | ================================================================================ 167 | **************************************** thop **************************************** 168 | [INFO] Register count_linear() for . 169 | thop: 20 MACs 170 | ================================================================================ 171 | **************************************** ptflops **************************************** 172 | Warning: module SimpleModel is treated as a zero-op. 173 | SimpleModel( 174 | 24, 100.000% Params, 24.0 Mac, 100.000% MACs, 175 | (layer): Linear(24, 100.000% Params, 24.0 Mac, 100.000% MACs, in_features=5, out_features=4, bias=True) 176 | ) 177 | ptflops: 24 MACs 178 | ================================================================================ 179 | ``` 180 | 181 | **It can be seen that only `torch_flops` can capture the FLOPs of `x+=1`!** 182 | 183 | `torchinfo` is not compared here but it does not have this ability. We also find that some of the other libraries cannot calculate the FLOPs of the `bias` item in `nn.Linear` using `python compare.py --linear_no_bias`. 184 | 185 | 186 | # Supported Operations 187 | The supported operations are listed in the following (the keys of the dicts), which can also be seen in [`flops_ops.py`](torch_flops/flops_ops.py). 188 | Note that in addtion to the modules inherited from `nn.Module` (e.g. `nn.Linear`), the function (e.g. `@`, `+`, `torch.softmax`) and method operations (e.g. `tensor.softmax`) are also supported! 189 | 190 | ```python 191 | MODULE_FLOPs_MAPPING = { 192 | 'Linear': ModuleFLOPs_Linear, 193 | 'Identity': ModuleFLOPs_zero, 194 | 'Conv1d': ModuleFLOPs_ConvNd, 195 | 'Conv2d': ModuleFLOPs_ConvNd, 196 | 'Conv3d': ModuleFLOPs_ConvNd, 197 | 'AvgPool1d': ModuleFLOPs_AvgPoolNd, 198 | 'AvgPool2d': ModuleFLOPs_AvgPoolNd, 199 | 'AvgPool3d': ModuleFLOPs_AvgPoolNd, 200 | 'AdaptiveAvgPool1d': ModuleFLOPs_AdaptiveAvgPoolNd, 201 | 'AdaptiveAvgPool2d': ModuleFLOPs_AdaptiveAvgPoolNd, 202 | 'AdaptiveAvgPool3d': ModuleFLOPs_AdaptiveAvgPoolNd, 203 | 'MaxPool1d': ModuleFLOPs_MaxPoolNd, 204 | 'MaxPool2d': ModuleFLOPs_MaxPoolNd, 205 | 'MaxPool3d': ModuleFLOPs_MaxPoolNd, 206 | 'AdaptiveMaxPool1d': ModuleFLOPs_AdaptiveMaxPoolNd, 207 | 'AdaptiveMaxPool2d': ModuleFLOPs_AdaptiveMaxPoolNd, 208 | 'AdaptiveMaxPool3d': ModuleFLOPs_AdaptiveMaxPoolNd, 209 | 'LayerNorm': ModuleFLOPs_Norm, 210 | 'BatchNorm1d': ModuleFLOPs_Norm, 211 | 'BatchNorm2d': ModuleFLOPs_Norm, 212 | 'BatchNorm3d': ModuleFLOPs_Norm, 213 | 'InstanceNorm1d': ModuleFLOPs_Norm, 214 | 'InstanceNorm2d': ModuleFLOPs_Norm, 215 | 'InstanceNorm3d': ModuleFLOPs_Norm, 216 | 'GroupNorm': ModuleFLOPs_Norm, 217 | 'Dropout': ModuleFLOPs_zero, 218 | 'GELU': ModuleFLOPs_GELU, 219 | 'ReLU': ModuleFLOPs_elemwise, 220 | 'Flatten': ModuleFLOPs_zero, 221 | } 222 | FUNCTION_FLOPs_MAPPING = { 223 | 'getattr': FunctionFLOPs_zero, 224 | 'getitem': FunctionFLOPs_zero, 225 | 'mul': FunctionFLOPs_elemwise, 226 | 'truediv': FunctionFLOPs_elemwise, 227 | 'sub': FunctionFLOPs_elemwise, 228 | 'matmul': FunctionFLOPs_matmul, 229 | 'add': FunctionFLOPs_elemwise, 230 | 'concat': FunctionFLOPs_zero, 231 | '_assert': FunctionFLOPs_zero, 232 | 'eq': FunctionFLOPs_elemwise, 233 | 'cat': FunctionFLOPs_zero, 234 | 'linear': FunctionFLOPs_linear, 235 | } 236 | METHOD_FLOPs_MAPPING = { 237 | 'reshape': MethodFLOPs_zero, 238 | 'permute': MethodFLOPs_zero, 239 | 'unbind': MethodFLOPs_zero, 240 | 'transpose': MethodFLOPs_zero, 241 | 'repeat': MethodFLOPs_zero, 242 | 'unsqueeze': MethodFLOPs_zero, 243 | 'exp': MethodFLOPs_elemwise, 244 | 'sum': MethodFLOPs_sum, 245 | 'div': MethodFLOPs_elemwise, 246 | 'softmax': MethodFLOPs_softmax, 247 | 'expand': MethodFLOPs_zero, 248 | 'flatten': MethodFLOPs_zero, 249 | } 250 | ``` 251 | However, not all the operations in pytorch have been considered since it spends a lot of effort. If you need to add support for a certain operation, please raise an issue. You are also welcome to add more features to this repository. 252 | 253 | # Limitations 254 | `torch.fx` can capture all the operations in the forward process, but it requires a high version of pytorch. However, we recommod you to use the newer version of pytorch (>=2.0) to try the new features. 255 | 256 | When using `torch.fx`, the model should be able to successfully transformed into a [`graph_model`](torch_flops/flops_engine.py#L317) by [`symbolic_trace()`](https://pytorch.org/docs/stable/fx.html#torch.fx.symbolic_trace). Dynamic control flow is not supported in the `forward` function. Please refer to https://pytorch.org/docs/stable/fx.html#limitations-of-symbolic-tracing for more information. 257 | 258 | There are many operations not implemented so far. However, you can raise an issue or contact me (zgxd@mail.nwpu.edu.cn) to add new operations. 259 | 260 | # Acknowledgements 261 | 262 | `pytorch`: https://github.com/pytorch/pytorch 263 | 264 | `timm`: https://github.com/huggingface/pytorch-image-models 265 | 266 | `torchscan`: https://frgfm.github.io/torch-scan/index.html 267 | 268 | `torchprofile`: https://github.com/zhijian-liu/torchprofile 269 | -------------------------------------------------------------------------------- /compare1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | 4 | from torch_flops import TorchFLOPsByFX 5 | import torchanalyse 6 | from thop import profile 7 | from ptflops import get_model_complexity_info 8 | 9 | import argparse 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--linear_no_bias", action='store_true') 13 | parser.add_argument("--add_one", action='store_true') 14 | inp_args = parser.parse_args() 15 | 16 | 17 | class SimpleModel(nn.Module): 18 | def __init__(self, args) -> None: 19 | super().__init__() 20 | self.layer = nn.Linear(5, 4, bias=not args.linear_no_bias) 21 | self.__add_one = args.add_one 22 | 23 | def forward(self, x: Tensor): 24 | x = self.layer(x) 25 | if self.__add_one: 26 | x += 1. 27 | return x 28 | 29 | 30 | if __name__ == "__main__": 31 | model = SimpleModel(inp_args).cuda() 32 | x = torch.randn(1, 5).cuda() 33 | y = model(x) 34 | print("*" * 40 + " Model " + "*" * 40) 35 | print(model) 36 | print(y) 37 | print("=" * 80) 38 | 39 | # ========= 40 | print("*" * 40 + " torch_flops " + "*" * 40) 41 | flops_counter = TorchFLOPsByFX(model) 42 | # flops_counter.graph_model.graph.print_tabular() 43 | flops_counter.propagate(x) 44 | flops_counter.print_result_table() 45 | flops_1 = flops_counter.print_total_flops(show=False) 46 | print(f"torch_flops: {flops_1} FLOPs") 47 | print("=" * 80) 48 | 49 | # ========= 50 | print("*" * 40 + " torchanalyse " + "*" * 40) 51 | unit = torchanalyse.Unit(unit_flop='mFLOP') 52 | system = torchanalyse.System( 53 | unit, 54 | frequency=940, 55 | flops=123, 56 | onchip_mem_bw=900, 57 | pe_min_density_support=0.0001, 58 | accelerator_type="structured", 59 | model_on_chip_mem_implications=False, 60 | on_chip_mem_size=32, 61 | ) 62 | result_2 = torchanalyse.profiler(model, x, system, unit) 63 | flops_2 = sum(result_2['Flops (mFLOP)'].values) / 1e3 64 | print(f"torchanalyse: {flops_2:.0f} FLOPs") 65 | print("=" * 80) 66 | 67 | # ========= 68 | print("*" * 40 + " thop " + "*" * 40) 69 | macs_1, params = profile(model, inputs=(x, )) 70 | print(f"thop: {macs_1:.0f} MACs") 71 | print("=" * 80) 72 | 73 | # ========= 74 | print("*" * 40 + " ptflops " + "*" * 40) 75 | macs_2, params = get_model_complexity_info(model, tuple(x.shape), as_strings=False, 76 | print_per_layer_stat=True, verbose=True) 77 | print(f"ptflops: {macs_2:.0f} MACs") 78 | print("=" * 80) 79 | -------------------------------------------------------------------------------- /compare2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | 4 | from torch_flops import TorchFLOPsByFX 5 | from torch.utils.flop_counter import FlopCounterMode 6 | import argparse 7 | 8 | ''' 9 | Comparion with torch.utils.flop_counter 10 | ''' 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--linear_no_bias", action='store_true') 14 | parser.add_argument("--add_one", action='store_true') 15 | inp_args = parser.parse_args() 16 | 17 | 18 | class SimpleModel(nn.Module): 19 | def __init__(self, args) -> None: 20 | super().__init__() 21 | self.layer = nn.Linear(5, 4, bias=not args.linear_no_bias) 22 | self.__add_one = args.add_one 23 | 24 | def forward(self, x: Tensor): 25 | x = self.layer(x) 26 | if self.__add_one: 27 | x += 1. 28 | return x 29 | 30 | 31 | if __name__ == "__main__": 32 | model = SimpleModel(inp_args).cuda() 33 | model.requires_grad_(False) 34 | model.eval() 35 | x = torch.randn(1, 5).cuda() 36 | y = model(x) 37 | print("*" * 40 + " Model " + "*" * 40) 38 | print(model) 39 | print(y) 40 | print("=" * 80) 41 | 42 | # ========= 43 | print("*" * 40 + " torch_flops " + "*" * 40) 44 | flops_counter = TorchFLOPsByFX(model) 45 | # flops_counter.graph_model.graph.print_tabular() 46 | flops_counter.propagate(x) 47 | flops_counter.print_result_table() 48 | flops_1 = flops_counter.print_total_flops(show=False) 49 | print(f"torch_flops: {flops_1} FLOPs") 50 | print("=" * 80) 51 | 52 | # ========= 53 | print("*" * 40 + " torch.utils.flop_counter " + "*" * 40) 54 | flops_counter = FlopCounterMode(model, depth=None, display=False) 55 | with flops_counter: 56 | model(x) 57 | flops_2 = flops_counter.get_total_flops() 58 | print(f"torch.utils.flop_counter: {flops_2} FLOPs") 59 | print("=" * 80) 60 | -------------------------------------------------------------------------------- /example1.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['TIMM_FUSED_ATTN'] = "0" 3 | 4 | import torch 5 | from torch import Tensor 6 | import timm 7 | import warnings 8 | warnings.filterwarnings('ignore') 9 | from typing import Literal 10 | import argparse 11 | 12 | from torch_flops import TorchFLOPsByFX 13 | 14 | ''' 15 | Count the FLOPs of ViT-B16 and ResNet-50. 16 | ''' 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--model', choices=('vitb16', 'resnet50'), default='vitb16') 20 | args = parser.parse_args() 21 | 22 | if __name__ == "__main__": 23 | device = 'cuda:0' 24 | # Input 25 | x = torch.randn([1, 3, 224, 224]).to(device) 26 | model_arch: Literal['vitb16', 'resnet50'] = args.model 27 | 28 | if model_arch == 'vitb16': 29 | print("=" * 10, "vit_base16", "=" * 10) 30 | # Define the models 31 | vit = timm.create_model('vit_base_patch16_224').to(device) 32 | 33 | # NOTE: First run the model once for accurate time measurement in the following process. 34 | with torch.no_grad(): 35 | vit(x) 36 | 37 | with torch.no_grad(): 38 | # Build the graph of the model. You can specify the operations (listed in `MODULE_FLOPs_MAPPING`, `FUNCTION_FLOPs_MAPPING` and `METHOD_FLOPs_MAPPING` in 'flops_ops.py') to ignore. 39 | flops_counter = TorchFLOPsByFX(vit) 40 | # # Print the grath (not essential) 41 | # print('*' * 120) 42 | # flops_counter.graph_model.graph.print_tabular() 43 | # Feed the input tensor 44 | flops_counter.propagate(x) 45 | # # Print the flops of each node in the graph. Note that if there are unsupported operations, the "flops" of these ops will be marked as 'not recognized'. 46 | print('*' * 120) 47 | result_table = flops_counter.print_result_table() 48 | # # Print the total FLOPs 49 | total_flops = flops_counter.print_total_flops(show=True) 50 | total_time = flops_counter.print_total_time() 51 | max_memory = flops_counter.print_max_memory() 52 | elif model_arch == 'resnet50': 53 | print("=" * 10, "resnet50", "=" * 10) 54 | resnet = timm.create_model('resnet50').to(device) 55 | 56 | # NOTE: First run the model once for accurate time measurement in the following process. 57 | with torch.no_grad(): 58 | resnet(x) 59 | 60 | with torch.no_grad(): 61 | flops_counter = TorchFLOPsByFX(resnet) 62 | flops_counter.propagate(x) 63 | result_table = flops_counter.print_result_table() 64 | total_flops = flops_counter.print_total_flops(show=True) 65 | total_time = flops_counter.print_total_time() 66 | max_memory = flops_counter.print_max_memory() 67 | -------------------------------------------------------------------------------- /example2.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['TIMM_FUSED_ATTN'] = "0" 3 | import torch 4 | from torch import Tensor 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.jit import Final 8 | from timm.layers import use_fused_attn, Mlp, DropPath 9 | 10 | from torch_flops import TorchFLOPsByFX 11 | 12 | 13 | class Attention(nn.Module): 14 | ''' 15 | REF: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py 16 | ''' 17 | fused_attn: Final[bool] 18 | 19 | def __init__( 20 | self, 21 | dim, 22 | num_heads=8, 23 | qkv_bias=False, 24 | qk_norm=False, 25 | attn_drop=0., 26 | proj_drop=0., 27 | norm_layer=nn.LayerNorm, 28 | ): 29 | super().__init__() 30 | assert dim % num_heads == 0, 'dim should be divisible by num_heads' 31 | self.num_heads = num_heads 32 | self.head_dim = dim // num_heads 33 | self.scale = self.head_dim ** -0.5 34 | self.fused_attn = use_fused_attn() 35 | 36 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 37 | self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() 38 | self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() 39 | self.attn_drop = nn.Dropout(attn_drop) 40 | self.proj = nn.Linear(dim, dim) 41 | self.proj_drop = nn.Dropout(proj_drop) 42 | 43 | @torch.no_grad() 44 | def forward(self, x): 45 | B, N, C = x.shape 46 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) 47 | q, k, v = qkv.unbind(0) 48 | q, k = self.q_norm(q), self.k_norm(k) 49 | 50 | if self.fused_attn: 51 | x = F.scaled_dot_product_attention( 52 | q, k, v, 53 | dropout_p=self.attn_drop.p, 54 | ) 55 | else: 56 | q = q * self.scale 57 | attn = q @ k.transpose(-2, -1) 58 | attn = attn.softmax(dim=-1) 59 | attn = self.attn_drop(attn) 60 | x = attn @ v 61 | 62 | x = x.transpose(1, 2).reshape(B, N, C) 63 | x = self.proj(x) 64 | x = self.proj_drop(x) 65 | return x 66 | 67 | 68 | class LayerScale(nn.Module): 69 | def __init__(self, dim, init_values=1e-5, inplace=False): 70 | super().__init__() 71 | self.inplace = inplace 72 | self.gamma = nn.Parameter(init_values * torch.ones(dim)) 73 | 74 | def forward(self, x): 75 | return x.mul_(self.gamma) if self.inplace else x * self.gamma 76 | 77 | 78 | class Block(nn.Module): 79 | def __init__( 80 | self, 81 | dim, 82 | num_heads, 83 | mlp_ratio=4., 84 | qkv_bias=False, 85 | qk_norm=False, 86 | proj_drop=0., 87 | attn_drop=0., 88 | init_values=None, 89 | drop_path=0., 90 | act_layer=nn.GELU, 91 | norm_layer=nn.LayerNorm, 92 | mlp_layer=Mlp, 93 | ): 94 | super().__init__() 95 | self.norm1 = norm_layer(dim) 96 | self.attn = Attention( 97 | dim, 98 | num_heads=num_heads, 99 | qkv_bias=qkv_bias, 100 | qk_norm=qk_norm, 101 | attn_drop=attn_drop, 102 | proj_drop=proj_drop, 103 | norm_layer=norm_layer, 104 | ) 105 | self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 106 | self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() 107 | 108 | self.norm2 = norm_layer(dim) 109 | self.mlp = mlp_layer( 110 | in_features=dim, 111 | hidden_features=int(dim * mlp_ratio), 112 | act_layer=act_layer, 113 | drop=proj_drop, 114 | ) 115 | self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 116 | self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() 117 | 118 | def forward(self, x, y): 119 | x = x + y 120 | x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x)))) 121 | x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) 122 | return x 123 | 124 | 125 | if __name__ == "__main__": 126 | C = 768 127 | device = 'cuda:0' 128 | 129 | # Define the model: an attention block (refer to "timm": https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py) 130 | block = Block(C, num_heads=2, qkv_bias=True) 131 | block.attn.fused_attn = False 132 | block.eval() 133 | model = block 134 | 135 | # Input 136 | # N: number of tokens 137 | N = 14**2 + 1 138 | B = 1 139 | x = torch.randn([B, N, C]).to(device) 140 | y = torch.randn([B, N, C]).to(device) 141 | model.to(device) 142 | 143 | # NOTE: First run the model once for accurate time measurement in the following process. 144 | with torch.no_grad(): 145 | model(x, y) 146 | 147 | # Output 148 | # Build the graph of the model. You can specify the operations (listed in `MODULE_FLOPs_MAPPING`, `FUNCTION_FLOPs_MAPPING` and `METHOD_FLOPs_MAPPING` in 'flops_ops.py') to ignore. 149 | flops_counter = TorchFLOPsByFX(model) 150 | # Print the grath (not essential) 151 | print('*' * 120) 152 | flops_counter.graph_model.graph.print_tabular() 153 | # Feed the input tensor 154 | with torch.no_grad(): 155 | flops_counter.propagate(x, y) 156 | # Print the flops of each node in the graph. Note that if there are unsupported operations, the "flops" of these ops will be marked as 'not recognized'. 157 | print('*' * 120) 158 | result_table = flops_counter.print_result_table() 159 | # Print the total FLOPs 160 | total_flops = flops_counter.print_total_flops() 161 | total_time = flops_counter.print_total_time() 162 | max_memory = flops_counter.print_max_memory() 163 | flops_counter.save_result_to_csv("./result.csv", 'w') 164 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=2.0.0 2 | tabulate 3 | timm -------------------------------------------------------------------------------- /screen_shot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zugexiaodui/torch_flops/9844a27224c716941c8382bcd53fdb66f83969b4/screen_shot.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | pkg_version = 'UNKNOWN' 4 | with open("./torch_flops/__init__.py") as f: 5 | for line in f.readlines(): 6 | if line.startswith('__version__'): 7 | pkg_version = line.strip('\n').split('\'')[-2] 8 | readme_path = "README.md" 9 | setup( 10 | name="torch_flops", 11 | version=pkg_version, 12 | author="Yue Lu", 13 | author_email="luyue163@126.com", 14 | description="A library for calculating the FLOPs in the forward() process based on torch.fx", 15 | long_description=open(readme_path, encoding='utf-8').read(), 16 | long_description_content_type='text/markdown', 17 | url="https://github.com/zugexiaodui/torch_flops", 18 | data_files=[readme_path], 19 | requires=["python(>=3.10)", "torch(>=2.0)", "tabulate"], 20 | # install_requires=["torch>=1.8", "tabulate"], 21 | # python_requires=">=3.10", 22 | license=open("./LICENCE", encoding='utf-8').read() 23 | ) 24 | -------------------------------------------------------------------------------- /torch_flops/__init__.py: -------------------------------------------------------------------------------- 1 | from .flops_engine import TorchFLOPsByFX 2 | __version__ = '0.3.6' 3 | -------------------------------------------------------------------------------- /torch_flops/flops_engine.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.fx 4 | from torch.fx import symbolic_trace 5 | from torch.fx.node import Argument, Node, Target, map_aggregate 6 | from torch.fx._compatibility import compatibility 7 | from torch.fx.graph_module import GraphModule 8 | import traceback 9 | from tabulate import tabulate 10 | from typing import Any, Tuple, NamedTuple, Optional, Dict, Sequence, Literal 11 | from copy import deepcopy 12 | import time 13 | import csv 14 | 15 | from torch_flops.flops_ops import MODULE_FLOPs_MAPPING, METHOD_FLOPs_MAPPING, FUNCTION_FLOPs_MAPPING 16 | 17 | ''' 18 | REF: https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/shape_prop.py 19 | ''' 20 | 21 | __all__ = ['TorchFLOPsByFX'] 22 | 23 | 24 | @compatibility(is_backward_compatible=True) 25 | class TensorMetadata(NamedTuple): 26 | # TensorMetadata is a structure containing pertinent information 27 | # about a tensor within a PyTorch program. 28 | 29 | # General Tensor metadata 30 | shape: torch.Size 31 | dtype: torch.dtype 32 | requires_grad: bool 33 | stride: Tuple[int, ...] 34 | memory_format: Optional[torch.memory_format] 35 | is_quantized: bool 36 | qparams: Dict[str, Any] 37 | 38 | 39 | def _extract_tensor_metadata(result: torch.Tensor) -> TensorMetadata: 40 | """ 41 | Extract a TensorMetadata NamedTuple describing `result`. 42 | """ 43 | shape = result.shape 44 | dtype = result.dtype 45 | requires_grad = result.requires_grad 46 | stride = result.stride() 47 | 48 | memory_formats = { 49 | torch.contiguous_format, 50 | torch.channels_last, 51 | torch.channels_last_3d, 52 | } 53 | 54 | memory_format = None 55 | 56 | for query_format in memory_formats: 57 | if result.is_contiguous(memory_format=query_format): 58 | memory_format = query_format 59 | break 60 | 61 | is_quantized = result.is_quantized 62 | qparams: Dict[str, Any] = {} 63 | if is_quantized: 64 | qscheme = result.qscheme() 65 | qparams["qscheme"] = qscheme 66 | if qscheme in {torch.per_tensor_affine, torch.per_tensor_symmetric}: 67 | qparams["scale"] = result.q_scale() # type: ignore[assignment] 68 | qparams["zero_point"] = result.q_zero_point() # type: ignore[assignment] 69 | elif qscheme in {torch.per_channel_affine, torch.per_channel_affine_float_qparams, torch.per_channel_symmetric}: 70 | # In this branch, scale and zero_point are expected to be tensors, 71 | # we store the values as immutable_list in TensorMetadata for 72 | # easier serialization downstream 73 | qparams["scale"] = result.q_per_channel_scales().tolist() # type: ignore[assignment] 74 | qparams["zero_point"] = result.q_per_channel_zero_points().tolist() # type: ignore[assignment] 75 | qparams["axis"] = result.q_per_channel_axis() # type: ignore[assignment] 76 | 77 | return TensorMetadata( 78 | shape, dtype, requires_grad, stride, memory_format, is_quantized, qparams) 79 | 80 | 81 | @compatibility(is_backward_compatible=True) 82 | class ShapeProp(torch.fx.Interpreter): 83 | """ 84 | Execute an FX graph Node-by-Node and 85 | record the shape and type of the result 86 | into the corresponding node. 87 | 88 | Example: 89 | In this example, we record the shape 90 | and data type of a module given 91 | an example input ``torch.randn(50, D_in)``. 92 | We print the name, shape and dtype of each node. 93 | 94 | class TwoLayerNet(torch.nn.Module): 95 | def __init__(self, D_in, H, D_out): 96 | super().__init__() 97 | self.linear1 = torch.nn.Linear(D_in, H) 98 | self.linear2 = torch.nn.Linear(H, D_out) 99 | def forward(self, x): 100 | h_relu = self.linear1(x).clamp(min=0) 101 | y_pred = self.linear2(h_relu) 102 | return y_pred 103 | N, D_in, H, D_out = 64, 1000, 100, 10 104 | x = torch.randn(N, D_in) 105 | y = torch.randn(N, D_out) 106 | model = TwoLayerNet(D_in, H, D_out) 107 | gm = torch.fx.symbolic_trace(model) 108 | sample_input = torch.randn(50, D_in) 109 | ShapeProp(gm).propagate(sample_input) 110 | 111 | for node in gm.graph.nodes: 112 | print(node.name, node.meta['tensor_meta'].dtype, 113 | node.meta['tensor_meta'].shape) 114 | 115 | The output of this code is: 116 | 117 | x torch.float32 torch.Size([50, 1000]) 118 | linear1 torch.float32 torch.Size([50, 100]) 119 | clamp_1 torch.float32 torch.Size([50, 100]) 120 | linear2 torch.float32 torch.Size([50, 10]) 121 | output torch.float32 torch.Size([50, 10]) 122 | 123 | Args: 124 | module (GraphModule): The module to be executed 125 | fake_mode (FakeTensorMode): A fake mode for copying the gm 126 | 127 | """ 128 | 129 | def __init__(self, gm: GraphModule, **kwargs): 130 | super().__init__(gm) 131 | mem_func_name: str = kwargs.get('mem_func_name', 'max_memory_allocated') 132 | assert mem_func_name in ['memory_allocated', 'max_memory_allocated'] 133 | ignore_ops = kwargs.get('ignore_ops', []) 134 | 135 | fake_mode = None 136 | if fake_mode is not None: 137 | from torch._dynamo.utils import deepcopy_to_fake_tensor 138 | # Note: 139 | # We need fake execution cause the inputs are fake, however, we cannot fakify the module 140 | # - because we need to write to the tensor_meta of the real module. So we fakify to 141 | # produce a result (L131 below), to extract tensor meta, and then keep going. 142 | # 143 | # If we were to fakify, we would write to the wrong node, and then downstream fusion 144 | # would be missing the tensor_meta. 145 | # 146 | # See torch/_inductor/overrides.py for where this is called upstream of fusion. 147 | self.fake_module = deepcopy_to_fake_tensor(self.module, fake_mode) 148 | self.fake_mode = fake_mode 149 | else: 150 | self.fake_module = None 151 | self.fake_mode = None 152 | 153 | self.real_module = self.module 154 | self.ignore_ops = ignore_ops 155 | self.mem_func_name = mem_func_name 156 | self.device = next(gm.parameters()).device 157 | 158 | @compatibility(is_backward_compatible=True) 159 | def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: 160 | """ 161 | Execute a ``call_module`` node and return the result. 162 | 163 | Args: 164 | target (Target): The call target for this node. See 165 | `Node `__ for 166 | details on semantics 167 | args (Tuple): Tuple of positional args for this invocation 168 | kwargs (Dict): Dict of keyword arguments for this invocation 169 | 170 | Return 171 | Any: The value returned by the module invocation 172 | """ 173 | # Retrieve executed args and kwargs values from the environment 174 | 175 | assert isinstance(target, str) 176 | submod = self.fetch_attr(target) 177 | 178 | if self.device.type == 'cuda': 179 | torch.cuda.synchronize(self.device) 180 | t_start = time.time() 181 | 182 | # Execute the method and return the result 183 | result = submod(*args, **kwargs) 184 | 185 | if self.device.type == 'cuda': 186 | torch.cuda.synchronize(self.device) 187 | t_end = time.time() 188 | exec_time = (t_end - t_start) * 1000 189 | 190 | # 计算出来result之后再计算FLOPs,保证计算过程能正确执行 191 | mod_name = submod.__class__.__name__ 192 | flops = None 193 | if mod_name in MODULE_FLOPs_MAPPING: 194 | if mod_name not in self.ignore_ops: 195 | flops = MODULE_FLOPs_MAPPING[mod_name](submod, result, *args, **kwargs) 196 | else: 197 | flops = 0 198 | 199 | return result, flops, exec_time 200 | 201 | @compatibility(is_backward_compatible=True) 202 | def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: 203 | """ 204 | Execute a ``call_function`` node and return the result. 205 | 206 | Args: 207 | target (Target): The call target for this node. See 208 | `Node `__ for 209 | details on semantics 210 | args (Tuple): Tuple of positional args for this invocation 211 | kwargs (Dict): Dict of keyword arguments for this invocation 212 | 213 | Return 214 | Any: The value returned by the function invocation 215 | """ 216 | assert not isinstance(target, str) 217 | 218 | if self.device.type == 'cuda': 219 | torch.cuda.synchronize(self.device) 220 | t_start = time.time() 221 | 222 | # Execute the function and return the result 223 | result = target(*args, **kwargs) 224 | 225 | if self.device.type == 'cuda': 226 | torch.cuda.synchronize(self.device) 227 | t_end = time.time() 228 | exec_time = (t_end - t_start) * 1000 229 | 230 | # 计算出来result之后再计算FLOPs,保证计算过程能正确执行 231 | func_name = target.__name__ 232 | flops = None 233 | if func_name in FUNCTION_FLOPs_MAPPING: 234 | if func_name not in self.ignore_ops: 235 | flops = FUNCTION_FLOPs_MAPPING[func_name](result, *args, **kwargs) 236 | else: 237 | flops = 0 238 | 239 | return result, flops, exec_time 240 | 241 | @compatibility(is_backward_compatible=True) 242 | def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: 243 | """ 244 | Execute a ``call_method`` node and return the result. 245 | 246 | Args: 247 | target (Target): The call target for this node. See 248 | `Node `__ for 249 | details on semantics 250 | args (Tuple): Tuple of positional args for this invocation 251 | kwargs (Dict): Dict of keyword arguments for this invocation 252 | 253 | Return 254 | Any: The value returned by the method invocation 255 | """ 256 | # args[0] is the `self` object for this method call 257 | self_obj, *args_tail = args 258 | 259 | assert isinstance(target, str) 260 | 261 | if self.device.type == 'cuda': 262 | torch.cuda.synchronize(self.device) 263 | t_start = time.time() 264 | 265 | # Execute the method and return the result 266 | result = getattr(self_obj, target)(*args_tail, **kwargs) 267 | 268 | if self.device.type == 'cuda': 269 | torch.cuda.synchronize(self.device) 270 | t_end = time.time() 271 | exec_time = (t_end - t_start) * 1000 272 | 273 | # 计算出来result之后再计算FLOPs,保证计算过程能正确执行 274 | method_name = target 275 | flops = None 276 | if method_name in METHOD_FLOPs_MAPPING: 277 | if method_name not in self.ignore_ops: 278 | flops = METHOD_FLOPs_MAPPING[method_name](self_obj, result, *args_tail, **kwargs) 279 | else: 280 | flops = 0 281 | return result, flops, exec_time 282 | 283 | def run_node(self, n: Node) -> Any: 284 | try: 285 | if self.fake_module is not None: 286 | # Hacky swap. Alternatively, we could do this with overriding 287 | # call_module and get_attr. 288 | self.module = self.fake_module 289 | try: 290 | if self.fake_mode is not None: 291 | raise ValueError("'fake_mode' must be None.") 292 | else: 293 | with self._set_current_node(n): 294 | args, kwargs = self.fetch_args_kwargs_from_env(n) 295 | assert isinstance(args, tuple) 296 | assert isinstance(kwargs, dict) 297 | 298 | mem_func = getattr(torch.cuda, self.mem_func_name) 299 | if self.mem_func_name == 'max_memory_allocated': 300 | torch.cuda.reset_peak_memory_stats(self.device) 301 | m_start = mem_func(self.device) 302 | 303 | if n.op in ('call_module', 'call_function', 'call_method'): 304 | result, flops, exec_time = getattr(self, n.op)(n.target, args, kwargs) 305 | else: 306 | if self.device.type == 'cuda': 307 | torch.cuda.synchronize(self.device) 308 | t_start = time.time() 309 | 310 | result = getattr(self, n.op)(n.target, args, kwargs) 311 | 312 | if self.device.type == 'cuda': 313 | torch.cuda.synchronize(self.device) 314 | t_end = time.time() 315 | exec_time = (t_end - t_start) * 1000 316 | 317 | flops = 0 318 | 319 | m_end = mem_func(self.device) 320 | 321 | assert flops not in n.meta, n.meta.keys() 322 | 323 | n.meta['flops'] = flops 324 | n.meta['time'] = exec_time 325 | n.meta['mem_before'] = m_start 326 | n.meta['mem_after'] = m_end 327 | n.meta['mem_delta'] = m_end - m_start 328 | finally: 329 | self.module = self.real_module 330 | except Exception as e: 331 | traceback.print_exc() 332 | raise RuntimeError( 333 | f"ShapeProp error for: node={n.format_node()} with " 334 | f"meta={n.meta}" 335 | ) from e 336 | 337 | found_tensor = False 338 | 339 | def extract_tensor_meta(obj): 340 | if isinstance(obj, torch.Tensor): 341 | nonlocal found_tensor 342 | found_tensor = True 343 | return _extract_tensor_metadata(obj) 344 | else: 345 | return obj 346 | 347 | meta = map_aggregate(result, extract_tensor_meta) 348 | if found_tensor: 349 | n.meta['tensor_meta'] = meta 350 | 351 | n.meta['type'] = type(result) 352 | 353 | return result 354 | 355 | def propagate(self, *args): 356 | """ 357 | Run `module` via interpretation and return the result and 358 | record the shape and type of each node. 359 | 360 | Args: 361 | *args (Tensor): the sample input. 362 | 363 | Returns: 364 | Any: The value returned from executing the Module 365 | """ 366 | if self.fake_mode is not None: 367 | raise ValueError("'fake_mode' must be None.") 368 | fake_args = [self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t for t in args] 369 | else: 370 | fake_args = args 371 | return super().run(*fake_args) 372 | 373 | 374 | class TorchFLOPsByFX(): 375 | def __init__(self, model: nn.Module, mem_func_name: Literal['memory_allocated', 'max_memory_allocated'] = 'max_memory_allocated', ignore_ops: Sequence[str] = []): 376 | ''' 377 | model: the model. 378 | mem_func_name: which function to measure the GPU memory; choosed from 'memory_allocated' and 'max_memory_allocated'; default: 'max_memory_allocated'. 379 | ignore_ops: the operations to be ignored for counting FLOPs. 380 | ''' 381 | model.eval() 382 | try: 383 | self.graph_model: GraphModule = symbolic_trace(model) 384 | except torch.fx.proxy.TraceError as e: 385 | print("\033[33mNOTE: The model cannot be built as a graph model by 'symbolic_trace()'. Please remove the `assert`, `if` and `for` operations. " + 386 | "See 'https://pytorch.org/docs/stable/fx.html#limitations-of-symbolic-tracing' for more instructions.\033[0m") 387 | raise e 388 | except TypeError as e: 389 | print("\033[33mNOTE: The model cannot be built as a graph model by 'symbolic_trace()'. Please replace the `tensor.shape[i]` that servers as the parameter of a function with a pre-defined deterministic value.\033[0m") 390 | raise e 391 | 392 | assert mem_func_name in ['memory_allocated', 'max_memory_allocated'] 393 | self.mem_func_name = mem_func_name 394 | if isinstance(ignore_ops, str): 395 | ignore_ops = [ignore_ops] 396 | self.ignore_ops = deepcopy(ignore_ops) 397 | 398 | self.result_table = [] 399 | self.result_header = ['node_name', 'node_op', 'op_target', 'which_module', 'flops', 'time(ms)', 'mem_before_op(B)', 'mem_after_op(B)', 'mem_delta(B)'] 400 | self.__missing_values = [''] * 4 + ['ERROR'] 401 | self.__flag_propagated = False 402 | 403 | def propagate(self, *args): 404 | ShapeProp(self.graph_model, mem_func_name=self.mem_func_name, ignore_ops=self.ignore_ops).propagate(*args) 405 | 406 | result_table = [] 407 | for node in self.graph_model.graph.nodes: 408 | node: Node 409 | 410 | _target_str = str(node.target) 411 | if (_pattern := ' at 0x') in _target_str: 412 | _target_str = f"{_target_str.split(_pattern)[0]}>" 413 | if (_pattern := ' of type object') in _target_str: 414 | _target_str = f"{_target_str.split(_pattern)[0]}>" 415 | 416 | _result_row = [node.name, node.op, _target_str] 417 | 418 | node_module_name = '' 419 | if (_var_name := 'nn_module_stack') in node.meta: 420 | node_module_name = next(reversed(node.meta[_var_name].values())).__name__ 421 | # node_module_name = ".".join([_v.__name__ for _v in node.meta[_var_name].values()]) 422 | _result_row.append(node_module_name) 423 | 424 | for _var_name in ('flops', 'time', 'mem_before', 'mem_after', 'mem_delta'): 425 | if _var_name in node.meta: 426 | _var_val = node.meta[_var_name] 427 | if _var_val is None: 428 | _result_row.append('not_recognized') 429 | elif isinstance(_var_val, (int, float)): 430 | if node_module_name in self.ignore_ops: 431 | _result_row.append('ignored') 432 | else: 433 | _result_row.append(_var_val) 434 | else: 435 | raise TypeError(type(_var_val)) 436 | else: 437 | raise KeyError(f"'{_var_name}' must be in node.meta") 438 | 439 | assert len(_result_row) == len(self.result_header) 440 | result_table.append(_result_row) 441 | 442 | self.result_table = result_table 443 | self.__flag_propagated = True 444 | 445 | def print_result_table(self, show: bool = True) -> list[list[str | int | float]]: 446 | ''' 447 | Print the full result table. 448 | return: the results in a 2D list (excluding the head of the table). 449 | ''' 450 | table_str = tabulate(self.result_table, self.result_header, tablefmt='rst', 451 | intfmt=[''] * 4 + [','] + [''] + [','] * 2 + ['+,'], 452 | floatfmt='.3f', 453 | missingval=self.__missing_values) 454 | if show: 455 | print(table_str) 456 | return self.result_table 457 | 458 | def print_total_flops(self, show: bool = True) -> int: 459 | if not self.__flag_propagated: 460 | raise RuntimeError(f"Use `propagate()` method first.") 461 | 462 | valid_flops_list = list(filter(lambda _f: isinstance(_f, int), list(zip(*self.result_table))[4])) 463 | total_flops = sum(valid_flops_list) 464 | num_empty_flops = len(self.result_table) - len(valid_flops_list) 465 | 466 | if show: 467 | print(f"total_flops = {total_flops:3,}", f"({num_empty_flops} operations are ignored or not recognized)" if num_empty_flops else "") 468 | 469 | """ 470 | total_flops = None 471 | try: 472 | total_flops = sum(filter(list(zip(*self.result_table))[-1])) 473 | except TypeError as e: 474 | print("\033[33mNOTE: There may be some operations not recognized. Please check them using `print_result_table()` and then add them to 'flops_ops.py'\033[0m") 475 | self.print_result_table() 476 | print(f"\033[33mNot Recognized: {set([_m for _m,_f in zip(*list(zip(*self.result_table))[-2:]) if _f is None])}\033[0m") 477 | print(f"\033[31m{traceback.format_exc()}\033[0m") 478 | exit(-1) 479 | """ 480 | return total_flops 481 | 482 | def print_total_time(self, show: bool = True) -> float: 483 | if not self.__flag_propagated: 484 | raise RuntimeError(f"Use `propagate()` method first.") 485 | 486 | valid_time_list = list(zip(*self.result_table))[5] 487 | total_time = sum(valid_time_list) 488 | 489 | if show: 490 | print(f"total_time = {total_time:.3f} ms") 491 | 492 | return total_time 493 | 494 | def print_max_memory(self, show=True) -> int: 495 | if not self.__flag_propagated: 496 | raise RuntimeError(f"Use `propagate()` method first.") 497 | 498 | valid_mem_list = list(zip(*self.result_table))[7] 499 | max_mem = max(valid_mem_list) 500 | 501 | if show: 502 | print(f"max_memory = {max_mem:3,} Bytes") 503 | 504 | return max_mem 505 | 506 | def save_result_to_csv(self, file_path:str, mode:str='a'): 507 | with open(file_path, mode) as f: 508 | writer = csv.writer(f) 509 | writer.writerow(self.result_header) 510 | writer.writerows(self.result_table) 511 | f.write('\n') -------------------------------------------------------------------------------- /torch_flops/flops_ops.py: -------------------------------------------------------------------------------- 1 | from torch import nn, Tensor, Size 2 | from torch.types import Number 3 | 4 | __all__ = ['MODULE_FLOPs_MAPPING', 'FUNCTION_FLOPs_MAPPING', 'METHOD_FLOPs_MAPPING'] 5 | 6 | 7 | def flops_zero() -> int: 8 | return 0 9 | 10 | 11 | def flops_elemwise(result_shape: Size) -> int: 12 | return result_shape.numel() 13 | 14 | 15 | def flops_matmul(tensor1_shape: Size, tensor2_shape: Size, result_shape: Size) -> int: 16 | # 可根据输入维度改为分情况处理,参考https://github.com/zhijian-liu/torchprofile/blob/6d80fe57bb8c6bc9f789da7925fac6547fa9502b/torchprofile/handlers.py#L35 17 | def get_reduce_dim_shape(_s: Size, is_first_mat: bool): 18 | return _s[0] if len(_s) == 1 else _s[-1 if is_first_mat else -2] 19 | 20 | reduce_dim_shape = get_reduce_dim_shape(tensor1_shape, True) 21 | assert reduce_dim_shape == get_reduce_dim_shape(tensor2_shape, False) 22 | return (2 * reduce_dim_shape - 1) * result_shape.numel() 23 | 24 | # For nn.modules.* 25 | def flops_convnd(module: nn.modules.conv._ConvNd, input_shape: Size, result_shape: Size) -> int: 26 | kernel_size = Size([__k]) if isinstance(__k := module.kernel_size, int) else Size(__k) 27 | window_flops_per_chan = 2 * kernel_size.numel() - 1 28 | effective_in_chan = module.in_channels // module.groups 29 | window_flops = effective_in_chan * window_flops_per_chan + (effective_in_chan - 1) 30 | conv_flops = result_shape.numel() * window_flops 31 | bias_flops = result_shape.numel() if module.bias is not None else 0 32 | return conv_flops + bias_flops 33 | # return (2 * kernel_size.numel() * module.in_channels // module.groups - int(module.bias is None)) * result_shape.numel() 34 | 35 | 36 | def flops_avgpoolnd(module: nn.modules.pooling._AvgPoolNd, input_shape: Size, result_shape: Size) -> int: 37 | kernel_size = Size([__k]) if isinstance(__k := module.kernel_size, int) else Size(__k) 38 | return kernel_size.numel() * result_shape.numel() 39 | 40 | 41 | def flops_adaptive_avgpoolnd(module: nn.modules.pooling._AdaptiveAvgPoolNd, input_shape: Size, result_shape: Size) -> int: 42 | kernel_size = Size( 43 | i_size // o_size if (i_size % o_size) == 0 else i_size - o_size * (i_size // o_size) + 1 44 | for i_size, o_size in zip(input_shape[2:], result_shape[2:]) 45 | ) 46 | return kernel_size.numel() * result_shape.numel() 47 | 48 | 49 | def flops_maxpoolnd(module: nn.modules.pooling._AvgPoolNd, input_shape: Size, result_shape: Size) -> int: 50 | kernel_size = Size([__k]) if isinstance(__k := module.kernel_size, int) else Size(__k) 51 | return (kernel_size.numel() - 1) * result_shape.numel() 52 | 53 | 54 | def flops_adaptive_maxpoolnd(module: nn.modules.pooling._AdaptiveMaxPoolNd, input_shape: Size, result_shape: Size) -> int: 55 | kernel_size = Size( 56 | i_size // o_size if (i_size % o_size) == 0 else i_size - o_size * (i_size // o_size) + 1 57 | for i_size, o_size in zip(input_shape[2:], result_shape[2:]) 58 | ) 59 | return (kernel_size.numel() - 1) * result_shape.numel() 60 | 61 | 62 | def flops_functional_convnd(bias: int, groups: int, kernel_size: Size, in_channels: int, result_shape: Size) -> int: 63 | total_flops = (2 * kernel_size.numel() * in_channels - int(bias is None) * groups) * result_shape.numel() 64 | return total_flops 65 | 66 | 67 | # For ModuleFLOPs 68 | def ModuleFLOPs_zero(module: nn.Linear, result: Tensor, *args, **kwargs) -> int: 69 | return flops_zero() 70 | 71 | 72 | def ModuleFLOPs_elemwise(module: nn.Module, result: Tensor, *args, **kwargs) -> int: 73 | assert len(args) == 1 74 | assert isinstance(args[0], Tensor) 75 | assert isinstance(result, Tensor) 76 | 77 | input_shape = args[0].shape # [..., d_in] 78 | result_shape = result.shape 79 | assert input_shape == result_shape 80 | 81 | total_flops = flops_elemwise(result_shape) 82 | return total_flops 83 | 84 | 85 | def ModuleFLOPs_LeakyReLU(module: nn.LeakyReLU, result: Tensor, *args, **kwargs) -> int: 86 | return result.numel() * 4 87 | 88 | 89 | def ModuleFLOPs_Linear(module: nn.Linear, result: Tensor, *args, **kwargs) -> int: 90 | assert len(args) == 1 91 | assert isinstance(args[0], Tensor) 92 | assert isinstance(result, Tensor) 93 | 94 | input_shape = args[0].shape # [..., d_in] 95 | weight_shape = module.weight.T.shape # [d_out, d_in].T -> [d_in, d_out] 96 | result_shape = result.shape 97 | 98 | assert input_shape[-1] == weight_shape[0], f"{input_shape}, {weight_shape}" 99 | matmul_shape = Size(list(input_shape[:-1]) + list(weight_shape[-1:])) 100 | assert matmul_shape == result_shape 101 | 102 | total_flops = flops_matmul(input_shape, weight_shape, result_shape) 103 | if module.bias is not None: 104 | total_flops += flops_elemwise(result_shape) 105 | 106 | return total_flops 107 | 108 | 109 | def ModuleFLOPs_ConvNd(module: nn.Conv1d | nn.Conv2d | nn.Conv3d, result: Tensor, *args, **kwargs) -> int: 110 | assert len(args) == 1 111 | assert isinstance(args[0], Tensor) 112 | assert isinstance(result, Tensor) 113 | 114 | input_shape = args[0].shape 115 | result_shape = result.shape 116 | 117 | total_flops = flops_convnd(module, input_shape, result_shape) 118 | return total_flops 119 | 120 | 121 | def ModuleFLOPs_AvgPoolNd(module: nn.AvgPool1d | nn.AvgPool2d | nn.AvgPool3d, result: Tensor, *args, **kwargs) -> int: 122 | assert len(args) == 1 123 | assert isinstance(args[0], Tensor) 124 | assert isinstance(result, Tensor) 125 | 126 | input_shape = args[0].shape 127 | result_shape = result.shape 128 | 129 | total_flops = flops_avgpoolnd(module, input_shape, result_shape) 130 | return total_flops 131 | 132 | 133 | def ModuleFLOPs_AdaptiveAvgPoolNd(module: nn.AdaptiveAvgPool1d | nn.AdaptiveAvgPool2d | nn.AdaptiveAvgPool3d, result: Tensor, *args, **kwargs) -> int: 134 | assert len(args) == 1 135 | assert isinstance(args[0], Tensor) 136 | assert isinstance(result, Tensor) 137 | 138 | input_shape = args[0].shape 139 | result_shape = result.shape 140 | 141 | total_flops = flops_adaptive_avgpoolnd(module, input_shape, result_shape) 142 | return total_flops 143 | 144 | 145 | def ModuleFLOPs_MaxPoolNd(module: nn.MaxPool1d | nn.MaxPool2d | nn.MaxPool3d, result: Tensor, *args, **kwargs) -> int: 146 | assert len(args) == 1 147 | assert isinstance(args[0], Tensor) 148 | assert isinstance(result, Tensor) 149 | 150 | input_shape = args[0].shape 151 | result_shape = result.shape 152 | 153 | total_flops = flops_maxpoolnd(module, input_shape, result_shape) 154 | return total_flops 155 | 156 | 157 | def ModuleFLOPs_AdaptiveMaxPoolNd(module: nn.AdaptiveMaxPool1d | nn.AdaptiveMaxPool2d | nn.AdaptiveMaxPool3d, result: Tensor, *args, **kwargs) -> int: 158 | assert len(args) == 1 159 | assert isinstance(args[0], Tensor) 160 | assert isinstance(result, Tensor) 161 | 162 | input_shape = args[0].shape 163 | result_shape = result.shape 164 | 165 | total_flops = flops_adaptive_maxpoolnd(module, input_shape, result_shape) 166 | return total_flops 167 | 168 | 169 | def ModuleFLOPs_Norm(module: nn.modules.batchnorm._NormBase | nn.LayerNorm | nn.GroupNorm, result: Tensor, *args, **kwargs) -> int: 170 | assert len(args) == 1 171 | assert isinstance(args[0], Tensor) 172 | assert isinstance(result, Tensor) 173 | assert not module.training, "Only support `eval` mode." 174 | 175 | input_shape = args[0].shape # [..., d_in] 176 | result_shape = result.shape 177 | assert input_shape == result_shape 178 | 179 | # (X-mean)/std 180 | total_flops = flops_elemwise(input_shape) * 2 181 | if (hasattr(module, 'affine') and module.affine) or (hasattr(module, 'elementwise_affine'), module.elementwise_affine): 182 | total_flops += flops_elemwise(input_shape) * 2 183 | 184 | return total_flops 185 | 186 | 187 | def ModuleFLOPs_GELU(module: nn.GELU, result: Tensor, *args, **kwargs) -> int: 188 | assert len(args) == 1 189 | assert isinstance(args[0], Tensor) 190 | assert isinstance(result, Tensor) 191 | 192 | input_shape = args[0].shape # [..., d_in] 193 | result_shape = result.shape 194 | assert input_shape == result_shape 195 | 196 | total_flops = flops_elemwise(result_shape) 197 | if module.approximate is None: 198 | raise NotImplementedError() 199 | 200 | return total_flops 201 | 202 | 203 | # For FunctionFLOPs 204 | def FunctionFLOPs_zero(result: Tensor, *args, **kwargs) -> int: 205 | return flops_zero() 206 | 207 | 208 | def FunctionFLOPs_elemwise(result: Tensor | Number, *args, **kwargs) -> int: 209 | assert len(args) == 2, len(args) 210 | 211 | total_flops = None 212 | if isinstance(result, Number): 213 | total_flops = 1 214 | elif isinstance(result, Tensor): 215 | total_flops = flops_elemwise(result.shape) 216 | else: 217 | raise TypeError(type(result)) 218 | 219 | return total_flops 220 | 221 | 222 | def FunctionFLOPs_matmul(result: Tensor, *args, **kwargs) -> int: 223 | assert len(args) == 2, len(args) 224 | tensor_A, tensor_B = args 225 | assert isinstance(tensor_A, Tensor) and isinstance(tensor_B, Tensor) 226 | 227 | total_flops = flops_matmul(tensor_A.shape, tensor_B.shape, result.shape) 228 | return total_flops 229 | 230 | 231 | def FunctionFLOPs_linear(result: Tensor, *args, **kwargs) -> int: 232 | if len(args) == 3: 233 | input, weight, bias = args 234 | elif len(args) == 2: 235 | input, weight = args 236 | bias = kwargs.get('bias') 237 | else: 238 | input = args[0] 239 | weight = kwargs.get('weight') 240 | bias = kwargs.get('bias') 241 | 242 | assert isinstance(input, Tensor) and isinstance(weight, Tensor) 243 | 244 | total_flops = flops_matmul(input.shape, weight.T.shape, result.shape) 245 | if bias is not None: 246 | total_flops += flops_elemwise(result.shape) 247 | return total_flops 248 | 249 | 250 | def FunctionFLOPs_convnd(result: Tensor, *args, **kwargs) -> int: 251 | 252 | input = args[0] 253 | if len(args) > 1: 254 | weight = args[1] 255 | else: 256 | weight = kwargs.get('weight') 257 | 258 | assert isinstance(input, Tensor) 259 | assert isinstance(weight, Tensor) 260 | 261 | kernel_size = weight.shape[2:] 262 | in_channels = weight.shape[1] 263 | bias = kwargs.get('bias') 264 | groups = kwargs.get('groups', None) 265 | if groups is None: 266 | groups = 1 267 | stride = kwargs.get('stride', None) 268 | if stride is None: 269 | stride = 1 270 | padding = kwargs.get('padding', None) 271 | if padding is None: 272 | padding = 0 273 | result_shape = result.shape 274 | 275 | return flops_functional_convnd(bias, groups, kernel_size, in_channels, result_shape) 276 | 277 | def FunctionFLOPs_leaky_relu(result: Tensor, *args, **kwargs) -> int: 278 | return result.numel() * 4 279 | 280 | def FunctionFLOPs_interpolate(result: Tensor, *args, **kwargs) -> int: 281 | input = args[0] 282 | if len(args) > 1: 283 | size = args[1] 284 | else: 285 | size = kwargs.get('size', None) 286 | 287 | if size is not None: 288 | if isinstance(size, tuple) or isinstance(size, list): 289 | prod = 1 290 | for s in size: 291 | prod *= s 292 | return int(prod) 293 | else: 294 | return int(size) 295 | 296 | if len(args) > 2: 297 | scale_factor = args[2] 298 | else: 299 | scale_factor = kwargs.get('scale_factor', None) 300 | 301 | flops = input.numel() 302 | if isinstance(scale_factor, tuple) and len(scale_factor) == len(input): 303 | prod = 1 304 | for s in scale_factor: 305 | prod *= s 306 | flops *= int(prod) 307 | else: 308 | flops *= scale_factor**len(input) 309 | 310 | return flops 311 | 312 | 313 | # For MethodFLOPs 314 | def MethodFLOPs_zero(self_obj: Tensor, result: Tensor, *args_tail, **kwargs) -> int: 315 | return flops_zero() 316 | 317 | 318 | def MethodFLOPs_elemwise(self_obj: Tensor, result: Tensor, *args_tail, **kwargs) -> int: 319 | return flops_elemwise(result.shape) 320 | 321 | 322 | def MethodFLOPs_sum(self_obj: Tensor, result: Tensor, *args_tail, **kwargs) -> int: 323 | this_shape = self_obj.squeeze().shape 324 | result_shape = result.squeeze().shape 325 | 326 | total_flops = None 327 | if len(result_shape) == 0: 328 | total_flops = self_obj.numel() - 1 329 | else: 330 | kept_shape = list(this_shape) 331 | for s in result_shape: 332 | kept_shape.remove(s) 333 | kept_shape = Size(kept_shape) 334 | total_flops = kept_shape.numel() * (result_shape.numel() - 1) 335 | 336 | return total_flops 337 | 338 | 339 | def MethodFLOPs_softmax(self_obj: Tensor, result: Tensor, *args_tail, **kwargs) -> int: 340 | this_shape = self_obj.shape 341 | result_shape = result.shape 342 | assert this_shape == result_shape 343 | 344 | exp_flops = flops_elemwise(this_shape) 345 | 346 | dim_reduce: int = args_tail[0] if args_tail else kwargs.get('dim') 347 | dims_kept = list(this_shape) 348 | dims_kept.pop(dim_reduce) 349 | dims_kept = Size(dims_kept) 350 | sum_flops = (this_shape[dim_reduce] - 1) * dims_kept.numel() 351 | 352 | div_flops = flops_elemwise(this_shape) 353 | 354 | total_flops = exp_flops + sum_flops + div_flops 355 | return total_flops 356 | 357 | 358 | 359 | MODULE_FLOPs_MAPPING = { 360 | 'Linear': ModuleFLOPs_Linear, 361 | 'Identity': ModuleFLOPs_zero, 362 | 'Conv1d': ModuleFLOPs_ConvNd, 363 | 'Conv2d': ModuleFLOPs_ConvNd, 364 | 'Conv3d': ModuleFLOPs_ConvNd, 365 | 'AvgPool1d': ModuleFLOPs_AvgPoolNd, 366 | 'AvgPool2d': ModuleFLOPs_AvgPoolNd, 367 | 'AvgPool3d': ModuleFLOPs_AvgPoolNd, 368 | 'AdaptiveAvgPool1d': ModuleFLOPs_AdaptiveAvgPoolNd, 369 | 'AdaptiveAvgPool2d': ModuleFLOPs_AdaptiveAvgPoolNd, 370 | 'AdaptiveAvgPool3d': ModuleFLOPs_AdaptiveAvgPoolNd, 371 | 'MaxPool1d': ModuleFLOPs_MaxPoolNd, 372 | 'MaxPool2d': ModuleFLOPs_MaxPoolNd, 373 | 'MaxPool3d': ModuleFLOPs_MaxPoolNd, 374 | 'AdaptiveMaxPool1d': ModuleFLOPs_AdaptiveMaxPoolNd, 375 | 'AdaptiveMaxPool2d': ModuleFLOPs_AdaptiveMaxPoolNd, 376 | 'AdaptiveMaxPool3d': ModuleFLOPs_AdaptiveMaxPoolNd, 377 | 'LayerNorm': ModuleFLOPs_Norm, 378 | 'BatchNorm1d': ModuleFLOPs_Norm, 379 | 'BatchNorm2d': ModuleFLOPs_Norm, 380 | 'BatchNorm3d': ModuleFLOPs_Norm, 381 | 'InstanceNorm1d': ModuleFLOPs_Norm, 382 | 'InstanceNorm2d': ModuleFLOPs_Norm, 383 | 'InstanceNorm3d': ModuleFLOPs_Norm, 384 | 'GroupNorm': ModuleFLOPs_Norm, 385 | 'Dropout': ModuleFLOPs_zero, 386 | 'GELU': ModuleFLOPs_GELU, 387 | 'ReLU': ModuleFLOPs_elemwise, 388 | 'Flatten': ModuleFLOPs_zero, 389 | 'LeakyReLU': ModuleFLOPs_LeakyReLU, 390 | 'type_as': ModuleFLOPs_zero 391 | } 392 | FUNCTION_FLOPs_MAPPING = { 393 | 'getattr': FunctionFLOPs_zero, 394 | 'getitem': FunctionFLOPs_zero, 395 | 'mul': FunctionFLOPs_elemwise, 396 | 'truediv': FunctionFLOPs_elemwise, 397 | 'sub': FunctionFLOPs_elemwise, 398 | 'matmul': FunctionFLOPs_matmul, 399 | 'add': FunctionFLOPs_elemwise, 400 | 'concat': FunctionFLOPs_zero, 401 | '_assert': FunctionFLOPs_zero, 402 | 'eq': FunctionFLOPs_elemwise, 403 | 'cat': FunctionFLOPs_zero, 404 | 'linear': FunctionFLOPs_linear, 405 | 'conv1d': FunctionFLOPs_convnd, 406 | 'conv2d': FunctionFLOPs_convnd, 407 | 'conv3d': FunctionFLOPs_convnd, 408 | 'leaky_relu': FunctionFLOPs_leaky_relu, 409 | 'pad': FunctionFLOPs_zero, 410 | 'floordiv': FunctionFLOPs_zero, 411 | 'flip': FunctionFLOPs_zero, 412 | 'interpolate': FunctionFLOPs_interpolate, 413 | } 414 | METHOD_FLOPs_MAPPING = { 415 | 'reshape': MethodFLOPs_zero, 416 | 'permute': MethodFLOPs_zero, 417 | 'unbind': MethodFLOPs_zero, 418 | 'transpose': MethodFLOPs_zero, 419 | 'repeat': MethodFLOPs_zero, 420 | 'unsqueeze': MethodFLOPs_zero, 421 | 'exp': MethodFLOPs_elemwise, 422 | 'sum': MethodFLOPs_sum, 423 | 'div': MethodFLOPs_elemwise, 424 | 'softmax': MethodFLOPs_softmax, 425 | 'expand': MethodFLOPs_zero, 426 | 'flatten': MethodFLOPs_zero, 427 | 'view': MethodFLOPs_zero, 428 | 'cuda': MethodFLOPs_zero, 429 | 'flip': MethodFLOPs_zero, 430 | 'type_as': MethodFLOPs_zero, 431 | 'size': MethodFLOPs_zero, 432 | 'clone': MethodFLOPs_zero, 433 | 'new_empty': MethodFLOPs_zero, 434 | 'normal_': MethodFLOPs_zero, 435 | 'pow': MethodFLOPs_zero, 436 | } 437 | --------------------------------------------------------------------------------