├── .gitignore ├── LICENSE ├── README.md ├── setup.py └── torchscope ├── __init__.py ├── helper.py └── scope.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### Python template 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | .pytest_cache/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | db.sqlite3 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # pyenv 78 | .python-version 79 | 80 | # celery beat schedule file 81 | celerybeat-schedule 82 | 83 | # SageMath parsed files 84 | *.sage.py 85 | 86 | # Environments 87 | .env 88 | .venv 89 | env/ 90 | venv/ 91 | ENV/ 92 | env.bak/ 93 | venv.bak/ 94 | 95 | # Spyder project settings 96 | .spyderproject 97 | .spyproject 98 | 99 | # Rope project settings 100 | .ropeproject 101 | 102 | # mkdocs documentation 103 | /site 104 | 105 | # mypy 106 | .mypy_cache/ 107 | 108 | # pycharm 109 | .idea/ 110 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # torchscope 2 | This is a neat plugin for scoping model in PyTorch. It is mainly based on the [pytorch-summary](https://github.com/sksq96/pytorch-summary) and [torchstat](https://github.com/Swall0w/torchstat). 3 | 4 | ## Installation 5 | 6 | - Install via pip 7 | 8 | ``` 9 | $ pip install torchscope 10 | ``` 11 | 12 | - Install from source 13 | 14 | ``` 15 | $ pip install --upgrade git+https://github.com/Tramac/torchscope.git 16 | ``` 17 | 18 | ## Usage 19 | 20 | ```python 21 | from torchvision.models import resnet18 22 | from torchscope import scope 23 | 24 | model = resnet18() 25 | scope(model, input_size=(3, 224, 224)) 26 | ``` 27 | 28 | ``` 29 | ------------------------------------------------------------------------------------------------------ 30 | Layer (type) Output Shape Params FLOPs Madds 31 | ====================================================================================================== 32 | Conv2d-1 [1, 64, 112, 112] 9,408 118,013,952 235,225,088 33 | BatchNorm2d-2 [1, 64, 112, 112] 128 1,605,632 3,211,264 34 | ReLU-3 [1, 64, 112, 112] 0 802,816 802,816 35 | MaxPool2d-4 [1, 64, 56, 56] 0 802,816 1,605,632 36 | Conv2d-5 [1, 64, 56, 56] 36,864 115,605,504 231,010,304 37 | BatchNorm2d-6 [1, 64, 56, 56] 128 401,408 802,816 38 | ReLU-7 [1, 64, 56, 56] 0 200,704 200,704 39 | Conv2d-8 [1, 64, 56, 56] 36,864 115,605,504 231,010,304 40 | BatchNorm2d-9 [1, 64, 56, 56] 128 401,408 802,816 41 | ReLU-10 [1, 64, 56, 56] 0 200,704 200,704 42 | Conv2d-11 [1, 64, 56, 56] 36,864 115,605,504 231,010,304 43 | BatchNorm2d-12 [1, 64, 56, 56] 128 401,408 802,816 44 | ReLU-13 [1, 64, 56, 56] 0 200,704 200,704 45 | Conv2d-14 [1, 64, 56, 56] 36,864 115,605,504 231,010,304 46 | BatchNorm2d-15 [1, 64, 56, 56] 128 401,408 802,816 47 | ReLU-16 [1, 64, 56, 56] 0 200,704 200,704 48 | Conv2d-17 [1, 128, 28, 28] 73,728 57,802,752 115,505,152 49 | BatchNorm2d-18 [1, 128, 28, 28] 256 200,704 401,408 50 | ReLU-19 [1, 128, 28, 28] 0 100,352 100,352 51 | Conv2d-20 [1, 128, 28, 28] 147,456 115,605,504 231,110,656 52 | BatchNorm2d-21 [1, 128, 28, 28] 256 200,704 401,408 53 | Conv2d-22 [1, 128, 28, 28] 8,192 6,422,528 12,744,704 54 | BatchNorm2d-23 [1, 128, 28, 28] 256 200,704 401,408 55 | ReLU-24 [1, 128, 28, 28] 0 100,352 100,352 56 | Conv2d-25 [1, 128, 28, 28] 147,456 115,605,504 231,110,656 57 | BatchNorm2d-26 [1, 128, 28, 28] 256 200,704 401,408 58 | ReLU-27 [1, 128, 28, 28] 0 100,352 100,352 59 | Conv2d-28 [1, 128, 28, 28] 147,456 115,605,504 231,110,656 60 | BatchNorm2d-29 [1, 128, 28, 28] 256 200,704 401,408 61 | ReLU-30 [1, 128, 28, 28] 0 100,352 100,352 62 | Conv2d-31 [1, 256, 14, 14] 294,912 57,802,752 115,555,328 63 | BatchNorm2d-32 [1, 256, 14, 14] 512 100,352 200,704 64 | ReLU-33 [1, 256, 14, 14] 0 50,176 50,176 65 | Conv2d-34 [1, 256, 14, 14] 589,824 115,605,504 231,160,832 66 | BatchNorm2d-35 [1, 256, 14, 14] 512 100,352 200,704 67 | Conv2d-36 [1, 256, 14, 14] 32,768 6,422,528 12,794,880 68 | BatchNorm2d-37 [1, 256, 14, 14] 512 100,352 200,704 69 | ReLU-38 [1, 256, 14, 14] 0 50,176 50,176 70 | Conv2d-39 [1, 256, 14, 14] 589,824 115,605,504 231,160,832 71 | BatchNorm2d-40 [1, 256, 14, 14] 512 100,352 200,704 72 | ReLU-41 [1, 256, 14, 14] 0 50,176 50,176 73 | Conv2d-42 [1, 256, 14, 14] 589,824 115,605,504 231,160,832 74 | BatchNorm2d-43 [1, 256, 14, 14] 512 100,352 200,704 75 | ReLU-44 [1, 256, 14, 14] 0 50,176 50,176 76 | Conv2d-45 [1, 512, 7, 7] 1,179,648 57,802,752 115,580,416 77 | BatchNorm2d-46 [1, 512, 7, 7] 1,024 50,176 100,352 78 | ReLU-47 [1, 512, 7, 7] 0 25,088 25,088 79 | Conv2d-48 [1, 512, 7, 7] 2,359,296 115,605,504 231,185,920 80 | BatchNorm2d-49 [1, 512, 7, 7] 1,024 50,176 100,352 81 | Conv2d-50 [1, 512, 7, 7] 131,072 6,422,528 12,819,968 82 | BatchNorm2d-51 [1, 512, 7, 7] 1,024 50,176 100,352 83 | ReLU-52 [1, 512, 7, 7] 0 25,088 25,088 84 | Conv2d-53 [1, 512, 7, 7] 2,359,296 115,605,504 231,185,920 85 | BatchNorm2d-54 [1, 512, 7, 7] 1,024 50,176 100,352 86 | ReLU-55 [1, 512, 7, 7] 0 25,088 25,088 87 | Conv2d-56 [1, 512, 7, 7] 2,359,296 115,605,504 231,185,920 88 | BatchNorm2d-57 [1, 512, 7, 7] 1,024 50,176 100,352 89 | ReLU-58 [1, 512, 7, 7] 0 25,088 25,088 90 | AvgPool2d-59 [1, 512, 1, 1] 0 25,088 25,088 91 | Linear-60 [1, 1000] 513,000 512,000 1,023,000 92 | ====================================================================================================== 93 | Total params: 11,689,512 94 | Trainable params: 11,689,512 95 | Non-trainable params: 0 96 | Total FLOPs: 1,822,176,768 97 | Total Madds: 3,639,535,640 98 | ---------------------------------------------------------------- 99 | Input size (MB): 0.14 100 | Forward/backward pass size (MB): 14.26 101 | Params size (MB): 11.15 102 | Estimated Total Size (MB): 25.55 103 | FLOPs size (GB): 1.82 104 | Madds size (GB): 3.64 105 | ---------------------------------------------------------------- 106 | ``` 107 | 108 | ## Note 109 | 110 | This plugin only supports the following operations: 111 | 112 | - Conv2d 113 | - BatchNorm2d 114 | - Pool2d 115 | - ReLU 116 | - Upsample 117 | 118 | ## Reference 119 | 120 | - [pytorch-summary](https://github.com/sksq96/pytorch-summary) 121 | - [torchstat](https://github.com/Swall0w/torchstat) -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name="torchscope", 5 | version="0.1.0", 6 | description="Model scope in PyTorch", 7 | url="https://github.com/Tramac/pytorchscope", 8 | author="Tramac", 9 | packages=["torchscope"], 10 | ) -------------------------------------------------------------------------------- /torchscope/__init__.py: -------------------------------------------------------------------------------- 1 | from .scope import scope -------------------------------------------------------------------------------- /torchscope/helper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn as nn 3 | 4 | __all__ = ["compute_flops", "compute_madd"] 5 | 6 | 7 | def compute_flops(module, inp, out): 8 | if isinstance(module, nn.Conv2d): 9 | return compute_Conv2d_flops(module, inp, out) // 2 10 | elif isinstance(module, nn.BatchNorm2d): 11 | return compute_BatchNorm2d_flops(module, inp, out) // 2 12 | elif isinstance(module, (nn.AvgPool2d, nn.MaxPool2d)): 13 | return compute_Pool2d_flops(module, inp, out) // 2 14 | elif isinstance(module, (nn.ReLU, nn.ReLU6, nn.PReLU, nn.ELU, nn.LeakyReLU)): 15 | return compute_ReLU_flops(module, inp, out) // 2 16 | elif isinstance(module, nn.Upsample): 17 | return compute_Upsample_flops(module, inp, out) // 2 18 | elif isinstance(module, nn.Linear): 19 | return compute_Linear_flops(module, inp, out) // 2 20 | else: 21 | return 0 22 | 23 | 24 | def compute_Conv2d_flops(module, inp, out): 25 | # Can have multiple inputs, getting the first one 26 | assert isinstance(module, nn.Conv2d) 27 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size()) 28 | 29 | batch_size = inp.size()[0] 30 | in_c = inp.size()[1] 31 | k_h, k_w = module.kernel_size 32 | out_c, out_h, out_w = out.size()[1:] 33 | groups = module.groups 34 | 35 | filters_per_channel = out_c // groups 36 | conv_per_position_flops = k_h * k_w * in_c * filters_per_channel 37 | active_elements_count = batch_size * out_h * out_w 38 | 39 | total_conv_flops = conv_per_position_flops * active_elements_count 40 | 41 | bias_flops = 0 42 | if module.bias is not None: 43 | bias_flops = out_c * active_elements_count 44 | 45 | total_flops = total_conv_flops + bias_flops 46 | return total_flops 47 | 48 | 49 | def compute_BatchNorm2d_flops(module, inp, out): 50 | assert isinstance(module, nn.BatchNorm2d) 51 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size()) 52 | in_c, in_h, in_w = inp.size()[1:] 53 | batch_flops = np.prod(inp.shape) 54 | if module.affine: 55 | batch_flops *= 2 56 | return batch_flops 57 | 58 | 59 | def compute_ReLU_flops(module, inp, out): 60 | assert isinstance(module, (nn.ReLU, nn.ReLU6, nn.PReLU, nn.ELU, nn.LeakyReLU)) 61 | batch_size = inp.size()[0] 62 | active_elements_count = batch_size 63 | 64 | for s in inp.size()[1:]: 65 | active_elements_count *= s 66 | 67 | return active_elements_count 68 | 69 | 70 | def compute_Pool2d_flops(module, inp, out): 71 | assert isinstance(module, nn.MaxPool2d) or isinstance(module, nn.AvgPool2d) 72 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size()) 73 | return np.prod(inp.shape) 74 | 75 | 76 | def compute_Linear_flops(module, inp, out): 77 | assert isinstance(module, nn.Linear) 78 | assert len(inp.size()) == 2 and len(out.size()) == 2 79 | batch_size = inp.size()[0] 80 | return batch_size * inp.size()[1] * out.size()[1] 81 | 82 | 83 | def compute_Upsample_flops(module, inp, out): 84 | assert isinstance(module, nn.Upsample) 85 | output_size = out[0] 86 | batch_size = inp.size()[0] 87 | output_elements_count = batch_size 88 | for s in output_size.shape[1:]: 89 | output_elements_count *= s 90 | 91 | return output_elements_count 92 | 93 | 94 | def compute_madd(module, inp, out): 95 | if isinstance(module, nn.Conv2d): 96 | return compute_Conv2d_madd(module, inp, out) 97 | elif isinstance(module, nn.ConvTranspose2d): 98 | return compute_ConvTranspose2d_madd(module, inp, out) 99 | elif isinstance(module, nn.BatchNorm2d): 100 | return compute_BatchNorm2d_madd(module, inp, out) 101 | elif isinstance(module, nn.MaxPool2d): 102 | return compute_MaxPool2d_madd(module, inp, out) 103 | elif isinstance(module, nn.AvgPool2d): 104 | return compute_AvgPool2d_madd(module, inp, out) 105 | elif isinstance(module, (nn.ReLU, nn.ReLU6)): 106 | return compute_ReLU_madd(module, inp, out) 107 | elif isinstance(module, nn.Softmax): 108 | return compute_Softmax_madd(module, inp, out) 109 | elif isinstance(module, nn.Linear): 110 | return compute_Linear_madd(module, inp, out) 111 | elif isinstance(module, nn.Bilinear): 112 | return compute_Bilinear_madd(module, inp[0], inp[1], out) 113 | else: 114 | return 0 115 | 116 | 117 | def compute_Conv2d_madd(module, inp, out): 118 | assert isinstance(module, nn.Conv2d) 119 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size()) 120 | 121 | in_c = inp.size()[1] 122 | k_h, k_w = module.kernel_size 123 | out_c, out_h, out_w = out.size()[1:] 124 | groups = module.groups 125 | 126 | # ops per output element 127 | kernel_mul = k_h * k_w * (in_c // groups) 128 | kernel_add = kernel_mul - 1 + (0 if module.bias is None else 1) 129 | 130 | kernel_mul_group = kernel_mul * out_h * out_w * (out_c // groups) 131 | kernel_add_group = kernel_add * out_h * out_w * (out_c // groups) 132 | 133 | total_mul = kernel_mul_group * groups 134 | total_add = kernel_add_group * groups 135 | 136 | return total_mul + total_add 137 | 138 | 139 | def compute_ConvTranspose2d_madd(module, inp, out): 140 | assert isinstance(module, nn.ConvTranspose2d) 141 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size()) 142 | 143 | in_c, in_h, in_w = inp.size()[1:] 144 | k_h, k_w = module.kernel_size 145 | out_c, out_h, out_w = out.size()[1:] 146 | groups = module.groups 147 | 148 | kernel_mul = k_h * k_w * (in_c // groups) 149 | kernel_add = kernel_mul - 1 + (0 if module.bias is None else 1) 150 | 151 | kernel_mul_group = kernel_mul * in_h * in_w * (out_c // groups) 152 | kernel_add_group = kernel_add * in_h * in_w * (out_c // groups) 153 | 154 | total_mul = kernel_mul_group * groups 155 | total_add = kernel_add_group * groups 156 | 157 | return total_mul + total_add 158 | 159 | 160 | def compute_BatchNorm2d_madd(module, inp, out): 161 | assert isinstance(module, nn.BatchNorm2d) 162 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size()) 163 | 164 | in_c, in_h, in_w = inp.size()[1:] 165 | 166 | # 1. sub mean 167 | # 2. div standard deviation 168 | # 3. mul alpha 169 | # 4. add beta 170 | return 4 * in_c * in_h * in_w 171 | 172 | 173 | def compute_MaxPool2d_madd(module, inp, out): 174 | assert isinstance(module, nn.MaxPool2d) 175 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size()) 176 | 177 | if isinstance(module.kernel_size, (tuple, list)): 178 | k_h, k_w = module.kernel_size 179 | else: 180 | k_h, k_w = module.kernel_size, module.kernel_size 181 | out_c, out_h, out_w = out.size()[1:] 182 | 183 | return (k_h * k_w - 1) * out_h * out_w * out_c 184 | 185 | 186 | def compute_AvgPool2d_madd(module, inp, out): 187 | assert isinstance(module, nn.AvgPool2d) 188 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size()) 189 | 190 | if isinstance(module.kernel_size, (tuple, list)): 191 | k_h, k_w = module.kernel_size 192 | else: 193 | k_h, k_w = module.kernel_size, module.kernel_size 194 | out_c, out_h, out_w = out.size()[1:] 195 | 196 | kernel_add = k_h * k_w - 1 197 | kernel_avg = 1 198 | 199 | return (kernel_add + kernel_avg) * (out_h * out_w) * out_c 200 | 201 | 202 | def compute_ReLU_madd(module, inp, out): 203 | assert isinstance(module, (nn.ReLU, nn.ReLU6)) 204 | 205 | count = 1 206 | for i in inp.size()[1:]: 207 | count *= i 208 | return count 209 | 210 | 211 | def compute_Softmax_madd(module, inp, out): 212 | assert isinstance(module, nn.Softmax) 213 | assert len(inp.size()) > 1 214 | 215 | count = 1 216 | for s in inp.size()[1:]: 217 | count *= s 218 | exp = count 219 | add = count - 1 220 | div = count 221 | return exp + add + div 222 | 223 | 224 | def compute_Linear_madd(module, inp, out): 225 | assert isinstance(module, nn.Linear) 226 | assert len(inp.size()) == 2 and len(out.size()) == 2 227 | 228 | num_in_features = inp.size()[1] 229 | num_out_features = out.size()[1] 230 | 231 | mul = num_in_features 232 | add = num_in_features - 1 233 | return num_out_features * (mul + add) 234 | 235 | 236 | def compute_Bilinear_madd(module, inp1, inp2, out): 237 | assert isinstance(module, nn.Bilinear) 238 | assert len(inp1.size()) == 2 and len(inp2.size()) == 2 and len(out.size()) == 2 239 | 240 | num_in_features_1 = inp1.size()[1] 241 | num_in_features_2 = inp2.size()[1] 242 | num_out_features = out.size()[1] 243 | 244 | mul = num_in_features_1 * num_in_features_2 + num_in_features_2 245 | add = num_in_features_1 * num_in_features_2 + num_in_features_2 - 1 246 | return num_out_features * (mul + add) 247 | -------------------------------------------------------------------------------- /torchscope/scope.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | from collections import OrderedDict 6 | from .helper import compute_madd, compute_flops 7 | 8 | __all__ = ["scope"] 9 | 10 | 11 | class ModelSummary(object): 12 | def __init__(self, model, input_size, batch_size=-1, device='cpu'): 13 | super(ModelSummary, self).__init__() 14 | assert device.lower() in ['cuda', 'cpu'] 15 | self.model = model 16 | self.batch_size = batch_size 17 | 18 | if device == "cuda" and torch.cuda.is_available(): 19 | dtype = torch.cuda.FloatTensor 20 | else: 21 | dtype = torch.FloatTensor 22 | 23 | # multiple inputs to the network 24 | if isinstance(input_size, tuple): 25 | input_size = list(input_size) 26 | self.input_size = input_size 27 | 28 | # batch_size of 2 for batchnorm 29 | x = torch.rand([2] + input_size).type(dtype) 30 | 31 | # create properties 32 | self.summary = OrderedDict() 33 | self.hooks = list() 34 | 35 | # register hook 36 | model.apply(self.register_hook) 37 | 38 | # make a forward pass 39 | model(x) 40 | 41 | # remove hooks 42 | for h in self.hooks: 43 | h.remove() 44 | 45 | def register_hook(self, module): 46 | if len(list(module.children())) == 0: 47 | self.hooks.append(module.register_forward_hook(self.hook)) 48 | 49 | def hook(self, module, input, output): 50 | class_name = str(module.__class__).split(".")[-1].split("'")[0] 51 | module_idx = len(self.summary) 52 | 53 | m_key = "%s-%i" % (class_name, module_idx + 1) 54 | self.summary[m_key] = OrderedDict() 55 | self.summary[m_key]["input_shape"] = list(input[0].size()) 56 | if isinstance(output, (list, tuple)): 57 | self.summary[m_key]["output_shape"] = [[-1] + list(o.size())[1:] for o in output] 58 | else: 59 | self.summary[m_key]["output_shape"] = list(output.size()) 60 | 61 | # ------------------------- 62 | # compute module parameters 63 | # ------------------------- 64 | params = 0 65 | if hasattr(module, "weight") and hasattr(module.weight, "size"): 66 | params += torch.prod(torch.LongTensor(list(module.weight.size()))) 67 | self.summary[m_key]["trainable"] = module.weight.requires_grad 68 | if hasattr(module, "bias") and hasattr(module.bias, "size"): 69 | params += torch.prod(torch.LongTensor(list(module.bias.size()))) 70 | self.summary[m_key]["nb_params"] = params 71 | 72 | # ------------------------- 73 | # compute module flops 74 | # ------------------------- 75 | flops = compute_flops(module, input[0], output) 76 | self.summary[m_key]["flops"] = flops 77 | 78 | # ------------------------- 79 | # compute module flops 80 | # ------------------------- 81 | madds = compute_madd(module, input[0], output) 82 | self.summary[m_key]["madds"] = madds 83 | 84 | def show(self): 85 | print("------------------------------------------------------------------------------------------------------") 86 | line = "{:>20} {:>25} {:>15} {:>15} {:>15}".format("Layer (type)", "Output Shape", "Params", "FLOPs", "Madds") 87 | print(line) 88 | print("======================================================================================================") 89 | total_params, total_output, trainable_params, total_flops, total_madds = 0, 0, 0, 0, 0 90 | for layer in self.summary: 91 | line = "{:>20} {:>25} {:>15} {:>15} {:>15}".format( 92 | layer, 93 | str(self.summary[layer]["output_shape"]), 94 | "{0:,}".format(self.summary[layer]["nb_params"]), 95 | "{0:,}".format(self.summary[layer]["flops"]), 96 | "{0:,}".format(self.summary[layer]["madds"]), 97 | ) 98 | total_params += self.summary[layer]["nb_params"] 99 | total_output += np.prod(self.summary[layer]["output_shape"]) 100 | total_flops += self.summary[layer]["flops"] 101 | total_madds += self.summary[layer]["madds"] 102 | if "trainable" in self.summary[layer]: 103 | if self.summary[layer]["trainable"] == True: 104 | trainable_params += self.summary[layer]["nb_params"] 105 | print(line) 106 | 107 | total_input_size = abs(np.prod(self.input_size) * self.batch_size / (1024 ** 2.)) 108 | total_output_size = abs(2. * total_output / (1024 ** 2.)) # x2 for gradients 109 | total_params_size = abs(total_params.numpy() / (1024 ** 2.)) 110 | total_flops_size = abs(total_flops / (1e9)) 111 | total_madds_size = abs(total_madds / (1e9)) 112 | total_size = total_params_size + total_output_size + total_input_size 113 | 114 | print("======================================================================================================") 115 | print("Total params: {0:,}".format(total_params)) 116 | print("Trainable params: {0:,}".format(trainable_params)) 117 | print("Non-trainable params: {0:,}".format(total_params - trainable_params)) 118 | print("Total FLOPs: {0:,}".format(total_flops)) 119 | print("Total Madds: {0:,}".format(total_madds)) 120 | print("----------------------------------------------------------------") 121 | print("Input size (MB): %0.2f" % total_input_size) 122 | print("Forward/backward pass size (MB): %0.2f" % total_output_size) 123 | print("Params size (MB): %0.2f" % total_params_size) 124 | print("Estimated Total Size (MB): %0.2f" % total_size) 125 | print("FLOPs size (GB): %0.2f" % total_flops_size) 126 | print("Madds size (GB): %0.2f" % total_madds_size) 127 | print("----------------------------------------------------------------") 128 | 129 | 130 | def scope(model, input_size, batch_size=-1, device='cpu'): 131 | summary = ModelSummary(model, input_size, batch_size, device) 132 | summary.show() 133 | --------------------------------------------------------------------------------