├── .github ├── requirements │ ├── 1.3.1.txt │ ├── 1.4.0.txt │ ├── 1.5.0.txt │ ├── 1.6.0.txt │ ├── 1.7.1.txt │ └── 1.8.1.txt └── workflows │ ├── nemo.yml │ └── pythonpublish.yml ├── LICENSE ├── README.md ├── doc ├── Makefile ├── _static │ └── .gitinclude ├── _templates │ └── .gitinclude ├── conf.py ├── index.rst ├── nemo.quant.rst └── nemo.rst ├── nemo ├── __init__.py ├── evaluation.py ├── graph.py ├── precision.py ├── quant │ ├── __init__.py │ └── pact.py ├── relaxation.py ├── transf │ ├── __init__.py │ ├── bias.py │ ├── bn.py │ ├── common.py │ ├── deploy.py │ ├── equalize.py │ ├── export.py │ ├── pruning.py │ ├── sawb.py │ ├── statistics.py │ └── utils.py ├── transform.py └── utils.py ├── requirements.txt ├── setup.py ├── tests ├── mnist_test.py └── mobi_fq_qd_id │ ├── mobi_fq_qd.py │ ├── mobi_qd_id.py │ └── mobilenet.py └── var └── aloha.png /.github/requirements/1.3.1.txt: -------------------------------------------------------------------------------- 1 | torch==1.3.1 2 | torchvision>=0.4.1,<0.5.0 3 | numpy 4 | tqdm 5 | packaging 6 | scikit-learn 7 | -------------------------------------------------------------------------------- /.github/requirements/1.4.0.txt: -------------------------------------------------------------------------------- 1 | torch==1.4.0 2 | torchvision>=0.5.0,<0.6.0 3 | numpy 4 | tqdm 5 | packaging 6 | scikit-learn 7 | -------------------------------------------------------------------------------- /.github/requirements/1.5.0.txt: -------------------------------------------------------------------------------- 1 | torch==1.5.0 2 | torchvision>=0.6.0,<0.7.0 3 | numpy 4 | tqdm 5 | packaging 6 | scikit-learn 7 | -------------------------------------------------------------------------------- /.github/requirements/1.6.0.txt: -------------------------------------------------------------------------------- 1 | torch==1.6.0 2 | torchvision>=0.7.0,<0.8.0 3 | numpy 4 | tqdm 5 | packaging 6 | scikit-learn 7 | -------------------------------------------------------------------------------- /.github/requirements/1.7.1.txt: -------------------------------------------------------------------------------- 1 | torch==1.7.1 2 | torchvision>=0.8.0,<0.9.0 3 | numpy 4 | tqdm 5 | packaging 6 | scikit-learn 7 | -------------------------------------------------------------------------------- /.github/requirements/1.8.1.txt: -------------------------------------------------------------------------------- 1 | torch==1.8.1 2 | torchvision>=0.9.0,<0.10.0 3 | numpy 4 | tqdm 5 | packaging 6 | scikit-learn 7 | -------------------------------------------------------------------------------- /.github/workflows/nemo.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: NEMO 5 | 6 | on: 7 | push: 8 | branches: [ master ] 9 | pull_request: 10 | branches: [ master ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | strategy: 17 | matrix: 18 | python-version: [ "3.5" ] 19 | torch-version: [ "1.3.1", "1.4.0", "1.5.0", "1.6.0", "1.7.1", "1.8.1" ] 20 | 21 | steps: 22 | - uses: actions/checkout@v2.1.0 23 | - name: Set up Python ${{ matrix.python-version }} 24 | uses: actions/setup-python@v2.2.1 25 | with: 26 | python-version: ${{ matrix.python-version }} 27 | - name: Install dependencies 28 | run: | 29 | python -m pip install --upgrade pip 30 | pip install -r .github/requirements/${{ matrix.torch-version }}.txt 31 | - name: Lint with flake8 32 | run: | 33 | pip install flake8 34 | # stop the build if there are Python syntax errors or undefined names 35 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 36 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 37 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 38 | - name: MNIST test 39 | run: | 40 | cd tests; wget https://raw.githubusercontent.com/FrancescoConti/nemo_examples_helper/master/mnist_cnn_fp.pt; PYTHONPATH=`pwd`/.. python mnist_test.py 41 | - name: MobileNet FQ-QD equivalence 42 | run: | 43 | cd tests/mobi_fq_qd_id 44 | wget https://raw.githubusercontent.com/FrancescoConti/nemo_examples_helper/master/mobilenet_1.0_128_best.pth 45 | wget https://raw.githubusercontent.com/FrancescoConti/nemo_examples_helper/master/input_fq.pth 46 | PYTHONPATH=`pwd`/../.. python mobi_fq_qd.py 47 | - name: MobileNet QD-ID equivalence 48 | run: | 49 | cd tests/mobi_fq_qd_id 50 | wget https://raw.githubusercontent.com/FrancescoConti/nemo_examples_helper/master/mobilenet_1.0_128_best.pth 51 | wget https://raw.githubusercontent.com/FrancescoConti/nemo_examples_helper/master/input_fq.pth 52 | wget https://raw.githubusercontent.com/FrancescoConti/nemo_examples_helper/master/mobi_qd_id_res.pth 53 | PYTHONPATH=`pwd`/../.. python mobi_qd_id.py 54 | 55 | -------------------------------------------------------------------------------- /.github/workflows/pythonpublish.yml: -------------------------------------------------------------------------------- 1 | # This workflows will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: Upload Python Package 5 | 6 | on: 7 | release: 8 | types: [created] 9 | 10 | jobs: 11 | deploy: 12 | 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | - name: Set up Python 18 | uses: actions/setup-python@v1 19 | with: 20 | python-version: '3.x' 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install setuptools wheel twine 25 | - name: Build and publish 26 | env: 27 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 29 | run: | 30 | python setup.py sdist bdist_wheel 31 | twine upload dist/* 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NEMO (NEural Minimizer for pytOrch) 2 | **NEMO (NEural Minimizer for pytOrch)** is a small library for minimization of Deep Neural Networks developed in PyTorch, aimed at their deployment on ultra-low power, highly memory constrained platforms, in particular (but not exclusively) PULP-based microcontrollers. 3 | NEMO features include: 4 | - deployment-related transformations such as BatchNorm folding, bias removal, weight equalization 5 | - collection of statistics on activations and weights 6 | - post-training quantization 7 | - quantization-aware fine-tuning, with partially automated precision relaxation 8 | - mixed-precision quantization 9 | - bit-accurate deployment model 10 | - export to ONNX 11 | 12 | NEMO operates on three different "levels" of quantization-aware DNN representations, all built upon `torch.nn.Module` and `torch.autograd.Function`: 13 | - fake-quantized *FQ*: replaces regular activations (e.g., ReLU) with quantization-aware ones (PACT) and dynamically quantized weights (with linear PACT-like quantization), maintaining full trainability (similar to the native PyTorch support, but not based on it). 14 | - quantized-deployable *QD*: replaces all function with deployment-equivalent versions, trading off trainability for a more accurate representation of numerical behavior on real hardware. 15 | - integer-deployable *ID*: replaces all activation and weight tensors used along the network with integer-based ones. It aims at bit-accurate representation of actual hardware behavior. 16 | All the quantized representations support mixed-precision weights (signed and asymmetric) and activations (unsigned). The current version of NEMO targets per-layer quantization; work on per-channel quantization is in progress. 17 | 18 | NEMO is organized as a Python library that can be applied with relatively small changes to an existing PyTorch based script or training framework. 19 | 20 | # Installation and requirements 21 | The NEMO library currently supports PyTorch >= 1.3.1 and runs on Python >= 3.5. 22 | To install it from PyPI, just run: 23 | ``` 24 | pip install pytorch-nemo 25 | ``` 26 | You can also install a development (and editable) version of NEMO by directly downloading this repo: 27 | ``` 28 | git clone https://github.com/pulp-platform/nemo 29 | cd nemo 30 | pip install -e . 31 | ``` 32 | Then, you can import it in your script using 33 | ``` 34 | import nemo 35 | ``` 36 | 37 | # Example 38 | - MNIST post-training quantization: https://colab.research.google.com/drive/1AmcITfN2ELQe07WKQ9szaxq-WSu4hdQb 39 | 40 | # Documentation 41 | Full documentation for NEMO is under development (see `doc` folder). You can find a technical report covering the deployment-aware quantization methodology here: https://arxiv.org/abs/2004.05930 42 | 43 | # License 44 | NEMO is released under Apache 2.0, see the LICENSE file in the root of this repository for details. 45 | 46 | # Acknowledgements 47 | ![ALOHA Logo](/var/aloha.png) 48 | 49 | NEMO is an outcome of the European Commission [Horizon 2020 ALOHA Project](https://www.aloha-h2020.eu/), funded under the EU's Horizon 2020 Research and Innovation Programme, grant agreement no. 780788. 50 | -------------------------------------------------------------------------------- /doc/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SOURCEDIR = . 8 | BUILDDIR = _build 9 | 10 | # Put it first so that "make" without argument is like "make help". 11 | help: 12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 13 | 14 | .PHONY: help Makefile 15 | 16 | # Catch-all target: route all unknown targets to Sphinx using the new 17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 18 | %: Makefile 19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /doc/_static/.gitinclude: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pulp-platform/nemo/5ea3338ae172f96e996bdf75a5dacdf795282929/doc/_static/.gitinclude -------------------------------------------------------------------------------- /doc/_templates/.gitinclude: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pulp-platform/nemo/5ea3338ae172f96e996bdf75a5dacdf795282929/doc/_templates/.gitinclude -------------------------------------------------------------------------------- /doc/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Configuration file for the Sphinx documentation builder. 4 | # 5 | # This file does only contain a selection of the most common options. For a 6 | # full list see the documentation: 7 | # http://www.sphinx-doc.org/en/master/config 8 | 9 | # -- Path setup -------------------------------------------------------------- 10 | 11 | # If extensions (or modules to document with autodoc) are in another directory, 12 | # add these directories to sys.path here. If the directory is relative to the 13 | # documentation root, use os.path.abspath to make it absolute, like shown here. 14 | # 15 | import os 16 | import sys 17 | sys.path.insert(0, os.path.abspath('../..')) 18 | 19 | 20 | # -- Project information ----------------------------------------------------- 21 | 22 | project = 'NeMO' 23 | copyright = '2019-2020, ETH Zurich (released under Apache 2.0)' 24 | author = 'Francesco Conti (fconti@iis.ee.ethz.ch)' 25 | 26 | # The short X.Y version 27 | version = '' 28 | # The full version, including alpha/beta/rc tags 29 | release = 'v0.1' 30 | 31 | 32 | # -- General configuration --------------------------------------------------- 33 | 34 | # If your documentation needs a minimal Sphinx version, state it here. 35 | # 36 | # needs_sphinx = '1.0' 37 | 38 | # Add any Sphinx extension module names here, as strings. They can be 39 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 40 | # ones. 41 | extensions = [ 42 | 'sphinx.ext.autodoc', 43 | 'sphinx.ext.todo', 44 | 'sphinx.ext.imgmath', 45 | 'sphinx.ext.viewcode', 46 | 'sphinx_rtd_theme' 47 | ] 48 | 49 | # Add any paths that contain templates here, relative to this directory. 50 | templates_path = ['_templates'] 51 | 52 | # The suffix(es) of source filenames. 53 | # You can specify multiple suffix as a list of string: 54 | # 55 | # source_suffix = ['.rst', '.md'] 56 | source_suffix = '.rst' 57 | 58 | # The master toctree document. 59 | master_doc = 'index' 60 | 61 | # The language for content autogenerated by Sphinx. Refer to documentation 62 | # for a list of supported languages. 63 | # 64 | # This is also used if you do content translation via gettext catalogs. 65 | # Usually you set "language" from the command line for these cases. 66 | language = None 67 | 68 | # List of patterns, relative to source directory, that match files and 69 | # directories to ignore when looking for source files. 70 | # This pattern also affects html_static_path and html_extra_path. 71 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 72 | 73 | # The name of the Pygments (syntax highlighting) style to use. 74 | pygments_style = None 75 | 76 | 77 | # -- Options for HTML output ------------------------------------------------- 78 | 79 | # 80 | # https://blog.deimos.fr/2014/10/02/sphinxdoc-and-readthedocs-theme-tricks-2/ 81 | on_rtd = os.environ.get('READTHEDOCS', None) == 'True' 82 | 83 | if not on_rtd: # only import and set the theme if we're building docs locally 84 | import sphinx_rtd_theme 85 | html_theme = 'sphinx_rtd_theme' 86 | html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] 87 | # Override default css to get a larger width for local build 88 | def setup(app): 89 | #app.add_javascript("custom.js") 90 | app.add_stylesheet('theme_overrides.css') 91 | else: 92 | # Override default css to get a larger width for ReadTheDoc build 93 | html_context = { 94 | 'css_files': [ 95 | 'https://media.readthedocs.org/css/sphinx_rtd_theme.css', 96 | 'https://media.readthedocs.org/css/readthedocs-doc-embed.css', 97 | '_static/theme_overrides.css', 98 | ], 99 | } 100 | 101 | # Add any paths that contain custom static files (such as style sheets) here, 102 | # relative to this directory. They are copied after the builtin static files, 103 | # so a file named "default.css" will overwrite the builtin "default.css". 104 | html_static_path = ['_static'] 105 | 106 | # Custom sidebar templates, must be a dictionary that maps document names 107 | # to template names. 108 | # 109 | # The default sidebars (for documents that don't match any pattern) are 110 | # defined by theme itself. Builtin themes are using these templates by 111 | # default: ``['localtoc.html', 'relations.html', 'sourcelink.html', 112 | # 'searchbox.html']``. 113 | # 114 | # html_sidebars = {} 115 | 116 | 117 | # -- Options for HTMLHelp output --------------------------------------------- 118 | 119 | # Output file base name for HTML help builder. 120 | htmlhelp_basename = 'NeMOdoc' 121 | 122 | 123 | # -- Options for LaTeX output ------------------------------------------------ 124 | 125 | latex_elements = { 126 | # The paper size ('letterpaper' or 'a4paper'). 127 | # 128 | # 'papersize': 'letterpaper', 129 | 130 | # The font size ('10pt', '11pt' or '12pt'). 131 | # 132 | # 'pointsize': '10pt', 133 | 134 | # Additional stuff for the LaTeX preamble. 135 | # 136 | # 'preamble': '', 137 | 138 | # Latex figure (float) alignment 139 | # 140 | # 'figure_align': 'htbp', 141 | } 142 | 143 | # Grouping the document tree into LaTeX files. List of tuples 144 | # (source start file, target name, title, 145 | # author, documentclass [howto, manual, or own class]). 146 | latex_documents = [ 147 | (master_doc, 'NeMO.tex', 'NeMO Documentation', 148 | 'Francesco Conti (fconti@iis.ee.ethz.ch)', 'manual'), 149 | ] 150 | 151 | 152 | # -- Options for manual page output ------------------------------------------ 153 | 154 | # One entry per manual page. List of tuples 155 | # (source start file, name, description, authors, manual section). 156 | man_pages = [ 157 | (master_doc, 'nemo', 'NeMO Documentation', 158 | [author], 1) 159 | ] 160 | 161 | 162 | # -- Options for Texinfo output ---------------------------------------------- 163 | 164 | # Grouping the document tree into Texinfo files. List of tuples 165 | # (source start file, target name, title, author, 166 | # dir menu entry, description, category) 167 | texinfo_documents = [ 168 | (master_doc, 'NeMO', 'NeMO Documentation', 169 | author, 'NeMO', 'One line description of project.', 170 | 'Miscellaneous'), 171 | ] 172 | 173 | 174 | # -- Options for Epub output ------------------------------------------------- 175 | 176 | # Bibliographic Dublin Core info. 177 | epub_title = project 178 | 179 | # The unique identifier of the text. This can be a ISBN number 180 | # or the project homepage. 181 | # 182 | # epub_identifier = '' 183 | 184 | # A unique identification for the text. 185 | # 186 | # epub_uid = '' 187 | 188 | # A list of files that should not be packed into the epub file. 189 | epub_exclude_files = ['search.html'] 190 | 191 | 192 | # -- Extension configuration ------------------------------------------------- 193 | 194 | # -- Options for todo extension ---------------------------------------------- 195 | 196 | # If true, `todo` and `todoList` produce output, else they produce nothing. 197 | todo_include_todos = True 198 | -------------------------------------------------------------------------------- /doc/index.rst: -------------------------------------------------------------------------------- 1 | .. NeMO documentation master file, created by 2 | sphinx-quickstart on Fri Oct 18 14:53:45 2019. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to NeMO's documentation! 7 | ================================ 8 | 9 | .. toctree:: 10 | :maxdepth: 2 11 | :caption: Contents: 12 | 13 | nemo 14 | 15 | Indices and tables 16 | ================== 17 | 18 | * :ref:`genindex` 19 | * :ref:`modindex` 20 | * :ref:`search` 21 | -------------------------------------------------------------------------------- /doc/nemo.quant.rst: -------------------------------------------------------------------------------- 1 | nemo.quant package 2 | ================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | nemo.quant.pact\_quant module 8 | ----------------------------- 9 | 10 | .. automodule:: nemo.quant.pact 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | Module contents 16 | --------------- 17 | 18 | .. automodule:: nemo.quant 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | -------------------------------------------------------------------------------- /doc/nemo.rst: -------------------------------------------------------------------------------- 1 | nemo package 2 | ============ 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | 9 | nemo.quant 10 | 11 | Submodules 12 | ---------- 13 | 14 | nemo.evaluation module 15 | ---------------------- 16 | 17 | .. automodule:: nemo.evaluation 18 | :members: 19 | :undoc-members: 20 | :show-inheritance: 21 | 22 | nemo.graph module 23 | ----------------- 24 | 25 | .. automodule:: nemo.graph 26 | :members: 27 | :undoc-members: 28 | :show-inheritance: 29 | 30 | nemo.precision module 31 | --------------------- 32 | 33 | .. automodule:: nemo.precision 34 | :members: 35 | :undoc-members: 36 | :show-inheritance: 37 | 38 | nemo.relaxation module 39 | ---------------------- 40 | 41 | .. automodule:: nemo.relaxation 42 | :members: 43 | :undoc-members: 44 | :show-inheritance: 45 | 46 | nemo.transform module 47 | --------------------- 48 | 49 | .. automodule:: nemo.transform 50 | :members: 51 | :undoc-members: 52 | :show-inheritance: 53 | 54 | nemo.utils module 55 | ----------------- 56 | 57 | .. automodule:: nemo.utils 58 | :members: 59 | :undoc-members: 60 | :show-inheritance: 61 | 62 | 63 | Module contents 64 | --------------- 65 | 66 | .. automodule:: nemo 67 | :members: 68 | :undoc-members: 69 | :show-inheritance: 70 | -------------------------------------------------------------------------------- /nemo/__init__.py: -------------------------------------------------------------------------------- 1 | # __init__.py 2 | # Francesco Conti 3 | # Alfio Di Mauro 4 | # 5 | # Copyright (C) 2018-2020 ETH Zurich 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | __all__ = ["quant", "transform", "relaxation", "utils", "evaluation", "graph"] 20 | from . import quant, transform, relaxation, utils, evaluation, graph 21 | 22 | -------------------------------------------------------------------------------- /nemo/evaluation.py: -------------------------------------------------------------------------------- 1 | # 2 | # evaluation.py 3 | # Francesco Conti 4 | # Alfio Di Mauro 5 | # 6 | # Copyright (C) 2018-2020 ETH Zurich 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | 20 | import numpy as np 21 | import logging 22 | import nemo 23 | import torch 24 | from collections import OrderedDict 25 | from tqdm import tqdm 26 | 27 | class EvaluationEngine(): 28 | def __init__(self, net, validate_fn=None, validate_data=None, precision_rule=None, min_prec_dict=None): 29 | super(EvaluationEngine, self).__init__() 30 | 31 | self.precision_rule = precision_rule 32 | if self.precision_rule is None: 33 | self.scale_activations = True 34 | self.scale_weights = True 35 | self.min_prec_dict = min_prec_dict 36 | self.loss_prev = 1e3 37 | self.net = net 38 | self.validate_fn = validate_fn 39 | self.validate_data = validate_data 40 | self.reset_grids() 41 | 42 | def reset_grids(self, W_start=None, W_stop=None, W_step=None, x_start=None, x_stop=None, x_step=None): 43 | first_W = self.precision_rule['0']['W_bits'] if W_start is None else W_start 44 | first_x = self.precision_rule['0']['x_bits'] if x_start is None else x_start 45 | step_W = float(self.precision_rule['W_bit_scaler']) if W_step is None else W_step 46 | step_x = float(self.precision_rule['x_bit_scaler']) if x_step is None else x_step 47 | last_W = float(self.precision_rule['W_bit_stop_condition'])+step_W if W_stop is None else W_stop 48 | last_x = float(self.precision_rule['x_bit_stop_condition'])+step_x if x_stop is None else x_stop 49 | W = np.arange(first_W, last_W, step_W) 50 | x = np.arange(first_x, last_x, step_x) 51 | if len(W) == 0: 52 | W = np.asarray([first_W,]) 53 | if len(x) == 0: 54 | x = np.asarray([first_x,]) 55 | logging.info("[Evaluation]\t Setting up new grid: W=%s x=%s" % (W, x)) 56 | wgrid, xgrid = np.meshgrid(W, x) 57 | self.acc = np.zeros_like(wgrid, dtype='float32') 58 | self.wgrid = wgrid.flatten() 59 | self.xgrid = xgrid.flatten() 60 | self.idx = -1 61 | 62 | def report(self, acc): 63 | sh = self.acc.shape 64 | self.acc = self.acc.flatten() 65 | self.acc[self.idx] = acc.item() 66 | self.acc = self.acc.reshape(sh) 67 | 68 | def __suffix(self): 69 | return "_%.1fx%.1fb" % (self.net.W_precision.get_bits(), self.net.x_precision.get_bits()) 70 | 71 | def step(self, checkpoint_name='checkpoint', verbose=False): 72 | if self.precision_rule is None: 73 | return 74 | try: 75 | curr_regime = self.precision_rule[0] 76 | except KeyError: 77 | try: 78 | curr_regime = self.precision_rule[str(0)] 79 | except KeyError: 80 | curr_regime = None 81 | 82 | if self.idx == self.wgrid.shape[0] - 1: 83 | return False 84 | 85 | self.idx += 1 86 | self.net.change_precision(bits=self.wgrid[self.idx], scale_activations=False, scale_weights=True, reset_alpha=False, verbose=verbose, min_prec_dict=self.min_prec_dict) 87 | self.net.change_precision(bits=self.xgrid[self.idx], scale_activations=True, scale_weights=False, reset_alpha=False, verbose=verbose, min_prec_dict=self.min_prec_dict) 88 | return True 89 | 90 | # the rationale here is to define frontiers using a certain threshold, e.g. 90% of the top accuracy 91 | # reached by pure evaluation. configurations within the frontier are considered 'good enough' 92 | # and not worth being considered for retraining (only minor fine-tuning is necessary). 93 | def get_next_config(self, upper_threshold=0.9, strategy='min_precision', verbose=False, Wbits_curr=None, xbits_curr=None, timeout=25): 94 | def create_bins(upper_threshold): 95 | bins = np.asarray([0, 1.0-upper_threshold, upper_threshold]) * self.acc.max() 96 | acc_idx = np.digitize(self.acc, bins) 97 | return acc_idx 98 | acc_idx = create_bins(upper_threshold) 99 | for i in range(timeout): 100 | if len(acc_idx[acc_idx == 2]) == 0: 101 | # exponentially increase threshold 102 | upper_threshold /= upper_threshold 103 | if verbose: 104 | logging.info("[Evaluation]\t No middle-bin element, changing threshold to %.3e." % upper_threshold) 105 | acc_idx = create_bins(upper_threshold) 106 | if len(acc_idx[acc_idx == 2]) == 0: 107 | return self.wgrid.reshape(self.acc.shape)[0,0], self.xgrid.reshape(self.acc.shape)[0,0] 108 | if verbose: 109 | logging.info("[Evaluation]\t Top-bin: %s" % (np.dstack((self.wgrid.reshape(self.acc.shape)[acc_idx == 3], self.xgrid.reshape(self.acc.shape)[acc_idx == 3], self.acc[acc_idx == 3])))) 110 | logging.info("[Evaluation]\t Middle-bin: %s" % (np.dstack((self.wgrid.reshape(self.acc.shape)[acc_idx == 2], self.xgrid.reshape(self.acc.shape)[acc_idx == 2], self.acc[acc_idx == 2])))) 111 | logging.info("[Evaluation]\t Bottom-bin: %s" % (np.dstack((self.wgrid.reshape(self.acc.shape)[acc_idx == 1], self.xgrid.reshape(self.acc.shape)[acc_idx == 1], self.acc[acc_idx == 1])))) 112 | if strategy == 'max_accuracy': 113 | if verbose: 114 | logging.info("[Evaluation]\t Select the most accurate one from the middle-bin.") 115 | idxs = np.unravel_index(self.acc[acc_idx == 2].argmax(), self.acc.shape) 116 | elif strategy == 'min_precision': 117 | if verbose: 118 | logging.info("[Evaluation]\t Select the lowest precision one from the middle-bin.") 119 | xmin = self.xgrid.reshape(self.acc.shape)[acc_idx == 2].min() 120 | mask = np.logical_and(acc_idx == 2, self.xgrid.reshape(self.acc.shape)==xmin) 121 | wtmpgrid = np.copy(self.wgrid.reshape(self.acc.shape)) 122 | wtmpgrid[np.logical_not(mask)] = 1e6 123 | idxs = np.unravel_index(wtmpgrid.argmin(), self.acc.shape) 124 | Widx = idxs[1] 125 | xidx = idxs[0] 126 | if verbose: 127 | logging.info("[Evaluation]\t Returning idxs: %d,%d prec %d,%d" % (Widx, xidx, self.wgrid[Widx], self.xgrid[xidx])) 128 | return self.wgrid.reshape(self.acc.shape)[0,Widx], self.xgrid.reshape(self.acc.shape)[xidx,0] 129 | -------------------------------------------------------------------------------- /nemo/graph.py: -------------------------------------------------------------------------------- 1 | # 2 | # graph.py 3 | # Francesco Conti 4 | # 5 | # Copyright (C) 2018-2020 ETH Zurich 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | import torch 20 | import torch.onnx.utils 21 | from nemo.precision import Precision 22 | from nemo.quant.pact import * 23 | from torch.nn.modules.utils import _single,_pair 24 | from collections import OrderedDict 25 | import types 26 | import logging 27 | import numpy as np 28 | import copy 29 | import math 30 | import torchvision.models 31 | import re 32 | from packaging import version 33 | 34 | def _hier_flat_dict_build(module, name): 35 | for n,m in module.named_children(): 36 | if n == name: 37 | return m 38 | elif n == name.split('.')[0]: 39 | return _hier_flat_dict_build(m, '.'.join(name.split('.')[1:])) 40 | return module 41 | 42 | def _hier_get_name_from_child(module, child, name): 43 | for n,m in module.named_children(): 44 | name_plus_n = name + '.' + n if name != '' else n 45 | if m == child: 46 | return name_plus_n 47 | elif len(list(m.named_children())) > 0: 48 | name_child = _hier_get_name_from_child(m, child, name_plus_n) 49 | if name_child is not '': 50 | return name_child 51 | return '' 52 | 53 | # necessary for PyTorch >= 1.4! See https://github.com/pytorch/pytorch/issues/33463#issuecomment-606399944 54 | class scope_name_workaround(object): 55 | def __init__(self, mother): 56 | self.backup = None 57 | self.mother = mother 58 | 59 | def __enter__(self): 60 | def _tracing_name(self_, tracing_state): 61 | if not tracing_state._traced_module_stack: 62 | return None 63 | module = tracing_state._traced_module_stack[-1] 64 | for name, child in module.named_children(): 65 | if child is self_: 66 | return name 67 | return None 68 | 69 | def _slow_forward(self_, *input, **kwargs): 70 | tracing_state = torch._C._get_tracing_state() 71 | if not tracing_state or isinstance(self_.forward, torch._C.ScriptMethod): 72 | return self_.forward(*input, **kwargs) 73 | if not hasattr(tracing_state, '_traced_module_stack'): 74 | tracing_state._traced_module_stack = [] 75 | name = _tracing_name(self_, tracing_state) 76 | scoped_name = _hier_get_name_from_child(self.mother, self_, '') 77 | tracing_state.push_scope(scoped_name) 78 | tracing_state._traced_module_stack.append(self_) 79 | try: 80 | result = self_.forward(*input, **kwargs) 81 | finally: 82 | tracing_state.pop_scope() 83 | tracing_state._traced_module_stack.pop() 84 | return result 85 | 86 | self.backup = torch.nn.Module._slow_forward 87 | setattr(torch.nn.Module, '_slow_forward', _slow_forward) 88 | 89 | def __exit__(self, type, value, tb): 90 | setattr(torch.nn.Module, '_slow_forward', self.backup) 91 | 92 | def onnx_name_2_pytorch_name(name): 93 | return name.split('/')[-1] 94 | 95 | class DeployNode(object): 96 | def __init__(self, key="", incoming=None, outgoing=None): 97 | if incoming is None: 98 | self.incoming = [] 99 | else: 100 | self.incoming = incoming 101 | if outgoing is None: 102 | self.outgoing = [] 103 | else: 104 | self.outgoing = outgoing 105 | self.key = key 106 | self.input_node = False 107 | 108 | def is_input(self): 109 | return True if len(self.incoming)==0 else False 110 | 111 | def is_output(self): 112 | return True if len(self.outgoing)==0 else False 113 | 114 | def _traverse_forward(self, fn=None, recurse_max="inf", reduc_fn=lambda ret,x: ret+x, ret_default=0, **kwargs): 115 | ret = ret_default 116 | recurse_max = "inf" if recurse_max == "inf" else recurse_max-1 117 | for o in self.outgoing: 118 | ret = reduc_fn(ret, fn(o, **kwargs)) 119 | if recurse_max == "inf" or recurse_max > 0: 120 | ret = reduc_fn(ret, o._traverse_forward(fn, recurse_max=recurse_max, reduc_fn=reduc_fn, **kwargs)) 121 | return ret 122 | 123 | def _traverse_backward(self, fn=None, recurse_max="inf", reduc_fn=lambda ret,x: ret+x, ret_default=0, **kwargs): 124 | ret = ret_default 125 | recurse_max = "inf" if recurse_max == "inf" else recurse_max-1 126 | for o in self.incoming: 127 | ret = reduc_fn(ret, fn(o, **kwargs)) 128 | if recurse_max == "inf" or recurse_max > 0: 129 | ret = reduc_fn(ret, o._traverse_backward(fn, recurse_max=recurse_max, reduc_fn=reduc_fn, ret_default=ret_default, **kwargs)) 130 | return ret 131 | 132 | class DeployGraph(object): 133 | def __init__(self, module, dummy_input): 134 | if version.parse(torch.__version__) < version.parse('1.4.0'): 135 | trace, _, _ = torch.jit.get_trace_graph(module, dummy_input, _force_outplace=True, _return_inputs_states=True) 136 | torch.onnx._optimize_trace(trace, torch.onnx.OperatorExportTypes.ONNX) 137 | graph = trace.graph() 138 | else: 139 | with scope_name_workaround(module): 140 | try: 141 | graph, _params_dict, _torch_out = torch.onnx.utils._model_to_graph(module, dummy_input, propagate=True, _retain_param_name=True) 142 | except TypeError: 143 | graph, _params_dict, _torch_out = torch.onnx.utils._model_to_graph(module, dummy_input, _retain_param_name=True) 144 | input_dict = {} 145 | output_dict = {} 146 | self.non_unique_names_dict = {} 147 | self.module = module 148 | 149 | # build list of inputs/outputs 150 | for i,node in enumerate(graph.nodes()): 151 | op_name = node.scopeName() 152 | module_name = onnx_name_2_pytorch_name(op_name) + "/" + node.kind().lstrip('::onnx') + "_" + str(i) 153 | self.non_unique_names_dict[module_name] = onnx_name_2_pytorch_name(op_name) 154 | for i in node.inputs(): 155 | try: 156 | input_dict[i.debugName()].append(module_name) 157 | except KeyError: 158 | input_dict[i.debugName()] = [module_name,] 159 | for o in node.outputs(): 160 | output_dict[o.debugName()] = module_name 161 | 162 | # build a flat dictionary of modules -- this is centralized in the DeployGraph to ease nemo.transform 163 | all_nodes = list(set([i for i in output_dict.values()]) | set([i for l in input_dict.values() for i in l])) 164 | self.module_nodes = OrderedDict([]) 165 | for n in all_nodes: 166 | nn = n.split("/")[0] 167 | self.module_nodes[n] = _hier_flat_dict_build(module, nn) 168 | 169 | # build a flat dictionary of DeployNodes 170 | self.nodes = OrderedDict([]) 171 | for n in all_nodes: 172 | self.nodes[n] = DeployNode(key=n) 173 | 174 | # populate outgoing connections 175 | for ok in output_dict.keys(): 176 | out = output_dict[ok] 177 | try: 178 | ils = input_dict[ok] 179 | except KeyError: 180 | # unused outputs 181 | ils = None 182 | if ils is not None: 183 | dnl = [self.nodes[i] for i in ils] 184 | self.nodes[out].outgoing.extend(dnl) 185 | 186 | # populate incoming connections 187 | nodes_copy = copy.copy(self.nodes) # shallow copy 188 | for k,n in nodes_copy.items(): 189 | for i,m in enumerate(n.outgoing): 190 | self.nodes[k].outgoing[i].incoming.append(n) 191 | 192 | # identify input nodes (only 1!) FIXME 193 | self.input_nodes = [] 194 | for k,n in self.nodes.items(): 195 | if len(n.incoming) == 0: 196 | n.input_node = True 197 | self.input_nodes.append(n) 198 | 199 | self.jit_graph = graph 200 | 201 | def rebuild_module_dict(self): 202 | # build a flat dictionary of modules -- this is centralized in the DeployGraph to ease nemo.transform 203 | for n in self.module_nodes.keys(): 204 | nn = n.split("/")[0] 205 | self.module_nodes[n] = _hier_flat_dict_build(self.module, nn) 206 | 207 | def print_modules(self): 208 | for k,n in self.nodes.items(): 209 | print (k, self.module_nodes[k]) 210 | 211 | def print_forward_edges(self): 212 | for k,n in self.nodes.items(): 213 | print (k, [m.key for m in n.outgoing]) 214 | 215 | def print_backward_edges(self): 216 | for k,n in self.nodes.items(): 217 | print (k, [m.key for m in n.incoming]) 218 | 219 | def print_jit_graph(self): 220 | print(self.jit_graph) 221 | 222 | def get_eps_at(self, key, eps_in, use_non_unique_name=True, verbose=False): 223 | # back-track route to input 224 | # the procedure is repeated for each incoming edge to the target node 225 | target = None 226 | if use_non_unique_name: 227 | for kk,el in list(self.non_unique_names_dict.items()): 228 | if el == key: 229 | target = self.nodes[kk] 230 | break 231 | else: 232 | target = self.nodes.get(key, None) 233 | if target is None: 234 | print("[nemo-graph] Warning: %s is not a module name" % key) 235 | return None 236 | eps_list = [] 237 | for incoming_idx in range(len(target.incoming)): 238 | curr = target 239 | route = [] 240 | while not curr.is_input(): 241 | if verbose: 242 | print("[nemo-graph] backward %s" % (curr.key)) 243 | k = curr.key 244 | # if current node is the target, use the incoming_idx route 245 | if curr == target: 246 | curr = curr.incoming[incoming_idx] 247 | else: 248 | curr = curr.incoming[0] 249 | route.append(curr.outgoing.index(self.nodes[k])) 250 | # forward-track route to node, computing eps 251 | route = route[::-1] 252 | eps = eps_in 253 | for idx in route: 254 | if verbose: 255 | print("[nemo-graph] forward %s %d" % (curr.key, idx)) 256 | if hasattr(self.module_nodes[curr.key], 'get_output_eps'): 257 | eps = self.module_nodes[curr.key].get_output_eps(eps) 258 | try: 259 | curr = curr.outgoing[idx] 260 | except IndexError: 261 | print("[nemo-graph] Warning: %s has no outgoing edge" % (curr.key)) 262 | break 263 | if type(eps) is float: 264 | eps_list.append(torch.tensor(eps)) 265 | else: 266 | eps_list.append(eps) 267 | if len(eps_list) == 1: 268 | return eps_list[0] 269 | else: 270 | return eps_list 271 | 272 | def get_supernodes(self, verbose=False): 273 | # collect all activation nodes 274 | actnodes = [] 275 | for k,n in self.nodes.items(): 276 | if isinstance(self.module_nodes[n.key], PACT_Act) or \ 277 | isinstance(self.module_nodes[n.key], PACT_ThresholdAct) or \ 278 | isinstance(self.module_nodes[n.key], PACT_IntegerAct): 279 | actnodes.append(n) 280 | supernodes = OrderedDict([]) 281 | # for each activation node, backtrack until another activation node is found 282 | for target in actnodes: 283 | # here we assume all activation nodes have only one incoming path, which should be reasonable 284 | curr = target.incoming[0] 285 | route = [] 286 | while not isinstance(self.module_nodes[curr.key], PACT_Act) or \ 287 | isinstance(self.module_nodes[curr.key], PACT_ThresholdAct) or \ 288 | isinstance(self.module_nodes[curr.key], PACT_IntegerAct): 289 | route.append((self.non_unique_names_dict[curr.key], self.module_nodes[curr.key])) 290 | if verbose: 291 | print("[nemo-graph] backward %s" % (curr.key)) 292 | try: 293 | curr = curr.incoming[0] 294 | except IndexError: 295 | break 296 | # forward-track route to node 297 | supernodes[self.non_unique_names_dict[target.key]] = {'supernode': route[::-1], 'previous': self.non_unique_names_dict[curr.key]} 298 | return supernodes 299 | -------------------------------------------------------------------------------- /nemo/precision.py: -------------------------------------------------------------------------------- 1 | # 2 | # precision.py 3 | # Francesco Conti 4 | # Alfio Di Mauro 5 | # 6 | # Copyright (C) 2018-2020 ETH Zurich 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | 20 | import numpy as np 21 | 22 | MIN_REPR_PRECISION = 1e-15 23 | MAX_REPR_SCALE = 1e+15 24 | MAX_NB_BITS = 32 25 | 26 | class Precision(): 27 | def __init__(self, bits=None, scale=None, positive=False): 28 | super(Precision, self).__init__() 29 | self.bits = bits 30 | self.scale = scale 31 | self.positive = positive 32 | 33 | def __gt__(self, other): 34 | try: 35 | if self.bits > other.bits: 36 | return True 37 | else: 38 | return False 39 | except AttributeError: 40 | if self.bits > other: 41 | return True 42 | else: 43 | return False 44 | 45 | def __lt__(self, other): 46 | try: 47 | if self.bits < other.bits: 48 | return True 49 | else: 50 | return False 51 | except AttributeError: 52 | if self.bits < other: 53 | return True 54 | else: 55 | return False 56 | 57 | def __ge__(self, other): 58 | try: 59 | if self.bits >= other.bits: 60 | return True 61 | else: 62 | return False 63 | except AttributeError: 64 | if self.bits >= other: 65 | return True 66 | else: 67 | return False 68 | 69 | def __le__(self, other): 70 | try: 71 | if self.bits <= other.bits: 72 | return True 73 | else: 74 | return False 75 | except AttributeError: 76 | if self.bits <= other: 77 | return True 78 | else: 79 | return False 80 | 81 | def __eq__(self, other): 82 | try: 83 | if self.bits == other.bits: 84 | return True 85 | else: 86 | return False 87 | except AttributeError: 88 | if self.bits == other: 89 | return True 90 | else: 91 | return False 92 | 93 | def __ne__(self, other): 94 | try: 95 | if self.bits != other.bits: 96 | return True 97 | else: 98 | return False 99 | except AttributeError: 100 | if self.bits != other: 101 | return True 102 | else: 103 | return False 104 | 105 | def set_eps(self, eps): 106 | self.bits = np.log2(-eps) 107 | 108 | def set_bits(self, bits): 109 | self.bits = bits 110 | 111 | def set_clip(self, clip): 112 | self.scale = clip 113 | 114 | def set_scale(self, scale): 115 | self.scale = scale 116 | 117 | def get_eps(self): 118 | if self.bits is None or self.scale is None: 119 | return MIN_REPR_PRECISION 120 | if not self.positive: 121 | return 2.0**(-(self.bits-1)) * self.scale 122 | else: 123 | return 2.0**(-self.bits) * self.scale 124 | 125 | def get_clip(self): 126 | return self.get_scale() 127 | 128 | def get_scale(self): 129 | if self.scale is None: 130 | return MAX_REPR_SCALE 131 | return self.scale 132 | 133 | def get_bits(self): 134 | if self.bits is None: 135 | return MAX_NB_BITS 136 | return self.bits 137 | -------------------------------------------------------------------------------- /nemo/quant/__init__.py: -------------------------------------------------------------------------------- 1 | # __init__.py 2 | # Francesco Conti 3 | # Alfio Di Mauro 4 | # 5 | # Copyright (C) 2018-2020 ETH Zurich 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | __all__ = ["pact", "lq_quant"] 20 | 21 | -------------------------------------------------------------------------------- /nemo/relaxation.py: -------------------------------------------------------------------------------- 1 | # 2 | # relaxation.py 3 | # Francesco Conti 4 | # Alfio Di Mauro 5 | # 6 | # Copyright (C) 2018-2020 ETH Zurich 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | 20 | import numpy as np 21 | import logging 22 | import nemo 23 | import torch 24 | from collections import OrderedDict 25 | from tqdm import tqdm 26 | 27 | class RelaxationEngine(): 28 | r"""Main engine class for weight/activation precision relaxation procedure. 29 | 30 | :param net: Module or network on which the relaxation procedure will be performed (must have a `change_precision` method). 31 | :type net: `torch.nn.Module` 32 | 33 | :param optimizer: A pointer to the optimizer being used for the training. 34 | :type optimizer: `torch.optim.Optimizer` 35 | 36 | :param criterion: Loss function being used as a quality metric. 37 | :type criterion: `torch.autograd.Function` 38 | 39 | :param trainloader: Loader used for training values (used for relaxation evaluation). 40 | :type trainloader: `torch.utils.data.DataLoader` 41 | 42 | :param precision_rule: A dictionary describing the rules to be used for the precision relaxation. 43 | :type precision_rule: `dict` or `collections.OrderedDict` 44 | 45 | :param tbx_writer: TensorBoardX writer for data logging. Defaults to `None`. 46 | :type tbx_writer: `tensorboardx.SummaryWriter` 47 | 48 | :param reset_alpha_weights: If True, reset the W_alpha and W_beta parameters at precision change. 49 | :type reset_alpha_weights: `bool` 50 | 51 | :param min_prec_dict: Dictionary of minimum allowed precision for all parameters. 52 | :type min_prec_dict: `dict` or `collections.OrderedDict` 53 | 54 | :param evaluator: Evaluation engine for precision selection heuristics. 55 | :type evaluator: `nemo.evaluation.EvaluationEngine` 56 | 57 | :param evaluator_threshold: Threshold to be used for precision binning (default 0.9). 58 | :type evaluator_threshold: `float` 59 | 60 | :param evaluator_verbose: If True, print more information from evaluation engine. 61 | :type evaluator_verbose: `bool` 62 | 63 | :param evaluator_strategy: Can be 'max_accuracy' (default) or 'min_precision'. 64 | :type evaluator_strategy: `string` 65 | 66 | :param divergence_policy: Can be 'change_precision' (default) or 'lower_lr'. 67 | :type divergence_policy: `string` 68 | 69 | :param divergence_lr_scaler: LR scaling factor for 'lower_lr' divergence policy. 70 | :type divergence_lr_scaler: `float` 71 | 72 | The :py:class:`RelaxationEngine` represents a precision-scaling procedure utilizing the relaxation 73 | heuristic. 74 | The main parameters of the relaxation heuristic are passed through the `precision_rule` 75 | dictionary when initializing the class. 76 | 77 | Assuming `p` is the `precision_rule`, the relaxation heuristic keeps an internal memory 78 | of the loss and its relative variation (`delta_loss`) over the last `p[running_avg_memory]` epochs. 79 | The mean and standard deviation of `delta_loss` over `p[running_avg_memory]` epochs are compared with 80 | `p[delta_loss_less_than]` and `p[delta_loss_running_std_stale]`. Moreover the absolute loss value 81 | is compared with `p[abs_loss_stale]`. 82 | 83 | Two counters are updated: 84 | - for each consecutive epoch in which all the three values compared are less than the respective 85 | parameters from `p`, an `abs_bound` counter is updated 86 | - for each consecutive epoch in which the two delta values compared are less than the respective 87 | parameters from `p`, a `no_abs_bound` counter is updated 88 | 89 | After the counters are updated, their respective value is compared with `p[for_epochs]` and 90 | `p[for_epochs_no_abs_bound]`: if any of the counters is higher than the respective parameter, the 91 | training is considered "stale" for the current quantization value. 92 | 93 | When this happens, precision is scaled down by a factor of `p[bit_scaler]` bits, up to the point 94 | when `p[bit_stop_condition]` precision is reached. If `p[scale_lr]` is set to true, the learning 95 | rate is also downscaled by a factor of `p[lr_scaler]`. 96 | """ 97 | 98 | def __init__(self, net, optimizer, criterion, trainloader, precision_rule=None, tbx_writer=None, reset_alpha_weights=True, min_prec_dict=None, evaluator=None, evaluator_threshold=0.9, evaluator_verbose=False, evaluator_strategy='min_precision', divergence_policy='change_precision', divergence_lr_scaler=0.2, evaluate=None, log_start=None, log_stop=None, log_step=None, validate_on_train_fn=None, reset_alpha_below_6bits=False): 99 | super(RelaxationEngine, self).__init__() 100 | 101 | self.precision_rule = precision_rule 102 | if self.precision_rule is None: 103 | self.scale_activations = True 104 | self.scale_weights = True 105 | else: 106 | try: 107 | self.scale_activations = self.precision_rule['scale_activations'] 108 | except KeyError: 109 | self.scale_activations = True 110 | try: 111 | self.scale_weights = self.precision_rule['scale_weights'] 112 | except KeyError: 113 | self.scale_weights = True 114 | 115 | if self.precision_rule is not None: 116 | self.delta_loss_running_avg = 1000.0 117 | self.delta_loss_running_std = 0.0 118 | self.delta_loss_memory_curr_size = 0 119 | self.delta_loss_memory_size = self.precision_rule['running_avg_memory'] 120 | self.delta_loss_memory = np.zeros(self.precision_rule['running_avg_memory']) 121 | 122 | try: 123 | cs = self.precision_rule['custom_scaler'] 124 | except KeyError: 125 | self.precision_rule['custom_scaler'] = None 126 | 127 | self.loss_best = 1e3 128 | self.loss_prev = 1e3 129 | self.net = net 130 | self.optimizer = optimizer 131 | self.precision_abs_bound_counter = 0 132 | self.precision_no_abs_bound_counter = 0 133 | self.criterion = criterion 134 | self.trainloader = trainloader 135 | 136 | self.tbx_writer = tbx_writer 137 | self.reset_alpha_weights = reset_alpha_weights 138 | 139 | self.min_prec_dict = min_prec_dict 140 | self.evaluator = evaluator 141 | self.evaluator_threshold = evaluator_threshold 142 | self.evaluator_verbose = evaluator_verbose 143 | self.evaluator_strategy = evaluator_strategy 144 | 145 | self.divergence_policy = divergence_policy 146 | self.divergence_lr_scaler = divergence_lr_scaler 147 | self.divergence_cnt = 0 148 | self.divergence_lrscaling_limit = 3 149 | 150 | self.validate_on_train_fn = validate_on_train_fn 151 | 152 | self.reset_alpha_below_6bits = reset_alpha_below_6bits 153 | 154 | self.relaxation_ended = False 155 | 156 | def step(self, loss, epoch=0, checkpoint_name='checkpoint', previous_loss=None): 157 | r"""Iterate a step over the relaxation engine, checking the current convergence rate and updating precision and LR. 158 | 159 | :param loss: Current value of the training loss function. 160 | :type loss: `torch.Tensor` 161 | 162 | :param epoch: Epoch of training. 163 | :type epoch: `int` 164 | 165 | :param checkpoint_name: String to be used as a name for the checkpoint file. 166 | :type checkpoint_name: `str` 167 | 168 | The `step` method iterates the `RelaxationEngine` using as input the current value of the training loss function, 169 | whose convergence is evaluated. 170 | """ 171 | 172 | reset_alpha_weights = self.reset_alpha_weights 173 | if self.precision_rule is None: 174 | return 175 | 176 | # save best result in case of catastrophic failure 177 | if loss < self.loss_best: 178 | self.loss_best = loss 179 | if checkpoint_name is not None: 180 | nemo.utils.save_checkpoint(self.net, self.optimizer, epoch-1, checkpoint_name=checkpoint_name, checkpoint_suffix="_current_best") 181 | else: 182 | nemo.utils.save_checkpoint(self.net, self.optimizer, epoch-1, checkpoint_name="checkpoint", checkpoint_suffix="_current_best") 183 | 184 | try: 185 | curr_regime = self.precision_rule[epoch] 186 | except KeyError: 187 | try: 188 | curr_regime = self.precision_rule[str(epoch)] 189 | except KeyError: 190 | curr_regime = None 191 | 192 | for i in range(self.delta_loss_memory_size-1, 0, -1): 193 | self.delta_loss_memory[i] = self.delta_loss_memory[i-1] 194 | self.delta_loss_memory[0] = self.loss_prev-loss 195 | if self.delta_loss_memory_curr_size < self.delta_loss_memory_size: 196 | self.delta_loss_memory_curr_size += 1 197 | if self.delta_loss_memory_curr_size > 1: 198 | delta_loss_running_avg = self.delta_loss_memory[:self.delta_loss_memory_curr_size-1].mean() 199 | delta_loss_running_std = self.delta_loss_memory[:self.delta_loss_memory_curr_size-1].std() 200 | else: 201 | delta_loss_running_avg = np.Inf 202 | delta_loss_running_std = np.Inf 203 | logging.info("[Relax @%d]\t delta_loss_running_avg=%.3e loss_epoch_m1=%.3e delta_loss=%.3e" % (epoch-1, delta_loss_running_avg, loss, self.loss_prev-loss)) 204 | logging.info("[Relax @%d]\t delta_loss_memory=%s" % (epoch-1, self.delta_loss_memory[:self.delta_loss_memory_curr_size-1])) 205 | 206 | if self.tbx_writer is not None: 207 | self.tbx_writer.add_scalars('train', { 'delta_loss_avg': delta_loss_running_avg, 'delta_loss_std': delta_loss_running_std }, epoch+1) 208 | 209 | # staleness happens when 1) delta_loss has bounded mean and std and absolute loss is bounded for for_epochs, or 2) delta_loss has bounded mean and std for for_epochs_no_abs_bound 210 | if delta_loss_running_avg < self.precision_rule['delta_loss_less_than'] and delta_loss_running_std < self.precision_rule['delta_loss_running_std_stale'] and loss < self.precision_rule['abs_loss_stale']: 211 | self.precision_abs_bound_counter += 1 212 | else: 213 | self.precision_abs_bound_counter = 0 214 | if delta_loss_running_avg < self.precision_rule['delta_loss_less_than'] and delta_loss_running_std < self.precision_rule['delta_loss_running_std_stale']: 215 | self.precision_no_abs_bound_counter += 1 216 | else: 217 | self.precision_no_abs_bound_counter = 0 218 | 219 | divergence_chprec_flag = False 220 | divergence_lowlr_flag = False 221 | # catastrophic failure occurs when delta_loss_running_avg is negative 222 | try: 223 | divergence_abs_threshold = self.precision_rule['divergence_abs_threshold'] 224 | except KeyError: 225 | divergence_abs_threshold = 1e9 226 | if delta_loss_running_avg < 0 or loss > self.precision_rule['divergence_abs_threshold']: 227 | # recover previous state 228 | state = torch.load("checkpoint/"+checkpoint_name+"_current_best.pth")['state_dict'] 229 | self.net.load_state_dict(state, strict=True) 230 | logging.info("[Relax @%d]\t Detected divergent training, restoring previous best state." % (epoch-1)) 231 | 232 | if (self.divergence_policy == 'lower_lr' and self.divergence_cnt < self.divergence_lrscaling_limit) or self.relaxation_ended: 233 | # reset delta loss memory (with a small delta) 234 | self.precision_abs_bound_counter = 0 235 | self.precision_no_abs_bound_counter = 0 236 | loss = self.loss_best 237 | self.delta_loss_memory[:] = 0 238 | self.delta_loss_memory_curr_size = 1 239 | self.divergence_cnt += 1 240 | if self.divergence_cnt == 1: 241 | self.divergence_saved_lr = list(self.optimizer.param_groups)[0]['lr'] 242 | 243 | # scale LR together with W bits 244 | for p in self.optimizer.param_groups: 245 | p['lr'] *= self.divergence_lr_scaler 246 | lr = p['lr'] 247 | logging.info("[Relax @%d]\t Using 'lower_lr' policy (iter %d); scaled LR to %.3e" % (epoch-1, self.divergence_cnt, lr)) 248 | 249 | # report this as a precision change 250 | divergence_lowlr_flag = True 251 | elif self.divergence_policy == 'change_precision' or self.divergence_cnt >= self.divergence_lrscaling_limit: 252 | divergence_chprec_flag = True 253 | self.divergence_cnt = 0 254 | logging.info("[Relax @%d]\t Using 'change_precision' policy." % (epoch-1)) 255 | else: 256 | raise NotImplementedError 257 | 258 | # if relaxation has already ended, exit here 259 | if self.relaxation_ended: 260 | return divergence_lowlr_flag, True 261 | 262 | change_precision = divergence_lowlr_flag # False in normal cases 263 | 264 | if curr_regime is not None: 265 | self.precision_abs_bound_counter = 0 266 | self.precision_no_abs_bound_counter = 0 267 | if loss is None: 268 | loss = 1e3 269 | self.loss_prev = 1e3 270 | suffix = '_' + "%.1f" % (self.net.W_precision.get_bits()) + 'b' 271 | suffix += 'x' + "%.1f" % (self.net.x_precision.get_bits()) + 'b' 272 | if checkpoint_name is not None: 273 | nemo.utils.save_checkpoint(self.net, self.optimizer, epoch-1, checkpoint_name=checkpoint_name, checkpoint_suffix=suffix) 274 | self.net.change_precision(bits=curr_regime['W_bits'], scale_activations=False, scale_weights=True, reset_alpha=reset_alpha_weights, min_prec_dict=self.min_prec_dict) 275 | self.net.change_precision(bits=curr_regime['x_bits'], scale_activations=True, scale_weights=False, reset_alpha=reset_alpha_weights, min_prec_dict=self.min_prec_dict) 276 | 277 | # save checkpoint for catastrophic failure case 278 | self.loss_best = 1e3 279 | if checkpoint_name is None: 280 | nemo.utils.save_checkpoint(self.net, self.optimizer, epoch-1, checkpoint_name='checkpoint', checkpoint_suffix="_current_best") 281 | else: 282 | nemo.utils.save_checkpoint(self.net, self.optimizer, epoch-1, checkpoint_name=checkpoint_name, checkpoint_suffix="_current_best") 283 | 284 | elif self.precision_abs_bound_counter == self.precision_rule['for_epochs'] or self.precision_no_abs_bound_counter == self.precision_rule['for_epochs_no_abs_bound'] or divergence_chprec_flag: 285 | 286 | if self.precision_abs_bound_counter == self.precision_rule['for_epochs']: 287 | logging.info("[Relax @%d]\t precision_abs_bound_counter=%d: Triggering precision change below absolute loss threshold" % (epoch-1, self.precision_abs_bound_counter)) 288 | elif self.precision_no_abs_bound_counter == self.precision_rule['for_epochs_no_abs_bound']: 289 | logging.info("[Relax @%d]\t precision_no_abs_bound_counter=%d: Triggering precision change above absolute loss threshold" % (epoch-1, self.precision_no_abs_bound_counter)) 290 | 291 | self.precision_abs_bound_counter = 0 292 | self.precision_no_abs_bound_counter = 0 293 | loss = 1e3 294 | self.loss_prev = 2e3 295 | self.delta_loss_memory[:] = 1e3 296 | self.delta_loss_memory_curr_size = 1 297 | 298 | try: 299 | scale_x = self.precision_rule['scale_x'] 300 | except KeyError: 301 | scale_x = True 302 | try: 303 | scale_W = self.precision_rule['scale_W'] 304 | except KeyError: 305 | scale_W = True 306 | 307 | # stop condition is currently measured against W_precision.bits 308 | if (self.net.W_precision.get_bits() >= self.precision_rule['W_bit_stop_condition']) or \ 309 | (self.net.x_precision.get_bits() >= self.precision_rule['x_bit_stop_condition']): 310 | 311 | # save checkpoint 312 | if checkpoint_name is not None: 313 | suffix = '_' + "%.1f" % (self.net.W_precision.get_bits()) + 'b' 314 | suffix += 'x' + "%.1f" % (self.net.x_precision.get_bits()) + 'b' 315 | nemo.utils.save_checkpoint(self.net, self.optimizer, epoch-1, checkpoint_name=checkpoint_name, checkpoint_suffix=suffix) 316 | 317 | # if there is an EvaluationEngine available, use it 318 | if self.evaluator is not None: 319 | W_start = self.net.W_precision.get_bits() 320 | x_start = self.net.x_precision.get_bits() 321 | # pure heuristics... 322 | if W_start > 8: 323 | W_step = -2 324 | W_stop = max(6, self.precision_rule['W_bit_stop_condition']-2) if scale_W else W_start+W_step 325 | elif W_start > 6: 326 | W_step = -1 327 | W_stop = max(5, self.precision_rule['W_bit_stop_condition']-1) if scale_W else W_start+W_step 328 | else: 329 | W_step = -0.5 330 | W_stop = max(3.5, self.precision_rule['W_bit_stop_condition']-0.5) if scale_W else W_start+W_step 331 | if x_start > 8: 332 | x_step = -2 333 | x_stop = max(6, self.precision_rule['x_bit_stop_condition']-2) if scale_x else x_start+x_step 334 | elif x_start > 6: 335 | x_step = -1 336 | x_stop = max(5, self.precision_rule['x_bit_stop_condition']-1) if scale_x else x_start+x_step 337 | else: 338 | x_step = -0.5 339 | x_stop = max(3.5, self.precision_rule['x_bit_stop_condition']-0.5) if scale_x else x_start+x_step 340 | self.net.unset_train_loop() # this is a "soft" weight hardening 341 | self.evaluator.reset_grids(W_start, W_stop, W_step, x_start, x_stop, x_step) 342 | while self.evaluator.step(): 343 | acc = self.evaluator.validate_fn(0, val_loader=self.evaluator.validate_data) 344 | self.evaluator.report(acc) 345 | logging.info("[Relax @%d]\t %.1f-bit W, %.1f-bit x %.2f%%" % (epoch-1, self.evaluator.wgrid[self.evaluator.idx], self.evaluator.xgrid[self.evaluator.idx], 100*acc.item())) 346 | self.net.set_train_loop() # this removes the "soft" hardening 347 | Wbits, xbits = self.evaluator.get_next_config(upper_threshold=self.evaluator_threshold, verbose=self.evaluator_verbose, strategy=self.evaluator_strategy) 348 | Wdiff = W_start - Wbits 349 | elif self.precision_rule['custom_scaler'] is not None: 350 | scaler = self.precision_rule['custom_scaler'] 351 | if len(scaler) == 0: 352 | return True, True 353 | Wbits, xbits, lrscaled, divpol = scaler.pop(0) 354 | self.divergence_policy = divpol # update divergence policy 355 | W_diff = self.net.W_precision.get_bits() - Wbits 356 | else: 357 | Wdiff = -self.precision_rule['W_bit_scaler'] 358 | Wbits = self.net.W_precision.get_bits()+self.precision_rule['W_bit_scaler'] 359 | xbits = self.net.x_precision.get_bits()+self.precision_rule['x_bit_scaler'] 360 | 361 | logging.info("[Relax @%d]\t Choosing %.1f-bit W, %.1f-bit x for next step" % (epoch-1, Wbits, xbits)) 362 | if scale_W: 363 | self.net.change_precision(bits=Wbits, scale=self.net.W_precision.get_scale(), scale_activations=False, scale_weights=self.scale_weights, reset_alpha=reset_alpha_weights, min_prec_dict=self.min_prec_dict) 364 | # this will reset alpha,beta PACT parameters to 5 standard deviations upon precision change, to avoid wasting dynamic range to represent irrealistic weights 365 | if Wbits < 6 and self.reset_alpha_below_6bits: 366 | self.net.reset_alpha_weights(stdev=5.) 367 | logging.info("[Relax @%d]\t Setting alpha,beta params of weights to %.1f std deviations" % (epoch-1, 5)) 368 | if scale_x: 369 | self.net.change_precision(bits=xbits, scale=self.net.x_precision.get_scale(), scale_activations=self.scale_activations, scale_weights=False, reset_alpha=reset_alpha_weights, min_prec_dict=self.min_prec_dict) 370 | change_precision = True 371 | if self.divergence_cnt > 0: 372 | for p in self.optimizer.param_groups: 373 | p['lr'] = self.divergence_saved_lr 374 | self.divergence_cnt = 0 375 | try: 376 | if self.evaluator is not None: 377 | # scale LR together with W bits 378 | for p in self.optimizer.param_groups: 379 | p['lr'] *= (2**Wdiff) 380 | lr = p['lr'] 381 | logging.info("[Relax @%d]\t Scaled LR to %.3e" % (epoch-1, lr)) 382 | elif self.precision_rule['custom_scaler'] is not None: 383 | for p in self.optimizer.param_groups: 384 | p['lr'] = lrscaled 385 | lr = p['lr'] 386 | logging.info("[Relax @%d]\t Scaled LR to %.3e" % (epoch-1, lr)) 387 | elif self.precision_rule['scale_lr']: 388 | for p in self.optimizer.param_groups: 389 | p['lr'] *= self.precision_rule['lr_scaler'] 390 | lr = p['lr'] 391 | logging.info("[Relax @%d]\t Scaled LR to %.3e" % (epoch-1, lr)) 392 | except KeyError: 393 | pass 394 | if self.validate_on_train_fn is not None: 395 | loss = self.validate_on_train_fn(epoch) 396 | logging.info("[Relax @%d]\t validate_on_train loss=%.3e" % (epoch-1, loss.item())) 397 | 398 | # save checkpoint for catastrophic failure case 399 | self.loss_best = loss 400 | nemo.utils.save_checkpoint(self.net, self.optimizer, epoch-1, checkpoint_name=checkpoint_name, checkpoint_suffix="_current_best") 401 | 402 | else: 403 | self.relaxation_ended = True 404 | logging.info("[Relax @%d]\t Precision relaxation procedure ended" % (epoch-1)) 405 | return True, True 406 | 407 | if self.tbx_writer is not None: 408 | self.tbx_writer.add_scalars('train', { 'abs_bound_counter': self.precision_abs_bound_counter, 'no_abs_bound_counter': self.precision_no_abs_bound_counter, }, epoch+1) 409 | lr_save = list(self.optimizer.param_groups)[0]['lr'] 410 | self.tbx_writer.add_scalars('train', { 'lr': lr_save }, epoch+1) 411 | logging.info("[Relax @%d]\t delta_loss_running_avg=%.3e loss_epoch_m1=%.3e delta_loss=%.3e" % (epoch-1, delta_loss_running_avg, loss, self.loss_prev-loss)) 412 | logging.info("[Relax @%d]\t delta_loss_memory=%s" % (epoch-1, self.delta_loss_memory[:self.delta_loss_memory_curr_size-1])) 413 | logging.info("[Relax @%d]\t precision_abs_bound_counter=%d precision_no_abs_bound_counter=%d" % (epoch-1, self.precision_abs_bound_counter, self.precision_no_abs_bound_counter)) 414 | 415 | self.loss_prev = loss 416 | 417 | # if precision has changed, signal this upstream 418 | return change_precision, False 419 | -------------------------------------------------------------------------------- /nemo/transf/__init__.py: -------------------------------------------------------------------------------- 1 | # __init__.py 2 | # Francesco Conti 3 | # Alfio Di Mauro 4 | # 5 | # Copyright (C) 2018-2020 ETH Zurich 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | __all__ = ["bias", "bn", "common", "deploy", "equalize", "export", "pruning", "statistics", "utils", "sawb"] 20 | from . import bias, bn, common, deploy, equalize, export, pruning, statistics, utils, sawb 21 | -------------------------------------------------------------------------------- /nemo/transf/bias.py: -------------------------------------------------------------------------------- 1 | # 2 | # bias.py 3 | # Francesco Conti 4 | # 5 | # Copyright (C) 2018-2020 ETH Zurich 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | import torch 20 | from nemo.precision import Precision 21 | from nemo.quant.pact import * 22 | from nemo.graph import DeployGraph 23 | from torch.nn.modules.utils import _single,_pair 24 | from collections import OrderedDict 25 | import types 26 | import logging 27 | import numpy as np 28 | import copy 29 | import math 30 | import torchvision.models 31 | import re 32 | from nemo.transf.common import * 33 | 34 | def _add_input_bias_pact(self, lin_dict, eps_in=None): 35 | r"""Adds a bias to compensate the asymmetry of an input activation tensor. 36 | 37 | :param lin_dict: a dictionary, where the key is the linear layer name and the value is the related activation translation. 38 | :type lin_dict: `dict` or `collections.OrderedDict` 39 | 40 | """ 41 | 42 | module_dict = {} 43 | for n,m in self.named_modules(): 44 | if (m.__class__.__name__ == "PACT_Conv2d" or \ 45 | m.__class__.__name__ == "PACT_Conv1d" or \ 46 | m.__class__.__name__ == "PACT_Linear" or \ 47 | m.__class__.__name__ == "ConstantPad2d"): 48 | module_dict[n] = m 49 | # print(lin_dict) 50 | for n in lin_dict.keys(): 51 | m = module_dict[n] 52 | if (m.__class__.__name__ == "PACT_Conv2d" or \ 53 | m.__class__.__name__ == "PACT_Conv1d" or \ 54 | m.__class__.__name__ == "PACT_Linear"): 55 | try: 56 | m.bias.data[:] = m.bias.data[:] - lin_dict[n] * m.weight.data[:].sum(3).sum(2).sum(1) 57 | except AttributeError: 58 | m.bias = torch.nn.Parameter(-lin_dict[n] * m.weight.data[:].sum(3).sum(2).sum(1)) 59 | if eps_in is None: 60 | m.padding_value = lin_dict[n] 61 | else: 62 | m.padding_value = math.floor(lin_dict[n]/eps_in)*eps_in 63 | elif eps_in is None: 64 | m.value = lin_dict[n] 65 | else: 66 | m.value = math.floor(lin_dict[n]/eps_in)*eps_in 67 | self.input_bias_dict = lin_dict 68 | 69 | def _remove_input_bias_pact(self): 70 | module_dict = {} 71 | for n,m in self.named_modules(): 72 | if (m.__class__.__name__ == "PACT_Conv2d" or \ 73 | m.__class__.__name__ == "PACT_Conv1d" or \ 74 | m.__class__.__name__ == "PACT_Linear"): 75 | module_dict[n] = m 76 | for n in self.input_bias_dict.keys(): 77 | m = module_dict[n] 78 | m.bias.data[:] = m.bias.data[:] + self.input_bias_dict[n] * m.weight.data[:].sum(3).sum(2).sum(1) 79 | m.padding_value = 0 80 | self.input_bias_dict = None 81 | 82 | def _remove_bias_pact(self, bn_dict={}): 83 | r"""Folds the bias of a linear layer into the parameters of a following batch-norm. 84 | 85 | :param bn_dict: a dictionary of layer names, with the key being the source (linear) and the value the target (batch-norm). 86 | :type bn_dict: `dict` or `collections.OrderedDict` 87 | 88 | """ 89 | 90 | if not bn_dict: 91 | bn_dict = get_bn_dict_from_supernodes(self) 92 | 93 | module_dict = {} 94 | for n,m in self.named_modules(): 95 | if (m.__class__.__name__ == "PACT_Conv2d" or \ 96 | m.__class__.__name__ == "PACT_Conv1d" or \ 97 | m.__class__.__name__ == "PACT_Linear" or \ 98 | m.__class__.__name__ == "BatchNorm2d" or \ 99 | m.__class__.__name__ == "BatchNorm1d" ): 100 | module_dict[n] = m 101 | for n_before in bn_dict.keys(): 102 | n_after = bn_dict[n_before] 103 | m_before = module_dict[n_before] 104 | m_after = module_dict[n_after] 105 | m_after.running_mean.data[:] -= m_before.bias.data[:] 106 | m_before.bias = None 107 | 108 | -------------------------------------------------------------------------------- /nemo/transf/bn.py: -------------------------------------------------------------------------------- 1 | # 2 | # bn.py 3 | # Francesco Conti 4 | # 5 | # Copyright (C) 2018-2020 ETH Zurich 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | import torch 20 | from nemo.precision import Precision 21 | from nemo.quant.pact import * 22 | from nemo.graph import DeployGraph 23 | from torch.nn.modules.utils import _single,_pair 24 | from collections import OrderedDict 25 | import types 26 | import logging 27 | import numpy as np 28 | import copy 29 | import math 30 | import torchvision.models 31 | import re 32 | from nemo.transf.common import * 33 | 34 | BN_PRECISION_MAX = 20 35 | 36 | def _absorb_affine_bn(self): 37 | for n,m in self.named_modules(): 38 | if m.__class__.__name__ == "BatchNorm2d": 39 | m.weight.data[:] = m.weight.data[:] / torch.sqrt(m.running_var.data[:] + m.eps) 40 | m.bias.data[:] = m.bias.data[:] - m.weight.data[:] * m.running_mean.data[:] 41 | m.running_var.data[:] = (1. - m.eps)**2 42 | m.running_mean.data[:] = 0. 43 | 44 | def _prime_affine_bn(self): 45 | for n,m in self.named_modules(): 46 | if m.__class__.__name__ == "BatchNorm2d": 47 | m.weight.data[:] = m.weight.data[:] * torch.sqrt(m.running_var.data[:] + m.eps) 48 | m.bias.data[:] = m.bias.data[:] + m.weight.data[:] * m.running_mean.data[:] / torch.sqrt(m.running_var.data[:] + m.eps) 49 | 50 | def _freeze_bn(self, reset_stats=False, disable_grad=False): 51 | r"""Sets :py:class:torch.nn.BatchNorm2d` layers to not collect statistics and keep the current `running_var` and `running_mean`. 52 | 53 | """ 54 | 55 | for n,m in self.named_modules(): 56 | if m.__class__.__name__ == "BatchNorm2d": 57 | if reset_stats: 58 | try: 59 | eps = m.eps 60 | except AttributeError: 61 | eps = 0. 62 | gamma = m.weight.data[:].clone().detach().cpu() 63 | beta = m.bias.data[:].clone().detach().cpu() 64 | sigma = torch.sqrt(m.running_var.data[:] + eps).clone().detach().cpu() 65 | mu = m.running_mean.data[:].clone().detach().cpu() 66 | kappa = gamma/sigma 67 | lamda = beta-gamma/sigma * mu 68 | m.weight.data[:] = kappa 69 | m.bias.data[:] = lamda 70 | m.running_var[:] = 1. 71 | m.running_mean[:] = 0. 72 | m.track_running_stats = True 73 | m.eval() 74 | if disable_grad: 75 | m.weight.requires_grad = False 76 | m.bias.requires_grad = False 77 | 78 | def _calibrate_bn_pact(self, calib_dict={}, kappa_bit_default=14, lamda_bit_default=20, kappa_dict=None, lamda_dict=None, range_factor=8, minmax=False, **kwargs): 79 | r"""Calibrates BN layer quantization for :py:class:`nemo.quant.pact.PACT_QuantizedBatchNormNd` layers. 80 | Using BN min-max statistics previously calculated, this method calibrates the number of bits used in BN parameters 81 | so that both the BN output and the `kappa`, `lamda` affine parameters are representable. 82 | 83 | :param kappa_bit_default: Default maximum number of bits for BN multiplicative parameter (can be overridden by `kappa_dict`); default 16. 84 | :type kappa_bit_default: int 85 | :param lamda_bit_default: Default maximum number of bits for BN additive parameter (can be overridden by `lamda_dict`); default 32. 86 | :type lamda_bit_default: int 87 | :param kappa_dict: dictionary of maximum number of bits for `kappa` in specific layers; overrides `kappa_bit_default`. Default None. 88 | :type kappa_dict: `dict` or `collections.OrderedDict` 89 | :param lamda_dict: dictionary of maximum number of bits for `lamda` in specific layers; overrides `lamda_bit_default`. Default None. 90 | :type lamda_dict: `dict` or `collections.OrderedDict` 91 | 92 | """ 93 | 94 | if not calib_dict: 95 | calib_dict = get_calib_dict_from_supernodes(self) 96 | 97 | module_dict = {} 98 | for n,m in self.named_modules(): 99 | if (m.__class__.__name__ == "PACT_Act" or \ 100 | m.__class__.__name__ == "PACT_QuantizedBatchNormNd"): 101 | module_dict[n] = m 102 | 103 | for n,m in self.named_modules(): 104 | if m.__class__.__name__ == "PACT_QuantizedBatchNormNd": 105 | kappa_max = m.kappa.abs().max() 106 | lamda_max = m.lamda.abs().max() 107 | if minmax: 108 | out_range = max(m.max.abs(), m.min.abs()) 109 | else: # if we do not have minmax statistics, approximate with alpha mult by a range_factor 110 | out_range = module_dict[calib_dict[n]].alpha * range_factor 111 | if kappa_dict is not None: 112 | kappa_bit_lim = kappa_dict[n] 113 | else: 114 | kappa_bit_lim = kappa_bit_default 115 | if lamda_dict is not None: 116 | lamda_bit_lim = lamda_dict[n] 117 | else: 118 | lamda_bit_lim = lamda_bit_default 119 | eps_lim = max(out_range, lamda_max) / (2**(lamda_bit_lim - 1) - 1) 120 | eps_kappa_lim = eps_lim / m.eps_in 121 | kappa_bits = min(int(min(torch.log2(1 + 2*kappa_max / eps_kappa_lim).floor(), kappa_bit_lim)), BN_PRECISION_MAX) 122 | lamda_bits = min(lamda_bit_lim, BN_PRECISION_MAX) 123 | m.precision_kappa.set_bits(kappa_bits) 124 | m.precision_lamda.set_bits(lamda_bits) 125 | # print("%s %d %d" % (n, kappa_bits, lamda_bits)) 126 | 127 | def _unfreeze_bn(self): 128 | r"""Sets :py:class:`torch.nn.BatchNorm2d` layers to collect statistics and update `running_var` and `running_mean`. 129 | 130 | """ 131 | 132 | for n,m in self.named_modules(): 133 | if m.__class__.__name__ == "BatchNorm2d": 134 | m.train() 135 | 136 | def _prune_empty_bn_pact(self, bn_dict={}, threshold=None): 137 | if not bn_dict: 138 | bn_dict = get_bn_dict_from_supernodes(self) 139 | 140 | module_dict = {} 141 | for n,m in self.named_modules(): 142 | if (m.__class__.__name__ == "PACT_Conv2d" or \ 143 | m.__class__.__name__ == "PACT_Conv1d" or \ 144 | m.__class__.__name__ == "PACT_Linear" or \ 145 | m.__class__.__name__ == "BatchNorm2d" or \ 146 | m.__class__.__name__ == "BatchNorm1d" or \ 147 | m.__class__.__name__ == "PACT_QuantizedBatchNormNd"): 148 | module_dict[n] = m 149 | for n_before in bn_dict.keys(): 150 | n_after = bn_dict[n_before] 151 | m_before = module_dict[n_before] 152 | m_after = module_dict[n_after] 153 | if threshold is None: 154 | try: 155 | eps = (m_before.W_alpha + m_before.W_beta) / (2.**m_before.W_precision.get_bits()-1) 156 | except AttributeError: 157 | continue 158 | else: 159 | eps = threshold 160 | if m_before.bias is not None: 161 | continue 162 | try: 163 | m_after.kappa.data[m_before.weight.data.flatten(1).abs().max(1)[0] < eps] = 0. 164 | except AttributeError: 165 | m_after.weight.data[m_before.weight.data.flatten(1).abs().max(1)[0] < eps] = 0. 166 | m_after.running_var.data[m_before.weight.data.flatten(1).abs().max(1)[0] < eps] = 1. 167 | 168 | def _fold_bn_pact(self, bn_dict={}, bn_inv_dict={}, eps=None, phi_inv=0., reset_alpha=True, remove_bn=False): 169 | r"""Performs batch-normalization folding following the algorithm presented in 170 | https://arxiv.org/abs/1905.04166. It performs both normal folding and inverse 171 | folding using two separate dictionaries `bn_dict` and `bn_inv_dict`. 172 | 173 | :param bn_dict: a dictionary of layer names, with the key being the source (linear) and the value the target (batch-norm). If empty (default), uses the graph to fold all BN layers. 174 | :type bn_dict: `dict` or `collections.OrderedDict` 175 | :param bn_inv_dict: a dictionary of layers, with the key being the source (batch-norm) and the value the target (linear). 176 | :type bn_inv_dict: `dict` or `collections.OrderedDict` 177 | :param verbose: if True, prints more information. 178 | :type verbose: bool 179 | :param eps: if not None (the default), overrides numerical `eps` used within batch-norm layer. 180 | :type eps: float 181 | :param phi_inv: parameter added to `gamma` in inverse folding for better numerical stability (default 0). 182 | :type phi_inv: float 183 | :param reset_alpha: if True (default), reset the clipping parameters of weights. 184 | :type reset_alpha: bool 185 | :param remove_bn: if True, replace all BN layers with identity layers to "physically remove" BN from the network (default False). 186 | :type remove_bn: bool 187 | 188 | """ 189 | 190 | if not bn_dict: 191 | bn_dict = get_bn_dict_from_supernodes(self) 192 | 193 | module_dict = {} 194 | for n,m in self.named_modules(): 195 | if (m.__class__.__name__ == "PACT_Conv2d" or \ 196 | m.__class__.__name__ == "PACT_Conv1d" or \ 197 | m.__class__.__name__ == "PACT_Linear" or \ 198 | m.__class__.__name__ == "BatchNorm2d" or \ 199 | m.__class__.__name__ == "BatchNorm1d" or \ 200 | m.__class__.__name__ == "PACT_QuantizedBatchNormNd"): 201 | module_dict[n] = m 202 | param = {} 203 | bn_list = list(bn_dict.values()) + list(bn_inv_dict.keys()) 204 | for n in bn_list: 205 | m = module_dict[n] 206 | # count how many time this occures as a value in the direct map (for bilinear functions e.g. add) 207 | count = list(bn_dict.values()).count(n) 208 | if eps is None: 209 | try: 210 | eps = m.eps 211 | except AttributeError: 212 | eps = 0. 213 | if m.__class__.__name__ == "PACT_QuantizedBatchNormNd": 214 | gamma = m.kappa.data[:].clone().detach().cpu().flatten() 215 | beta = m.lamda.data[:].clone().detach().cpu().flatten() 216 | sigma = torch.ones_like(gamma).flatten() 217 | mu = torch.zeros_like(beta).flatten() 218 | else: 219 | gamma = m.weight.data[:].clone().detach().cpu() 220 | beta = m.bias.data[:].clone().detach().cpu() 221 | sigma = torch.sqrt(m.running_var.data[:] + eps).clone().detach().cpu() 222 | mu = m.running_mean.data[:].clone().detach().cpu() 223 | param[n] = { 224 | 'gamma' : gamma, 225 | 'beta' : beta, 226 | 'sigma' : sigma, 227 | 'mu' : mu, 228 | 'count' : count 229 | } 230 | # direct folding (CONV->BN) 231 | for n in bn_dict.keys(): 232 | n_bn = bn_dict[n] 233 | m = module_dict[n] 234 | m_bn = module_dict[n_bn] 235 | gamma = param[n_bn]['gamma'] 236 | beta = param[n_bn]['beta'] 237 | mu = param[n_bn]['mu'] 238 | sigma = param[n_bn]['sigma'] 239 | count = param[n_bn]['count'] 240 | if count > 1: 241 | beta = beta/count 242 | mu = mu/count 243 | 244 | th_a = (gamma/sigma).to(m.weight.device) 245 | th_b = (beta-gamma/sigma * mu).to(m.weight.device) 246 | 247 | m.weight.data[:] = m.weight.data[:] * reshape_before(m, th_a) 248 | try: 249 | m.bias.data[:] = th_a * m.bias.data[:] + th_b 250 | except AttributeError: 251 | m.bias = torch.nn.Parameter(th_b) 252 | # inverse folding (BN->CONV) 253 | for n_bn in bn_inv_dict.keys(): 254 | n = bn_inv_dict[n_bn] 255 | m = module_dict[n] 256 | m_bn = module_dict[n_bn] 257 | gamma = param[n_bn]['gamma'] 258 | beta = param[n_bn]['beta'] 259 | mu = param[n_bn]['mu'] 260 | sigma = param[n_bn]['sigma'] 261 | count = param[n_bn]['count'] 262 | 263 | th_a = sigma/gamma 264 | shape_w = np.prod(np.asarray(m.weight.data.shape)[3:1]) 265 | th_m_by_w = (reshape_after(m, mu)*m.weight.data[:]).sum(3).sum(2).sum(1) / shape_w 266 | th_bsg_by_w = (reshape_after(m, beta*sigma/(gamma+phi_inv))*m.weight.data[:]).sum(3).sum(2).sum(1) / shape_w 267 | 268 | if phi_inv is None: 269 | phi_inv = m_bn.eps 270 | m.weight.data[:] = m.weight.data[:] * reshape_after(m, th_a) 271 | try: 272 | m.bias.data[:] = m.bias.data[:] + th_m_by_w - th_bsg_by_w 273 | except AttributeError: 274 | m.bias = torch.nn.Parameter(th_m_by_w - th_bsg_by_w) 275 | # neutralize BatchNorm's 276 | for n in bn_list: 277 | m = module_dict[n] 278 | if m.__class__.__name__ == "PACT_QuantizedBatchNormNd": 279 | m.kappa.data[:] = 1. 280 | m.lamda.data[:] = 0. 281 | else: 282 | m.weight.data[:] = 1. 283 | m.bias.data[:] = 0. 284 | m.running_mean.data[:] = 0. 285 | m.running_var.data[:] = (1. - eps)**2 286 | if reset_alpha: 287 | self.reset_alpha_weights() 288 | 289 | def _threshold_folding_pact(self, act_dict): 290 | r"""Performs the folding of batch-normalization layers into threshold-based 291 | activation layers. 292 | 293 | :param act_dict: a dictionary of layer names, with the key being the source (batch-norm) and the value the target (activation). 294 | :type act_dict: `dict` or `collections.OrderedDict` 295 | 296 | """ 297 | 298 | module_dict = {} 299 | 300 | for n,m in self.named_modules(): 301 | if (m.__class__.__name__ == "PACT_ThresholdAct" or \ 302 | m.__class__.__name__ == "BatchNorm2d" or \ 303 | m.__class__.__name__ == "BatchNorm1d" ): 304 | module_dict[n] = m 305 | for n_before in act_dict.keys(): 306 | n_after = act_dict[n_before] 307 | m_before = module_dict[n_before] 308 | m_after = module_dict[n_after] 309 | # get BN parameters 310 | eps = m_before.eps 311 | gamma = m_before.weight.data[:] 312 | beta = m_before.bias.data[:] 313 | sigma = torch.sqrt(m_before.running_var.data[:] + eps) 314 | mu = m_before.running_mean.data[:] 315 | # setup threshold in PACT_ThresholdAct 316 | del m_after.kappa, m_after.lamda 317 | m_after.kappa = torch.nn.Parameter(torch.zeros(gamma.shape[0]).to(m_after.alpha.data.device)) 318 | m_after.lamda = torch.nn.Parameter(torch.zeros(gamma.shape[0]).to(m_after.alpha.data.device)) 319 | m_after.kappa.data[:] = sigma/gamma 320 | m_after.lamda.data[:] = mu - beta*sigma/gamma 321 | # remove BN parameters 322 | m_before.weight.data[:] = 1. 323 | m_before.bias.data[:] = 0. 324 | m_before.running_var.data[:] = (1. - eps)**2 325 | m_before.running_mean.data[:] = 0. 326 | -------------------------------------------------------------------------------- /nemo/transf/common.py: -------------------------------------------------------------------------------- 1 | # 2 | # common.py 3 | # Francesco Conti 4 | # 5 | # Copyright (C) 2018-2020 ETH Zurich 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | import torch 20 | from nemo.precision import Precision 21 | from nemo.quant.pact import * 22 | from nemo.graph import DeployGraph 23 | from torch.nn.modules.utils import _single,_pair 24 | from collections import OrderedDict 25 | import types 26 | import logging 27 | import numpy as np 28 | import copy 29 | import math 30 | import torchvision.models 31 | import re 32 | 33 | __all__ = [ "reshape_before", "reshape_after", "weight_range", "weight_min", "weight_max", "onnx_name_2_pytorch_name", "get_bn_dict_from_supernodes", "get_equalize_dict_from_supernodes", "get_calib_dict_from_supernodes" ] 34 | 35 | def reshape_before(m, s): 36 | if m.__class__.__name__ == "PACT_Conv2d": 37 | return s.reshape((s.shape[0],1,1,1)) 38 | elif m.__class__.__name__ == "PACT_Conv1d": 39 | return s.reshape((s.shape[0],1,1)) 40 | elif m.__class__.__name__ == "PACT_Linear": 41 | return s.reshape((s.shape[0],1)) 42 | else: 43 | return s 44 | 45 | def reshape_after(m, s): 46 | if m.__class__.__name__ == "PACT_Conv2d": 47 | if m.groups == s.shape[0]: 48 | # dwc 49 | return s.reshape((s.shape[0],1,1,1)) 50 | else: 51 | return s.reshape((1,s.shape[0],1,1)) 52 | elif m.__class__.__name__ == "PACT_Conv1d": 53 | return s.reshape((1,s.shape[0],1)) 54 | elif m.__class__.__name__ == "PACT_Linear": 55 | return s.reshape((1,s.shape[0])) 56 | else: 57 | return s 58 | 59 | def weight_max(m, range_idx): 60 | if m.__class__.__name__ == "PACT_Conv2d": 61 | if m.groups == m.weight.shape[0]: 62 | range_idx = 0 # for DW-conv, always marginalize idx 1 63 | return m.weight.max(3)[0].max(2)[0].max(1-range_idx)[0] 64 | elif m.__class__.__name__ == "PACT_Conv1d": 65 | return m.weight.max(2)[0].max(1-range_idx)[0] 66 | elif m.__class__.__name__ == "PACT_Linear": 67 | return m.weight.max(1-range_idx)[0] 68 | elif m.__class__.__name__ == "BatchNorm1d" or m.__class__.__name__ == "BatchNorm2d": 69 | return m.weight.data.abs() 70 | 71 | def weight_min(m, range_idx): 72 | if m.__class__.__name__ == "PACT_Conv2d": 73 | if m.groups == m.weight.shape[0]: 74 | range_idx = 0 # for DW-conv, always marginalize idx 1 75 | return m.weight.min(3)[0].min(2)[0].min(1-range_idx)[0] 76 | elif m.__class__.__name__ == "PACT_Conv1d": 77 | return m.weight.min(2)[0].min(1-range_idx)[0] 78 | elif m.__class__.__name__ == "PACT_Linear": 79 | return m.weight.min(1-range_idx)[0] 80 | elif m.__class__.__name__ == "BatchNorm1d" or m.__class__.__name__ == "BatchNorm2d": 81 | return m.weight.data.abs() 82 | 83 | def weight_range(m, range_idx, symmetric=False): 84 | if m.__class__.__name__ == "PACT_Conv2d": 85 | if m.groups == m.weight.shape[0]: 86 | range_idx = 0 # for DW-conv, always marginalize idx 1 87 | if not symmetric: 88 | return m.weight.max(3)[0].max(2)[0].max(1-range_idx)[0] - m.weight.min(3)[0].min(2)[0].min(1-range_idx)[0] 89 | else: 90 | return 2*m.weight.abs().max(3)[0].max(2)[0].max(1-range_idx)[0] 91 | elif m.__class__.__name__ == "PACT_Conv1d": 92 | if not symmetric: 93 | return m.weight.max(2)[0].max(1-range_idx)[0] - m.weight.min(2)[0].min(1-range_idx)[0] 94 | else: 95 | return m.weight.abs().max(2)[0].max(1-range_idx)[0] 96 | elif m.__class__.__name__ == "PACT_Linear": 97 | if not symmetric: 98 | return m.weight.max(1-range_idx)[0] - m.weight.min(1-range_idx)[0] 99 | else: 100 | return m.weight.abs().max(1-range_idx)[0] 101 | elif m.__class__.__name__ == "BatchNorm1d" or m.__class__.__name__ == "BatchNorm2d": 102 | return m.weight.data.abs() 103 | 104 | def onnx_name_2_pytorch_name(name): 105 | name_parts = re.findall('\[.*?\]', name) 106 | name_parts = [part[1:-1] for part in name_parts] 107 | return '.'.join(name_parts) 108 | 109 | def get_equalize_dict_from_supernodes(net): 110 | eq_dict = OrderedDict() 111 | act_dict = OrderedDict() 112 | # check all supernodes for ACT and CONV layers 113 | lin = {} 114 | act = {} 115 | prev = {} 116 | for k,ssn in net.graph.get_supernodes().items(): 117 | for n in ssn['supernode']: 118 | if isinstance(n[1], PACT_Conv2d) or \ 119 | isinstance(n[1], PACT_Conv1d) or \ 120 | isinstance(n[1], PACT_Linear): 121 | lin[k] = n[0] 122 | prev[k] = ssn['previous'] 123 | act[k] = k 124 | for k in prev.keys(): 125 | p = lin.get(prev[k]) 126 | if p is not None: 127 | eq_dict [lin[prev[k]]] = lin[k] 128 | act_dict[lin[prev[k]]] = act[prev[k]] 129 | return eq_dict, act_dict 130 | 131 | def get_bn_dict_from_supernodes(net): 132 | bn_dict = OrderedDict() 133 | # check all supernodes for BN and CONV layers 134 | for k,ssn in net.graph.get_supernodes().items(): 135 | bn = [] 136 | lin = [] 137 | sn = ssn['supernode'] 138 | prev = ssn['previous'] 139 | for n in sn: 140 | if isinstance(n[1], torch.nn.BatchNorm2d) or \ 141 | isinstance(n[1], torch.nn.BatchNorm1d) or \ 142 | isinstance(n[1], PACT_QuantizedBatchNormNd): 143 | bn.append(n[0]) 144 | if isinstance(n[1], PACT_Conv2d) or \ 145 | isinstance(n[1], PACT_Conv1d) or \ 146 | isinstance(n[1], PACT_Linear): 147 | lin.append(n[0]) 148 | if len(lin) > 1 or len(bn) > 1: 149 | print("[Error] Supernode analysis identified multiple BN or LIN layers in supernode") 150 | print(lin, bn) 151 | return 152 | try: 153 | bn_dict[lin[0]] = bn[0] 154 | except IndexError: 155 | pass 156 | return bn_dict 157 | 158 | def get_calib_dict_from_supernodes(net): 159 | calib_dict = OrderedDict() 160 | # check all supernodes for BN and ACT layers 161 | for k,ssn in net.graph.get_supernodes().items(): 162 | bn = [] 163 | sn = ssn['supernode'] 164 | prev = ssn['previous'] 165 | for n in sn: 166 | if isinstance(n[1], torch.nn.BatchNorm2d) or \ 167 | isinstance(n[1], torch.nn.BatchNorm1d) or \ 168 | isinstance(n[1], PACT_QuantizedBatchNormNd): 169 | bn.append(n[0]) 170 | if len(bn) > 1: 171 | print("[Error] Supernode analysis identified multiple BN layers in supernode") 172 | print(bn) 173 | return 174 | try: 175 | calib_dict[bn[0]] = k 176 | except IndexError: 177 | pass 178 | return calib_dict 179 | 180 | -------------------------------------------------------------------------------- /nemo/transf/deploy.py: -------------------------------------------------------------------------------- 1 | # 2 | # deploy.py 3 | # Francesco Conti 4 | # 5 | # Copyright (C) 2018-2020 ETH Zurich 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | import torch 20 | from nemo.precision import Precision 21 | from nemo.quant.pact import * 22 | from nemo.graph import DeployGraph 23 | from torch.nn.modules.utils import _single,_pair 24 | from collections import OrderedDict 25 | import types 26 | import logging 27 | import numpy as np 28 | import copy 29 | import math 30 | import torchvision.models 31 | import re 32 | from nemo.transf.common import * 33 | import nemo 34 | 35 | def _get_eps_at(self, *args, **kwargs): 36 | if hasattr(self, 'graph'): 37 | if self.graph is not None: 38 | return self.graph.get_eps_at(*args, **kwargs) 39 | 40 | def _harden_weights_pact(self, **kwargs): 41 | r"""Harden all weights in the network to their quantized value. 42 | 43 | """ 44 | 45 | for n,m in self.named_modules(): 46 | if (m.__class__.__name__ == "PACT_Conv2d" or \ 47 | m.__class__.__name__ == "PACT_Conv1d" or \ 48 | m.__class__.__name__ == "PACT_Linear"): 49 | m.train_loop_oldprec = float(m.W_beta.item()+m.W_alpha.item())/(2.0**(m.W_precision.get_bits())-1) 50 | m.harden_weights(**kwargs) 51 | if (m.__class__.__name__ == "PACT_QuantizedBatchNormNd"): 52 | m.harden_weights(**kwargs) 53 | 54 | def _round_weights_pact(self, **kwargs): 55 | r"""Round all weights in the network adding 1/2 an eps. 56 | 57 | """ 58 | 59 | for n,m in self.named_modules(): 60 | if (m.__class__.__name__ == "PACT_Conv2d" or \ 61 | m.__class__.__name__ == "PACT_Conv1d" or \ 62 | m.__class__.__name__ == "PACT_Linear"): 63 | m.weight.data[:] += (m.W_beta.item()+m.W_alpha.item())/(2.0**(m.W_precision.get_bits())-1) / 2 64 | if (m.__class__.__name__ == "PACT_QuantizedBatchNormNd"): 65 | m.kappa.data[:] += m.eps_kappa/2 66 | m.lamda.data[:] += m.eps_lamda/2 67 | 68 | def _set_deployment_pact(self, eps_in, only_activations=False, **kwargs): 69 | r"""Sets the network in deployment mode, enabling saving it to ONNX format or similar. 70 | 71 | :param eps_in: Input precision quantum. 72 | :type eps_in: float 73 | 74 | """ 75 | 76 | self.stage = 'qd' 77 | if not only_activations: 78 | self.eps_in = eps_in 79 | self.set_eps_in(eps_in) 80 | for n,m in self.named_modules(): 81 | if ((not only_activations and m.__class__.__name__ == "PACT_Conv2d") or \ 82 | (not only_activations and m.__class__.__name__ == "PACT_Conv1d") or \ 83 | (not only_activations and m.__class__.__name__ == "PACT_Linear") or \ 84 | (not only_activations and m.__class__.__name__ == "PACT_IntegerAdd") or \ 85 | m.__class__.__name__ == "PACT_Act"): 86 | m.deployment = True 87 | if (m.__class__.__name__ == "PACT_Act"): 88 | m.set_static_precision(**kwargs) 89 | 90 | def _set_eps_in_pact(self, eps_in): 91 | r"""Sets the input precision quantum of the network. 92 | 93 | :param eps_in: Input precision quantum. 94 | :type eps_in: float 95 | 96 | """ 97 | 98 | assert(hasattr(self, 'graph')) 99 | assert(self.graph is not None) 100 | self.graph.rebuild_module_dict() 101 | for n,m in self.named_modules(): 102 | if (m.__class__.__name__ == "PACT_Act" or \ 103 | m.__class__.__name__ == "PACT_QuantizedBatchNormNd" or \ 104 | m.__class__.__name__ == "PACT_IntegerAdd"): 105 | eps_in_new = self.get_eps_at(n, eps_in) 106 | if eps_in_new is None: 107 | continue 108 | if (m.__class__.__name__ == "PACT_Act"): 109 | m.eps_in = eps_in_new.clone().detach().requires_grad_(False) 110 | if (m.__class__.__name__ == "PACT_QuantizedBatchNormNd"): 111 | m.eps_in = eps_in_new.clone().detach().requires_grad_(False) 112 | if (m.__class__.__name__ == "PACT_IntegerAdd"): 113 | eps_in_list = [] 114 | for eps in eps_in_new: 115 | eps_in_list.append(torch.tensor(eps.item(), requires_grad=False)) 116 | m.eps_in_list = eps_in_list 117 | 118 | def _qd_stage(self, eps_in, add_input_bias_dict={}, remove_bias_dict={}, prune_empty_bn=True, int_accurate=True, bn_calibration_fn=None, bn_calibration_range_factor=8, **kwargs): 119 | r"""High-level function to move the network from FQ to QD stage. 120 | 121 | :param eps_in: Input precision quantum (required). 122 | :type eps_in: float 123 | :param add_input_bias_dict: dictionary of layers to which an input bias must be added (layer name as key, bias as value). 124 | :type add_input_bias_dict: dict or `collections.OrderedDict` 125 | :param remove_bias_dict: dictionary of Linear->BatchNorm couples where bias must be absorbed by the BatchNorm (Linear name as key, BatchNorm name as value). 126 | :type remove_bias_dict: dict or `collections.OrderedDict` 127 | :param prune_empty_bn: if True (default), BatchNorm channel multiplicative parameters are pruned if smaller than 1e-9. 128 | :type prune_empty_bn: bool 129 | :param int_accurate: if True (default), target an accurate representation of ID numerical dynamics (e.g., with requantization) at QD stage. 130 | :type int_accurate: bool 131 | :param bn_calibration_fn: if not None (default), a function (e.g., calling validation) used to calibrate BatchNorm range. 132 | :type bn_calibration_fn: function 133 | :param bn_calibration_range_factor: if bn_calibration_fn is None, multiply the clipping parameter of the following Activation multiplied by bn_calibration_range_factor to estimate BatchNorm range. 134 | :type bn_calibration_range_factor: int 135 | 136 | """ 137 | 138 | if prune_empty_bn: 139 | self.prune_empty_bn(threshold=1e-9) 140 | self.round_weights() 141 | self.harden_weights() 142 | if add_input_bias_dict: 143 | self.add_input_bias(add_input_bias_dict) 144 | if remove_bias_dict: 145 | self.remove_bias(remove_bias_dict) 146 | if int_accurate: 147 | nemo.transform.bn_quantizer(self, **kwargs) 148 | else: # this is mainly useful for debug purposes, to identify misalignments FQ/QD 149 | for n,m in self.named_modules(): 150 | if (m.__class__.__name__ == "PACT_Act"): 151 | m.precise = True 152 | self.set_deployment(eps_in=eps_in, **kwargs) # with initial BN eps 153 | if bn_calibration_fn is not None: 154 | with self.statistics_bn(): 155 | bn_calibration_fn() 156 | self.calibrate_bn(minmax=True, **kwargs) 157 | else: 158 | self.calibrate_bn(minmax=False, range_factor=bn_calibration_range_factor, **kwargs) 159 | self.set_deployment(eps_in=eps_in, **kwargs) # repeat, to fix BN eps 160 | self.harden_weights() 161 | 162 | def _id_stage(self, eps_in=None, **kwargs): 163 | r"""High-level function to move the network from QD to ID stage. 164 | 165 | :param eps_in: Input precision quantum, default None (will use the previously saved eps_in). 166 | :type eps_in: float 167 | 168 | """ 169 | 170 | if self.stage == 'fq': 171 | self.qd_stage(eps_in=eps_in, **kwargs) 172 | self.stage = 'id' 173 | if eps_in is None: 174 | eps_in = self.eps_in 175 | nemo.transform.integerize_pact(self, eps_in=eps_in, **kwargs) 176 | 177 | -------------------------------------------------------------------------------- /nemo/transf/equalize.py: -------------------------------------------------------------------------------- 1 | # 2 | # equalize.py 3 | # Francesco Conti 4 | # 5 | # Copyright (C) 2018-2020 ETH Zurich 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | import torch 20 | from nemo.precision import Precision 21 | from nemo.quant.pact import * 22 | from nemo.graph import DeployGraph 23 | from torch.nn.modules.utils import _single,_pair 24 | from collections import OrderedDict 25 | import types 26 | import logging 27 | import numpy as np 28 | import copy 29 | import math 30 | import torchvision.models 31 | import re 32 | from nemo.transf.common import * 33 | from sklearn import linear_model 34 | 35 | # Part of the procedure necessary for DFQ as described here https://arxiv.org/pdf/1906.04721.pdf 36 | def _equalize_weights_dfq_pact(self, equalize_dict={}, act_dict={}, verbose=False, cost_eps=1e-3, max_iter=1000, reset_alpha=True): 37 | r"""This function implements the cross-layer weight-range equalization procedure proposed in 38 | the Data-Free Quantization paper by Qualcomm (https://arxiv.org/pdf/1906.04721.pdf). 39 | It should be used only after batch-normalization layers have been folded into convolution 40 | by means of the `fold_bn` or `fold_bn_withinv` methods. 41 | 42 | :param equalize_dict: a dictionary of layer names, with the key being a Linear and the value the next Linear layer. 43 | :type equalize_dict: `dict` or `collections.OrderedDict` 44 | :param act_dict: a dictionary of layer names, with the key being a Linear and the value the next Act layer. If empty, activation alpha scaling is not performed unless `equalize_dict` is also empty. 45 | :type act_dict: `dict` or `collections.OrderedDict` 46 | :param verbose: if True, prints more information. 47 | :type verbose: bool 48 | :param cost_eps: equalization will iterate until the cost is less than this threshold or the number of iterations is greater than `max_iter`. 49 | :type cost_eps: float 50 | :param max_iter: maximum number of iterations. 51 | :type max_iter: int 52 | :param reset_alpha: if True, reset the clipping parameters of weights (default True). 53 | :type reset_alpha: bool 54 | 55 | """ 56 | 57 | if not equalize_dict: 58 | equalize_dict, act_dict = get_equalize_dict_from_supernodes(self) 59 | 60 | module_dict = {} 61 | for n,m in self.named_modules(): 62 | if (m.__class__.__name__ == "PACT_Conv2d" or \ 63 | m.__class__.__name__ == "PACT_Conv1d" or \ 64 | m.__class__.__name__ == "PACT_Linear" or \ 65 | m.__class__.__name__ == "BatchNorm2d" or \ 66 | m.__class__.__name__ == "BatchNorm1d" or \ 67 | m.__class__.__name__ == "PACT_Act"): 68 | module_dict[n] = m 69 | it = 0 70 | cost = 1e10 71 | while cost > cost_eps and it < max_iter: 72 | cost = 0.0 73 | for n_before in equalize_dict.keys(): 74 | n_after = equalize_dict[n_before] 75 | m_before = module_dict[n_before] 76 | m_after = module_dict[n_after] 77 | range_before = weight_range(m_before, 0) 78 | range_after = weight_range(m_after, 1) 79 | old_prec_before_mean = (weight_range(m_before, 0).abs() / (m_before.weight.max() - m_before.weight.min()).abs()).sum().item() 80 | old_prec_after_mean = (weight_range(m_after, 0).abs() / (m_after.weight.max() - m_after.weight.min()).abs()).sum().item() 81 | 82 | # this happens when the two layers are across a Flatten operation 83 | flatten_flag = False 84 | if range_after.shape[0] != range_before.shape[0]: 85 | flatten_flag = True 86 | range_after = range_after.reshape((range_before.shape[0], -1)) 87 | flatten_dim = range_after.shape[1] 88 | range_after = range_after.max(1)[0] 89 | 90 | s = torch.sqrt(range_after/range_before) 91 | m_before.weight.data[:] = m_before.weight.data[:] * reshape_before(m_before, s) 92 | if act_dict: 93 | # per-layer: has to use s max! 94 | module_dict[act_dict[n_before]].alpha.data[:] *= s.max() 95 | try: 96 | m_before.bias.data[:] = m_before.bias.data[:] * s 97 | except AttributeError: 98 | pass 99 | 100 | if flatten_flag: 101 | s = torch.cat(flatten_dim*(s.unsqueeze(1),),1).flatten() 102 | 103 | m_after.weight.data[:] = m_after.weight.data[:] / reshape_after(m_after, s) 104 | new_prec_before_mean = (weight_range(m_before, 0).abs() / (m_before.weight.max() - m_before.weight.min()).abs()).sum().item() 105 | new_prec_after_mean = (weight_range(m_after, 0).abs() / (m_after.weight.max() - m_after.weight.min()).abs()).sum().item() 106 | cost += np.abs(new_prec_before_mean*new_prec_after_mean - old_prec_before_mean*old_prec_after_mean) 107 | it += 1 108 | if verbose: 109 | logging.info("[DFQ Equalization] cost=%.5f" % cost) 110 | logging.info("[DFQ Equalization] terminated after %d iterations" % it) 111 | if reset_alpha: 112 | self.reset_alpha_weights() 113 | 114 | def _equalize_weights_unfolding_pact(self, bn_dict={}, verbose=False, eps=None): 115 | r"""Performs in-layer equalization by unfolding of convolution parameters 116 | into batch-normalization parameters. 117 | 118 | :param bn_dict: a dictionary of layer names, with the key being the source (linear) and the value the target (batch-norm). 119 | :type bn_dict: `dict` or `collections.OrderedDict` 120 | :param verbose: if True, prints more information. 121 | :type verbose: bool 122 | :param eps: if not None (the default), overrides numerical `eps` used within batch-norm layer. 123 | :type eps: float 124 | 125 | """ 126 | 127 | if not bn_dict: 128 | bn_dict = get_bn_dict_from_supernodes(self) 129 | 130 | module_dict = {} 131 | for n,m in self.named_modules(): 132 | if (m.__class__.__name__ == "PACT_Conv2d" or \ 133 | m.__class__.__name__ == "PACT_Conv1d" or \ 134 | m.__class__.__name__ == "PACT_Linear" or \ 135 | m.__class__.__name__ == "BatchNorm2d" or \ 136 | m.__class__.__name__ == "BatchNorm1d" ): 137 | module_dict[n] = m 138 | for n_before in bn_dict.keys(): 139 | n_after = bn_dict[n_before] 140 | m_before = module_dict[n_before] 141 | m_after = module_dict[n_after] 142 | if eps is None: 143 | eps = m_after.eps 144 | range_before = weight_range(m_before, 0) 145 | if verbose: 146 | logging.info("[Equalization by Unfolding] %s: wrange_min=%.5f wrange_max=%.5f" % (n_before, range_before.min().item(), range_before.max().item())) 147 | m_before.weight.data[:] = m_before.weight.data[:] / reshape_before(m_before, range_before) 148 | try: 149 | m_before.bias.data[:] = m_before.bias.data[:] / range_before 150 | except AttributeError: 151 | pass 152 | m_after.running_mean.data[:] = m_after.running_mean.data[:] / range_before 153 | m_after.weight.data[:] = m_after.weight.data[:] * reshape_after(m_after, range_before) 154 | if verbose: 155 | logging.info("[Equalization by Unfolding] %s: wrange_min=%.5f wrange_max=%.5f" % (n_before, weight_range(m_before, 0).min().item(), weight_range(m_before, 0).max().item())) 156 | 157 | def _equalize_weights_lsq_pact(self, bn_dict={}, verbose=False, eps=None): 158 | r"""Performs in-layer equalization by unfolding of convolution parameters 159 | into batch-normalization parameters. 160 | 161 | :param bn_dict: a dictionary of layer names, with the key being the source (linear) and the value the target (batch-norm). 162 | :type bn_dict: `dict` or `collections.OrderedDict` 163 | :param verbose: if True, prints more information. 164 | :type verbose: bool 165 | :param eps: if not None (the default), overrides numerical `eps` used within batch-norm layer. 166 | :type eps: float 167 | 168 | """ 169 | 170 | if not bn_dict: 171 | bn_dict = get_bn_dict_from_supernodes(self) 172 | 173 | module_dict = {} 174 | for n,m in self.named_modules(): 175 | if (m.__class__.__name__ == "PACT_Conv2d" or \ 176 | m.__class__.__name__ == "PACT_Conv1d" or \ 177 | m.__class__.__name__ == "PACT_Linear" or \ 178 | m.__class__.__name__ == "BatchNorm2d" or \ 179 | m.__class__.__name__ == "BatchNorm1d" ): 180 | module_dict[n] = m 181 | for n_before in bn_dict.keys(): 182 | n_after = bn_dict[n_before] 183 | m_before = module_dict[n_before] 184 | m_after = module_dict[n_after] 185 | if eps is None: 186 | eps = m_after.eps 187 | min_before = weight_min(m_before, 0).cpu().detach().numpy() 188 | max_before = weight_max(m_before, 0).cpu().detach().numpy() 189 | if verbose: 190 | logging.info("[Equalization by Least Squares] %s: wrange_min=%.5f wrange_max=%.5f" % (n_before, weight_range(m_before, 0).min().item(), weight_range(m_before, 0).max().item())) 191 | X = np.vstack((min_before, max_before)) 192 | y = np.asarray((-1,1)) 193 | coeff = torch.zeros(len(min_before), device=m_before.weight.device) 194 | regr = linear_model.LinearRegression(fit_intercept=False) 195 | for i in range(len(min_before)): 196 | regr.fit(X[:,i].reshape((-1,1)), y) 197 | coeff[i] = torch.as_tensor(regr.coef_[0], device=m_before.weight.device) 198 | coeff = 1./coeff 199 | m_before.weight.data[:] = m_before.weight.data[:] / reshape_before(m_before, coeff) 200 | try: 201 | m_before.bias.data[:] = m_before.bias.data[:] / coeff 202 | except AttributeError: 203 | pass 204 | m_after.running_mean.data[:] = m_after.running_mean.data[:] / coeff 205 | m_after.weight.data[:] = m_after.weight.data[:] * reshape_after(m_after, coeff) 206 | if verbose: 207 | logging.info("[Equalization by Least Squares] %s: wrange_min=%.5f wrange_max=%.5f" % (n_before, weight_range(m_before, 0).min().item(), weight_range(m_before, 0).max().item())) 208 | 209 | -------------------------------------------------------------------------------- /nemo/transf/export.py: -------------------------------------------------------------------------------- 1 | # 2 | # export.py 3 | # Francesco Conti 4 | # 5 | # Copyright (C) 2018-2020 ETH Zurich 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | import torch 20 | from nemo.precision import Precision 21 | from nemo.quant.pact import * 22 | from nemo.graph import DeployGraph 23 | from torch.nn.modules.utils import _single,_pair 24 | from collections import OrderedDict 25 | import types 26 | import logging 27 | import numpy as np 28 | import copy 29 | import math 30 | import torchvision.models 31 | import re 32 | from nemo.transf.common import * 33 | 34 | def _export_precision(self): 35 | r"""Returns a dictionary of precisions for each layer. 36 | 37 | """ 38 | 39 | d = OrderedDict([]) 40 | for n,m in self.named_modules(): 41 | d[n] = OrderedDict([]) 42 | try: 43 | d[n]['x_bits'] = m.precision.get_bits() 44 | if not "PACT" in m.__class__.__name__: 45 | d[n]['x_scale'] = m.precision.get_scale() 46 | except AttributeError: 47 | pass 48 | try: 49 | d[n]['W_bits'] = m.W_precision.get_bits() 50 | if not "PACT" in m.__class__.__name__: 51 | d[n]['W_scale'] = m.W_precision.get_scale() 52 | except AttributeError: 53 | pass 54 | if len(d[n].keys()) == 0 or n == "": 55 | d.pop(n, None) 56 | return d 57 | 58 | def _export_weights_legacy_int16(self, header_name='weights.h', save_binary=False, folder_name='.', x_alpha_safety_factor=1.): 59 | r"""Exports weights and bias values with the legacy strategies used e.g. in PULP-DroNet, 60 | towards INT-16. Quantization is fully symmetric and aligned to power-of-two `alpha` so that 61 | there is no need to propagate :math:`\varepsilon` values. 62 | 63 | :param header_name: name of a header file. 64 | :type header_name: string 65 | :param save_binary: if True, saves also a binary version. 66 | :type save_binary: bool 67 | :param folder_name: name of the folder where to save binaries. 68 | :type folder_name: string 69 | 70 | """ 71 | 72 | weight_dict = {} 73 | bias_dict = {} 74 | qi_dict = {} # actually qi-1, as 1 bit is for sign 75 | x_alpha = 0.001 76 | W_alpha = {} 77 | bigstr = "/* weights & biases */\n\n#include \n\n\n" 78 | checkstr_w = "" 79 | checkstr_b = "" 80 | for n,m in self.named_modules(): 81 | if (m.__class__.__name__ == "PACT_Act"): 82 | x_alpha = max(x_alpha, m.alpha.item()) 83 | for n,m in self.named_modules(): 84 | if (m.__class__.__name__ == "PACT_Act"): 85 | m.alpha.data[:] = 2.**int(math.ceil(math.log2(x_alpha))) 86 | x_alpha *= x_alpha_safety_factor 87 | for n,m in self.named_modules(): 88 | if (m.__class__.__name__ == "PACT_Conv2d" or \ 89 | m.__class__.__name__ == "PACT_Conv1d" or \ 90 | m.__class__.__name__ == "PACT_Linear"): 91 | W_alpha[n] = max(-m.W_alpha.item(), m.W_beta.item()) 92 | qi_dict[n] = int(math.ceil(math.log2(W_alpha[n]))) 93 | W_eps = 2.**-(16-qi_dict[n]-1) 94 | m.W_beta.data[:] = 2.**qi_dict[n] 95 | m.W_alpha.data[:] = 2.**qi_dict[n] 96 | # performs also weight hardening, destructive! 97 | m.harden_weights() 98 | weight_dict[n] = np.int16(((m.weight.data.clone().detach().to('cpu').numpy()) / W_eps)) 99 | m.weight.data[:] = torch.tensor(weight_dict[n] * W_eps) 100 | x_eps = 2.**-(16-int(math.ceil(math.log2(x_alpha)))-1) 101 | try: 102 | bias_dict[n] = np.int16(m.bias.data.clone().detach().to('cpu').numpy() / x_eps) 103 | m.bias.data[:] = torch.tensor(bias_dict[n] * x_eps) 104 | except AttributeError: 105 | bias_dict[n] = np.int16(np.zeros(weight_dict[n].shape[0])) 106 | import re 107 | n_str = re.sub('[^0-9a-zA-Z_]+', '_', n) 108 | bigstr += "// %s weights [shape=%s, qi=%d, qf=%d]\n" % (n, weight_dict[n].shape, qi_dict[n]+1, 16-qi_dict[n]-1) 109 | bigstr += "int16_t w_%s[] = {\n " % n_str 110 | for i in range(len(weight_dict[n].flatten())-1): 111 | bigstr += "0x%04x,\n " % np.uint16(weight_dict[n].flatten()[i]) 112 | bigstr += "0x%04x\n};\n\n" % np.uint16(weight_dict[n].flatten()[-1]) 113 | bigstr += "// %s bias [shape=%s, qi=%d, qf=%d]\n" % (n, bias_dict[n].shape, int(math.ceil(math.log2(x_alpha)))+1, 16-int(math.ceil(math.log2(x_alpha)))-1) 114 | bigstr += "int16_t b_%s[] = {\n " % n_str 115 | for i in range(len(bias_dict[n].flatten())-1): 116 | bigstr += "0x%04x,\n " % np.uint16(bias_dict[n].flatten()[i]) 117 | bigstr += "0x%04x\n};\n\n\n" % np.uint16(bias_dict[n].flatten()[-1]) 118 | if save_binary: 119 | with open("%s/weights_%s.hex" % (folder_name, n_str), "w") as file: 120 | weight_dict[n].flatten().tofile(file) 121 | with open("%s/bias_%s.hex" % (folder_name, n_str), "w") as file: 122 | bias_dict[n].flatten().tofile(file) 123 | checkstr_w += "Checksum weights_%s:\t%s\n" % (n_str, sum(weight_dict[n].flatten())) 124 | checkstr_b += "Checksum bias_%s:\t%s\n" % (n_str, sum(bias_dict[n].flatten())) 125 | print("Export procedure completed, qi=%d qf=%d for activations" % (int(math.ceil(math.log2(x_alpha)))+1, 16-int(math.ceil(math.log2(x_alpha)))-1)) 126 | with open("%s/%s" % (folder_name, header_name), "w") as file: 127 | file.write(bigstr) 128 | with open("%s/checksum.txt" % (folder_name), "w") as file: 129 | file.write(checkstr_w) 130 | file.write(checkstr_b) -------------------------------------------------------------------------------- /nemo/transf/pruning.py: -------------------------------------------------------------------------------- 1 | # 2 | # pruning.py 3 | # Francesco Conti 4 | # 5 | # Copyright (C) 2018-2020 ETH Zurich 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | import torch 20 | from nemo.precision import Precision 21 | from nemo.quant.pact import * 22 | from nemo.graph import DeployGraph 23 | from torch.nn.modules.utils import _single,_pair 24 | from collections import OrderedDict 25 | import types 26 | import logging 27 | import numpy as np 28 | import copy 29 | import math 30 | import torchvision.models 31 | import re 32 | from nemo.transf.common import * 33 | 34 | def _prune_weights_pact(self, **kwargs): 35 | r"""Prune weights. 36 | 37 | """ 38 | 39 | pruned = 0 40 | for n,m in self.named_modules(): 41 | if (m.__class__.__name__ == "PACT_Conv2d" or \ 42 | m.__class__.__name__ == "PACT_Conv1d" or \ 43 | m.__class__.__name__ == "PACT_Linear"): 44 | pruned += m.prune_weights(**kwargs) 45 | return pruned 46 | -------------------------------------------------------------------------------- /nemo/transf/sawb.py: -------------------------------------------------------------------------------- 1 | # 2 | # sawb.py 3 | # Francesco Conti 4 | # 5 | # Copyright (C) 2018-2021 ETH Zurich and University of Bologna 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | import torch 20 | from nemo.precision import Precision 21 | from nemo.quant.pact import * 22 | from nemo.graph import DeployGraph 23 | from torch.nn.modules.utils import _single,_pair 24 | from collections import OrderedDict 25 | import types 26 | import logging 27 | import numpy as np 28 | import copy 29 | import math 30 | import torchvision.models 31 | import re 32 | from nemo.transf.common import * 33 | 34 | # first entry is c1, second is c2 with alpha_w^* = c1 * sqrt(Ew2) - c2 * Ew1 35 | __sawb_asymm_lut = { 36 | 2: [8.356, 7.841], 37 | 3: [4.643, 3.729], 38 | 4: [8.356, 7.841], 39 | 5: [12.522, 12.592], 40 | 6: [15.344, 15.914], 41 | 7: [19.767, 21.306], 42 | 8: [26.294, 29.421] 43 | } 44 | 45 | # Disable gradients for alpha,beta params 46 | def _disable_grad_sawb(self, layer_bits={}): 47 | 48 | # Colab with SAWB LUT: https://colab.research.google.com/drive/1UEQnvVcSP3N-QTZLEumbbGCv_oLv-JtL 49 | module_dict = {} 50 | use_default = False 51 | if not layer_bits: 52 | layer_bits = {} 53 | use_default = True 54 | for n,m in self.named_modules(): 55 | if (m.__class__.__name__ == "PACT_Conv2d" or \ 56 | m.__class__.__name__ == "PACT_Conv1d" or \ 57 | m.__class__.__name__ == "PACT_Linear"): 58 | m.W_alpha.requires_grad = False 59 | m.W_beta.requires_grad = False 60 | 61 | # Set weight clipping parameters according to Statistics-Aware Weight Binning 62 | def _weight_clip_sawb(self, asymmetric=True, layer_bits={}, check_minmax=True, verbose=False): 63 | 64 | # Colab with SAWB LUT: https://colab.research.google.com/drive/1UEQnvVcSP3N-QTZLEumbbGCv_oLv-JtL 65 | module_dict = {} 66 | use_default = False 67 | if not layer_bits: 68 | layer_bits = {} 69 | use_default = True 70 | for n,m in self.named_modules(): 71 | if (m.__class__.__name__ == "PACT_Conv2d" or \ 72 | m.__class__.__name__ == "PACT_Conv1d" or \ 73 | m.__class__.__name__ == "PACT_Linear"): 74 | if use_default: 75 | module_dict[n] = m 76 | layer_bits[n] = m.W_precision.get_bits() 77 | elif n in layer_bits.keys(): 78 | module_dict[n] = m 79 | 80 | for n in module_dict.keys(): 81 | m = module_dict[n] 82 | # compute E[|w|] 83 | Ew1 = m.weight.abs().mean() 84 | # compute E[w^2] 85 | Ew2 = (m.weight.abs() ** 2).mean() 86 | # compute alpha 87 | alpha = __sawb_asymm_lut[layer_bits[n]][0] * torch.sqrt(Ew2) - __sawb_asymm_lut[layer_bits[n]][1] * Ew1 88 | # compute beta 89 | eps = 2*alpha / (2**layer_bits[n]) 90 | if asymmetric: 91 | beta = alpha + eps * (2**layer_bits[n]-1) 92 | else: 93 | beta = alpha + eps * 2**layer_bits[n] 94 | if check_minmax: 95 | m.W_alpha.data[:] = min(alpha, m.weight.min().abs()) 96 | m.W_beta.data[:] = min(beta, m.weight.max().abs()) 97 | else: 98 | m.W_alpha.data[:] = alpha 99 | m.W_beta.data[:] = beta 100 | if verbose: 101 | print("[weight clip SAWB] %s: Ew1=%.3e Ew2=%.3e alpha=%.3e beta=%.3e" % (n, Ew1, Ew2, m.W_alpha.data.item(), m.W_beta.data.item())) 102 | 103 | -------------------------------------------------------------------------------- /nemo/transf/statistics.py: -------------------------------------------------------------------------------- 1 | # 2 | # statistics.py 3 | # Francesco Conti 4 | # 5 | # Copyright (C) 2018-2020 ETH Zurich 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | import torch 20 | from nemo.precision import Precision 21 | from nemo.quant.pact import * 22 | from nemo.graph import DeployGraph 23 | from torch.nn.modules.utils import _single,_pair 24 | from collections import OrderedDict 25 | import types 26 | import logging 27 | import numpy as np 28 | import copy 29 | import math 30 | import torchvision.models 31 | import re 32 | from nemo.transf.common import * 33 | from contextlib import contextmanager 34 | 35 | @contextmanager 36 | def _statistics_act_pact(self): 37 | r"""Used with `with net.statistics_act():`, calls `net.set_statistics_act()` on enter 38 | and `net.unset_statistics_act()` on exit. 39 | 40 | """ 41 | self.set_statistics_act() 42 | try: 43 | yield 44 | finally: 45 | self.unset_statistics_act() 46 | 47 | @contextmanager 48 | def _statistics_bn_pact(self): 49 | r"""Used with `with net.statistics_bn():`, calls `net.set_statistics_bn()` on enter 50 | and `net.unset_statistics_bn()` on exit. 51 | 52 | """ 53 | self.set_statistics_bn() 54 | try: 55 | yield 56 | finally: 57 | self.unset_statistics_bn() 58 | 59 | def _set_statistics_act_pact(self): 60 | r"""Sets :py:class:`nemo.quant.PACT_Act` layers to collect statistics and work like ReLU's. 61 | 62 | """ 63 | 64 | for n,m in self.named_modules(): 65 | if m.__class__.__name__ == "PACT_Act": 66 | m.statistics_only = True 67 | 68 | def _get_statistics_act_pact(self): 69 | r"""Returns the statistics collected by :py:class:`nemo.quant.PACT_Act` layers. 70 | 71 | """ 72 | 73 | d = OrderedDict([]) 74 | for n,m in self.named_modules(): 75 | d[n] = OrderedDict([]) 76 | if m.__class__.__name__ == "PACT_Act": 77 | d[n]['max'] = m.get_statistics()[0] 78 | d[n]['running_mean'] = m.get_statistics()[1] 79 | d[n]['running_var'] = m.get_statistics()[2] 80 | d[n]['active'] = m.statistics_only 81 | return d 82 | 83 | def _unset_statistics_act_pact(self): 84 | r"""Sets :py:class:`nemo.quant.PACT_Act` layers to act normally and stop statistics collection. 85 | 86 | """ 87 | 88 | for n,m in self.named_modules(): 89 | if m.__class__.__name__ == "PACT_Act": 90 | m.statistics_only = False 91 | 92 | def _set_statistics_bn_pact(self): 93 | r"""Sets :py:class:nemo.quant.PACT_QuantizedBatchNormNd` layers to collect statistics. 94 | 95 | """ 96 | 97 | for n,m in self.named_modules(): 98 | if m.__class__.__name__ == "PACT_QuantizedBatchNormNd": 99 | m.statistics_only = True 100 | 101 | def _unset_statistics_bn_pact(self): 102 | r"""Sets :py:class:`nemo.quant.PACT_QuantizedBatchNormNd` layers to act normally and stop statistics collection. 103 | 104 | """ 105 | 106 | for n,m in self.named_modules(): 107 | if m.__class__.__name__ == "PACT_QuantizedBatchNormNd": 108 | m.statistics_only = False 109 | -------------------------------------------------------------------------------- /nemo/transf/utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # utils.py 3 | # Francesco Conti 4 | # 5 | # Copyright (C) 2018-2020 ETH Zurich 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | import torch 20 | import nemo 21 | from nemo.precision import Precision 22 | from nemo.quant.pact import * 23 | from nemo.graph import DeployGraph 24 | from torch.nn.modules.utils import _single,_pair 25 | from collections import OrderedDict 26 | import types 27 | import logging 28 | import numpy as np 29 | import copy 30 | import math 31 | import torchvision.models 32 | import re 33 | from nemo.transf.common import * 34 | 35 | def _change_precision_pact(self, bits=4, scale_activations=True, scale_weights=True, verbose=True, reset_alpha=True, min_prec_dict=None, **kwargs): 36 | r"""Changes the target precision of a PACT quantization-aware layer. 37 | 38 | 39 | :param bits: target bit-width. 40 | :type bits: `int` 41 | 42 | :param scale_activations: if False, do not change precision of activations (default True). 43 | :type scale_activations: boolean 44 | 45 | :param scale_weights: if False, do not change precision of weights (default True). 46 | :type scale_weights: boolean 47 | 48 | :param verbose: if False, do not log precision information (default True). 49 | :type verbose: boolean 50 | 51 | :param reset_alpha: if False, do not reset weight scale parameter upon precision change (default True). 52 | :type reset_alpha: boolean 53 | 54 | :param min_prec_dict: dictionary of minimum layer-by-layer precisions (default None). 55 | :type min_prec_dict: dictionary 56 | 57 | """ 58 | if scale_activations and bits is not None: 59 | self.x_precision.bits = bits 60 | if scale_weights and bits is not None: 61 | self.W_precision.bits = bits 62 | for n,m in self.named_modules(): 63 | min_prec_x = copy.deepcopy(self.x_precision) 64 | min_prec_W = copy.deepcopy(self.W_precision) 65 | if min_prec_dict is not None: 66 | try: 67 | min_prec_x.bits = min_prec_dict[n]['x_bits'] 68 | except KeyError: 69 | pass 70 | try: 71 | min_prec_W.bits = min_prec_dict[n]['W_bits'] 72 | except KeyError: 73 | pass 74 | if m.__class__.__name__ == "PACT_Act" and scale_activations: 75 | m.precision = max(self.x_precision, min_prec_x) 76 | if scale_weights and (m.__class__.__name__ == "PACT_Conv2d" or \ 77 | m.__class__.__name__ == "PACT_Conv1d" or \ 78 | m.__class__.__name__ == "PACT_Linear"): 79 | m.W_precision = max(self.W_precision, min_prec_W) 80 | if reset_alpha: 81 | m.reset_alpha_weights() 82 | if verbose and (m.__class__.__name__ == "PACT_Act") and scale_activations: 83 | try: 84 | logging.info("[Quant]\t\t %s: x_bits=%.2f" % (n, m.precision.get_bits())) 85 | except AttributeError: 86 | pass 87 | if verbose and scale_weights and (m.__class__.__name__ == "PACT_Conv2d" or \ 88 | m.__class__.__name__ == "PACT_Conv1d" or \ 89 | m.__class__.__name__ == "PACT_Linear"): 90 | try: 91 | logging.info("[Quant]\t\t %s: W_bits=%.2f" % (n, m.W_precision.get_bits())) 92 | except AttributeError: 93 | pass 94 | 95 | def _set_train_loop_pact(self): 96 | r"""Sets modules so that weights are not treated like hardened (e.g., for training). 97 | 98 | """ 99 | 100 | for n,m in self.named_modules(): 101 | if (m.__class__.__name__ == "PACT_Conv2d" or \ 102 | m.__class__.__name__ == "PACT_Conv1d" or \ 103 | m.__class__.__name__ == "PACT_Linear" ): 104 | m.train_loop = True 105 | 106 | def _unset_train_loop_pact(self): 107 | r"""Sets modules so that weights are treated like hardened (e.g., for evaluation). 108 | 109 | """ 110 | 111 | for n,m in self.named_modules(): 112 | if (m.__class__.__name__ == "PACT_Conv2d" or \ 113 | m.__class__.__name__ == "PACT_Conv1d" or \ 114 | m.__class__.__name__ == "PACT_Linear" ): 115 | m.train_loop = False 116 | m.train_loop_oldprec = float(m.W_beta.item()+m.W_alpha.item())/(2.0**(m.W_precision.get_bits())-1) 117 | 118 | def _reset_alpha_act_pact(self, **kwargs): 119 | r"""Resets :py:class:`nemo.quant.PACT_Act` parameter `alpha` the value collected through statistics. 120 | 121 | """ 122 | 123 | for n,m in self.named_modules(): 124 | if m.__class__.__name__ == "PACT_Act": 125 | m.reset_alpha(**kwargs) 126 | 127 | def _get_nonclip_parameters_pact(self): 128 | r"""Yields all parameters except for `alpha` values. 129 | 130 | """ 131 | 132 | for name, param in self.named_parameters(recurse=True): 133 | if name[-5:] != 'alpha': 134 | yield param 135 | 136 | def _get_clip_parameters_pact(self): 137 | r"""Yields all `alpha` parameters. 138 | 139 | """ 140 | 141 | for name, param in self.named_parameters(recurse=True): 142 | if name[-5:] == 'alpha': 143 | yield param 144 | 145 | def _reset_alpha_weights_pact(self, method='standard', **kwargs): 146 | r"""Resets parameter `W_alpha`. 147 | 148 | """ 149 | 150 | for n,m in self.named_modules(): 151 | if (m.__class__.__name__ == "PACT_Conv2d" or \ 152 | m.__class__.__name__ == "PACT_Conv1d" or \ 153 | m.__class__.__name__ == "PACT_Linear"): 154 | m.reset_alpha_weights(**kwargs) 155 | 156 | -------------------------------------------------------------------------------- /nemo/transform.py: -------------------------------------------------------------------------------- 1 | # 2 | # transform.py 3 | # Francesco Conti 4 | # Alfio Di Mauro 5 | # Thorir Mar Ingolfsson 6 | # 7 | # Copyright (C) 2018-2021 ETH Zurich 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | 21 | import torch 22 | from nemo.precision import Precision 23 | from nemo.quant.pact import * 24 | from nemo.graph import DeployGraph 25 | from torch.nn.modules.utils import _single,_pair 26 | from collections import OrderedDict 27 | import types 28 | import logging 29 | import numpy as np 30 | import copy 31 | import math 32 | import torchvision.models 33 | import re 34 | import nemo 35 | from nemo.transf.common import * 36 | from nemo.transf.bias import * 37 | from nemo.transf.bn import * 38 | from nemo.transf.deploy import * 39 | from nemo.transf.equalize import * 40 | from nemo.transf.export import * 41 | from nemo.transf.pruning import * 42 | from nemo.transf.statistics import * 43 | from nemo.transf.utils import * 44 | from nemo.transf.sawb import * 45 | 46 | def quantize_pact(module, W_bits=4, x_bits=4, dummy_input=None, remove_dropout=False, **kwargs): 47 | r"""Takes a PyTorch module and makes it quantization-aware with PACT, recursively. 48 | 49 | The function follows recursively the data structures containing PyTorch layers (typically as hierarchical lists, e.g. 50 | block-level :py:class:`torch.nn.Sequential` for networks like ResNet). 51 | It performs two main kinds of replacements: 52 | - linear layers like :py:class:`torch.nn.Conv2d`, :py:class:`torch.nn.Conv1d`, :py:class:`torch.nn.Linear` are replaced 53 | with quantization-aware versions (:py:class:`nemo.quant.pact.PACT_Conv2d`, :py:class:`nemo.quant.pact.PACT_Conv1d`, 54 | :py:class:`nemo.quant.pact.PACT_Linear`). By default, these layers quantize weights but not (input) activations. 55 | - activation layers like :py:class:`torch.nn.ReLU`, :py:class:`torch.nn.ReLU6` are replaced 56 | with a special quantized activation, :py:class:`nemo.quant.pact.PACT_Act` that performs both clipping and quantization. 57 | 58 | The returned layer exposes a series of methods typical of quantization-aware models: 59 | - `export_precision`, exports the precision of current layers as a dictionary. 60 | - `change_precision`, changes the target precision of the model. 61 | - `reset_alpha_weights`, resets the scaling factors for weights. 62 | - `get_clip_parameters`, returns scaling factor parameters of activations, that can be trained. 63 | - `get_nonclip_parameters`, returns all parameters except scaling factor parameters of activations. 64 | - `harden_weights`, hardens the quantized value of weights into their underlying float representation. 65 | - `prune_weights`, replaces a portion of the weights in the model with 0's. 66 | - `set_statistics_act`, setup activation layers in statistics collection mode (they run as normal ReLUs collecting 67 | the maximum value of activations, to calibrate the scaling parameters). 68 | - `get_statistics_act`, get collected activation layer statistics. 69 | - `unset_statistics_act`, setup activation layers to act normally as quantization-aware layers. 70 | - `reset_alpha_act`, uses the collected activation layer statistics to recalibrate the scaling parameters. 71 | 72 | :param module: module to be transformed to use PACT quantization (typically, a container like :py:class:`torch.nn.ModuleList`). 73 | :type module: `torch.nn.Module` 74 | 75 | :param W_bits: target precision for weights. 76 | :type W_bits: float 77 | 78 | :param x_bits: target precision for activations. 79 | :type x_bits: float 80 | 81 | :param dummy_input: dummy input tensor (default None). Used to derive an adjacency map by tracing 82 | :type dummy_input: `torch.Tensor` 83 | 84 | :param remove_dropout: if True, removes dropout layers before graph construction. 85 | :type remove_dropout: bool 86 | 87 | :return: The quantization-aware module. 88 | :rtype: same as `module` 89 | 90 | """ 91 | # if given a dummy input, get an adjacency map of the module and other useful things 92 | module.eval() 93 | if remove_dropout: 94 | module = nemo.transform.dropout_to_identity(module) 95 | if dummy_input is not None: 96 | module.graph = DeployGraph(module, dummy_input=dummy_input) 97 | else: 98 | module.graph = None 99 | module.stage = 'fq' 100 | module = _hier_quantizer_pact(module, module.graph, **kwargs) 101 | if hasattr(module, 'graph'): 102 | if module.graph is not None: 103 | module.graph.rebuild_module_dict() 104 | module.add_input_bias = types.MethodType(nemo.transf.bias._add_input_bias_pact, module) 105 | module.remove_bias = types.MethodType(nemo.transf.bias._remove_bias_pact, module) 106 | module.remove_input_bias = types.MethodType(nemo.transf.bias._remove_input_bias_pact, module) 107 | module.freeze_bn = types.MethodType(nemo.transf.bn._freeze_bn, module) 108 | module.unfreeze_bn = types.MethodType(nemo.transf.bn._unfreeze_bn, module) 109 | module.absorb_affine_bn = types.MethodType(nemo.transf.bn._absorb_affine_bn, module) 110 | module.prime_affine_bn = types.MethodType(nemo.transf.bn._prime_affine_bn, module) 111 | module.fold_bn = types.MethodType(nemo.transf.bn._fold_bn_pact, module) 112 | module.fold_bn_withinv = types.MethodType(nemo.transf.bn._fold_bn_pact, module) 113 | module.fold_thresholds = types.MethodType(nemo.transf.bn._threshold_folding_pact, module) 114 | module.prune_empty_bn = types.MethodType(nemo.transf.bn._prune_empty_bn_pact, module) 115 | module.calibrate_bn = types.MethodType(nemo.transf.bn._calibrate_bn_pact, module) 116 | module.get_eps_at = types.MethodType(nemo.transf.deploy._get_eps_at, module) 117 | module.set_eps_in = types.MethodType(nemo.transf.deploy._set_eps_in_pact, module) 118 | module.round_weights = types.MethodType(nemo.transf.deploy._round_weights_pact, module) 119 | module.harden_weights = types.MethodType(nemo.transf.deploy._harden_weights_pact, module) 120 | module.set_deployment = types.MethodType(nemo.transf.deploy._set_deployment_pact, module) 121 | module.qd_stage = types.MethodType(nemo.transf.deploy._qd_stage, module) 122 | module.id_stage = types.MethodType(nemo.transf.deploy._id_stage, module) 123 | module.export_precision = types.MethodType(nemo.transf.export._export_precision, module) 124 | module.export_weights_legacy_int16 = types.MethodType(nemo.transf.export._export_weights_legacy_int16, module) 125 | module.change_precision = types.MethodType(nemo.transf.utils._change_precision_pact, module) 126 | module.reset_alpha_weights = types.MethodType(nemo.transf.utils._reset_alpha_weights_pact, module) 127 | module.reset_alpha_act = types.MethodType(nemo.transf.utils._reset_alpha_act_pact, module) 128 | module.get_clip_parameters = types.MethodType(nemo.transf.utils._get_clip_parameters_pact, module) 129 | module.get_nonclip_parameters = types.MethodType(nemo.transf.utils._get_nonclip_parameters_pact, module) 130 | module.set_train_loop = types.MethodType(nemo.transf.utils._set_train_loop_pact, module) 131 | module.unset_train_loop = types.MethodType(nemo.transf.utils._unset_train_loop_pact, module) 132 | module.prune_weights = types.MethodType(nemo.transf.pruning._prune_weights_pact, module) 133 | module.equalize_weights_dfq = types.MethodType(nemo.transf.equalize._equalize_weights_dfq_pact, module) 134 | module.equalize_weights_lsq = types.MethodType(nemo.transf.equalize._equalize_weights_lsq_pact, module) 135 | module.equalize_weights_unfolding = types.MethodType(nemo.transf.equalize._equalize_weights_unfolding_pact, module) 136 | module.statistics_act = types.MethodType(nemo.transf.statistics._statistics_act_pact, module) 137 | module.set_statistics_act = types.MethodType(nemo.transf.statistics._set_statistics_act_pact, module) 138 | module.get_statistics_act = types.MethodType(nemo.transf.statistics._get_statistics_act_pact, module) 139 | module.unset_statistics_act = types.MethodType(nemo.transf.statistics._unset_statistics_act_pact, module) 140 | module.statistics_bn = types.MethodType(nemo.transf.statistics._statistics_bn_pact, module) 141 | module.set_statistics_bn = types.MethodType(nemo.transf.statistics._set_statistics_bn_pact, module) 142 | module.unset_statistics_bn = types.MethodType(nemo.transf.statistics._unset_statistics_bn_pact, module) 143 | module.disable_grad_sawb = types.MethodType(nemo.transf.sawb._disable_grad_sawb, module) 144 | module.weight_clip_sawb = types.MethodType(nemo.transf.sawb._weight_clip_sawb, module) 145 | module.W_precision = Precision(W_bits, None) 146 | module.x_precision = Precision(x_bits, None) 147 | return module 148 | 149 | def _hier_replacer(module, name, replacement): 150 | for n,m in module.named_children(): 151 | if n == name: 152 | module._modules[n] = replacement() 153 | elif n == name.split('.')[0]: 154 | module._modules[n] = _hier_replacer(m, '.'.join(name.split('.')[1:]), replacement) 155 | return module 156 | 157 | def _hier_bn_to_identity(module): 158 | if module.__class__.__name__ == 'BatchNorm2d' or \ 159 | module.__class__.__name__ == 'BatchNorm1d': 160 | module = PACT_Identity() 161 | return module 162 | else: 163 | for n,m in module.named_children(): 164 | module._modules[n] = _hier_bn_to_identity(m) 165 | return module 166 | 167 | def _hier_dropout_to_identity(module): 168 | if module.__class__.__name__ == 'Dropout': 169 | module = PACT_Identity() 170 | return module 171 | else: 172 | for n,m in module.named_children(): 173 | module._modules[n] = _hier_dropout_to_identity(m) 174 | return module 175 | 176 | def _hier_bn_quantizer(module, **kwargs): 177 | if module.__class__.__name__ == 'BatchNorm2d' or \ 178 | module.__class__.__name__ == 'BatchNorm1d': 179 | gamma = module.weight.data[:].clone().detach() 180 | beta = module.bias.data[:].clone().detach() 181 | sigma = torch.sqrt(module.running_var.data[:] + module.eps).clone().detach() 182 | mu = module.running_mean.data[:].clone().detach() 183 | dimensions = 1 if module.__class__.__name__ == 'BatchNorm1d' else 2 184 | module = PACT_QuantizedBatchNormNd(kappa=gamma/sigma, lamda=beta-gamma/sigma*mu, dimensions=dimensions, **kwargs) 185 | return module 186 | else: 187 | for n,m in module.named_children(): 188 | module._modules[n] = _hier_bn_quantizer(m, **kwargs) 189 | return module 190 | 191 | def _hier_bn_dequantizer(module): 192 | if module.__class__.__name__ == 'PACT_QuantizedBatchNormNd': 193 | gamma = module.kappa.data[:].clone().detach().flatten() 194 | beta = module.lamda.data[:].clone().detach().flatten() 195 | module = torch.nn.BatchNorm2d(weight=gamma, bias=beta) 196 | return module 197 | else: 198 | for n,m in module.named_children(): 199 | module._modules[n] = _hier_bn_dequantizer(m) 200 | return module 201 | 202 | def _hier_integerizer(module, **kwargs): 203 | if (module.__class__.__name__ == "PACT_Conv2d" or \ 204 | module.__class__.__name__ == "PACT_Conv1d" or \ 205 | module.__class__.__name__ == "PACT_Linear"): 206 | module.integerize_weights(**kwargs) 207 | return module 208 | elif (module.__class__.__name__ == "PACT_QuantizedBatchNormNd"): 209 | module = PACT_IntegerBatchNormNd(kappa=module.kappa, lamda=module.lamda, eps_in=module.eps_in, eps_kappa=module.eps_kappa, eps_lamda=module.eps_lamda) 210 | module.integerize_weights(**kwargs) 211 | elif (module.__class__.__name__ == "PACT_Act"): 212 | module = PACT_IntegerAct(precision=module.precision, eps_in=module.eps_in, eps_out=module.eps_static, alpha=module.alpha_static, **kwargs) 213 | module.set_output_eps(**kwargs) 214 | elif (module.__class__.__name__ == "PACT_IntegerAdd"): 215 | module.integerized = True 216 | elif (module.__class__.__name__ == "AvgPool2d"): 217 | module = PACT_IntegerAvgPool2d(module.kernel_size, stride=module.stride, padding=module.padding, ceil_mode=module.ceil_mode, 218 | count_include_pad=module.count_include_pad) 219 | elif (module.__class__.__name__ == "AvgPool1d"): 220 | module = PACT_IntegerAvgPool1d(module.kernel_size, stride=module.stride, padding=module.padding, ceil_mode=module.ceil_mode, 221 | count_include_pad=module.count_include_pad) 222 | else: 223 | for n,m in module.named_children(): 224 | module._modules[n] = _hier_integerizer(m, **kwargs) 225 | return module 226 | 227 | def _hier_thresholdizer_pact(module): 228 | if module.__class__.__name__ == 'PACT_Act': 229 | module = PACT_ThresholdAct(precision=module.precision, alpha=module.alpha.data[:]) 230 | return module 231 | else: 232 | for n,m in module.named_children(): 233 | module._modules[n] = _hier_thresholdizer_pact(m) 234 | return module 235 | 236 | def _hier_quantizer_pact(module, graph=None, **kwargs): 237 | if module.__class__.__name__ == 'Conv2d': 238 | W = module.weight.data 239 | try: 240 | b = module.bias.data 241 | except AttributeError: 242 | b = None 243 | module = PACT_Conv2d( 244 | module.in_channels, 245 | module.out_channels, 246 | _single(module.kernel_size), 247 | stride=_single(module.stride), 248 | padding=_single(module.padding), 249 | dilation=_single(module.dilation), 250 | groups=module.groups, 251 | bias=True if module.bias is not None else False 252 | ) 253 | module.weight.data = W.clone() 254 | if b is not None: 255 | module.bias.data = b.clone() 256 | return module 257 | if module.__class__.__name__ == 'Conv1d': 258 | W = module.weight.data 259 | try: 260 | b = module.bias.data 261 | except AttributeError: 262 | b = None 263 | module = PACT_Conv1d( 264 | module.in_channels, 265 | module.out_channels, 266 | _single(module.kernel_size), 267 | stride=_single(module.stride), 268 | padding=_single(module.padding), 269 | dilation=_single(module.dilation), 270 | groups=module.groups, 271 | bias=True if module.bias is not None else False 272 | ) 273 | module.weight.data = W.clone() 274 | if b is not None: 275 | module.bias.data = b.clone() 276 | return module 277 | if module.__class__.__name__ == 'Linear': 278 | W = module.weight.data 279 | try: 280 | b = module.bias.data 281 | except AttributeError: 282 | b = None 283 | module = PACT_Linear( 284 | module.in_features, 285 | module.out_features, 286 | bias=True if module.bias is not None else False 287 | ) 288 | module.weight.data = W.clone() 289 | if b is not None: 290 | module.bias.data = b.clone() 291 | return module 292 | elif module.__class__.__name__ == 'ReLU6': 293 | module = PACT_Act(alpha=6., **kwargs) 294 | return module 295 | elif module.__class__.__name__ == 'ReLU': 296 | module = PACT_Act(**kwargs) 297 | return module 298 | elif module.__class__.__name__ == 'LeakyReLU': 299 | module = PACT_Act(leaky=module.negative_slope, **kwargs) 300 | return module 301 | else: 302 | for n,m in module.named_children(): 303 | module._modules[n] = _hier_quantizer_pact(m, **kwargs) 304 | return module 305 | 306 | def _hier_dequantizer_pact(module): 307 | if module.__class__.__name__ == 'PACT_Conv2d': 308 | W = module.weight.data 309 | try: 310 | b = module.bias.data 311 | except AttributeError: 312 | b = None 313 | module = torch.nn.Conv2d( 314 | module.in_channels, 315 | module.out_channels, 316 | _single(module.kernel_size), 317 | stride=_single(module.stride), 318 | padding=_single(module.padding), 319 | dilation=_single(module.dilation), 320 | groups=module.groups, 321 | bias=True if module.bias is not None else False 322 | ) 323 | module.weight.data = W.clone() 324 | if b is not None: 325 | module.bias.data = b.clone() 326 | return module 327 | if module.__class__.__name__ == 'PACT_Linear': 328 | W = module.weight.data 329 | try: 330 | b = module.bias.data 331 | except AttributeError: 332 | b = None 333 | module = torch.nn.Linear( 334 | module.in_features, 335 | module.out_features, 336 | bias=True if module.bias is not None else False 337 | ) 338 | module.weight.data = W.clone() 339 | if b is not None: 340 | module.bias.data = b.clone() 341 | return module 342 | elif module.__class__.__name__ == 'PACT_Act': 343 | module = torch.nn.ReLU() 344 | return module 345 | else: 346 | for n,m in module.named_children(): 347 | module._modules[n] = _hier_dequantizer_pact(m) 348 | return module 349 | 350 | def _hier_flat_dict_build(module, name): 351 | for n,m in module.named_children(): 352 | if n == name: 353 | return m 354 | elif n == name.split('.')[0]: 355 | return _hier_flat_dict_build(m, '.'.join(name.split('.')[1:])) 356 | return module 357 | 358 | def integerize_pact(module, eps_in, **kwargs): 359 | # r"""Takes a PyTorch module in q-deploy stage and makes it integerized, recursively. 360 | 361 | # :param eps_in: input quantum :math:`\varepsilon_{in}`. 362 | # :type eps_in: :py:class:`torch.Tensor` 363 | # :return: output quantum :math:`\varepsilon_{out}`. 364 | # :rtype: :py:class:`torch.Tensor` 365 | 366 | # """ 367 | try: 368 | net = module.module 369 | except AttributeError: 370 | net = module 371 | # assert(hasattr(net, 'graph')) 372 | # assert(net.graph is not None) 373 | net.set_eps_in(eps_in) 374 | net = _hier_integerizer(net, **kwargs) 375 | net.graph.rebuild_module_dict() 376 | # if hasattr(module, 'model'): 377 | # module.model = net 378 | # else: 379 | # module = net 380 | return module 381 | 382 | def dequantize_pact(module): 383 | module = _hier_dequantizer_pact(module) 384 | if hasattr(module, 'graph'): 385 | if module.graph is not None: 386 | module.graph.rebuild_module_dict() 387 | return module 388 | 389 | def thresholdize_pact(module, act_dict): 390 | module = _hier_thresholdizer_pact(module) 391 | module.fold_thresholds(act_dict) 392 | if hasattr(module, 'graph'): 393 | if module.graph is not None: 394 | module.graph.rebuild_module_dict() 395 | return module 396 | 397 | def bn_to_identity(module): 398 | module = _hier_bn_to_identity(module) 399 | if hasattr(module, 'graph'): 400 | if module.graph is not None: 401 | module.graph.rebuild_module_dict() 402 | return module 403 | 404 | def dropout_to_identity(module): 405 | module = _hier_dropout_to_identity(module) 406 | if hasattr(module, 'graph'): 407 | if module.graph is not None: 408 | module.graph.rebuild_module_dict() 409 | return module 410 | 411 | def bn_quantizer(module, **kwargs): 412 | module = _hier_bn_quantizer(module, **kwargs) 413 | if hasattr(module, 'graph'): 414 | if module.graph is not None: 415 | module.graph.rebuild_module_dict() 416 | return module 417 | 418 | def bn_dequantizer(module): 419 | module = _hier_bn_dequantizer(module) 420 | if hasattr(module, 'graph'): 421 | if module.graph is not None: 422 | module.graph.rebuild_module_dict() 423 | return module 424 | -------------------------------------------------------------------------------- /nemo/utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # utils.py 3 | # Francesco Conti 4 | # Alfio Di Mauro 5 | # 6 | # Copyright (C) 2018-2020 ETH Zurich 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | 20 | import torch 21 | import logging 22 | import os 23 | import json 24 | import re 25 | import nemo 26 | import numpy as np 27 | from collections import OrderedDict 28 | 29 | def precision_dict_to_json(d, filename=None): 30 | s = json.dumps(d, indent=4) 31 | if filename is not None: 32 | with open(filename, 'w') as f: 33 | f.write(s) 34 | else: 35 | return s 36 | 37 | def precision_dict_from_json(filename): 38 | with open(filename, "r") as f: 39 | rr = json.load(f) 40 | return rr 41 | 42 | def process_json(json, args): 43 | args = vars(args) 44 | regime = {} 45 | if args.regime is not None: 46 | with open(args.regime, "r") as f: 47 | rr = json.load(f) 48 | for k in rr.keys(): 49 | try: 50 | regime[int(k)] = rr[k] 51 | except ValueError: 52 | regime[k] = rr[k] 53 | 54 | def save_checkpoint(net, optimizer, epoch, acc=0.0, checkpoint_name='net_', checkpoint_suffix=''): 55 | checkpoint_name = checkpoint_name + checkpoint_suffix 56 | logging.info('Saving checkpoint %s' % checkpoint_name) 57 | try: 58 | optimizer_state = optimizer.state_dict() 59 | except AttributeError: 60 | optimizer_state = None 61 | try: 62 | precision = net.export_precision() 63 | except AttributeError: 64 | precision = None 65 | state = { 66 | 'epoch': epoch + 1, 67 | 'state_dict': net.state_dict(), 68 | 'precision': precision, 69 | 'acc': acc, 70 | 'optimizer' : optimizer_state, 71 | } 72 | if not os.path.isdir('checkpoint'): 73 | os.mkdir('checkpoint') 74 | torch.save(state, './checkpoint/%s.pth' % (checkpoint_name)) 75 | 76 | def export_onnx(file_name, net, net_inner, input_shape, round_params=True, perm=None, redefine_names=False, batch_size=1, verbose=False): 77 | if perm is None: 78 | perm = lambda x : x 79 | pattern = re.compile('[\W_]+') 80 | dummy_input = perm(torch.randn(batch_size, *input_shape, device='cuda' if torch.cuda.is_available() else 'cpu')) 81 | net.eval() 82 | # rounding of parameters to avoid strange numerical errors on writeout 83 | if round_params: 84 | for param in net_inner.parameters(recurse=True): 85 | if param.dtype is torch.float32: 86 | param[:] = torch.round(param) 87 | if redefine_names: 88 | input_names = [ 'input' ] + [ pattern.sub('_', n) for n,_ in net_inner.named_parameters() ] 89 | output_names = [ 'output' ] 90 | torch.onnx.export(net_inner, dummy_input, file_name, verbose=verbose, do_constant_folding=True, input_names=input_names, output_names=output_names, export_params=True) 91 | else: 92 | torch.onnx.export(net_inner, dummy_input, file_name, verbose=verbose, do_constant_folding=True, export_params=True) 93 | 94 | PRECISION_RULE_KEYS_REQUIRED = { 95 | "for_epochs": 1, 96 | "for_epochs_no_abs_bound": 3, 97 | "delta_loss_less_than": 0.01, 98 | "running_avg_memory": 5, 99 | "delta_loss_running_std_stale": 1.5, 100 | "abs_loss_stale": 1.4, 101 | "scale_lr": True, 102 | "lr_scaler": 4.0, 103 | "divergence_abs_threshold": 1e9 104 | } 105 | PRECISION_RULE_KEYS_ALLOWED = [ 106 | "custom_scaler", 107 | "bit_scaler", 108 | "bit_stop_condition" 109 | ] 110 | 111 | def parse_precision_rule(rule): 112 | required = list(PRECISION_RULE_KEYS_REQUIRED.keys()) 113 | allowed = PRECISION_RULE_KEYS_ALLOWED + required 114 | for k in required: 115 | if not k in rule: 116 | rule[k] = PRECISION_RULE_KEYS_REQUIRED[k] 117 | flag = False 118 | for k in rule.keys(): 119 | if not k in allowed and not k.isdigit(): 120 | print("[ERROR] %s is not a key allowed in the relaxation rule", k) 121 | flag = True 122 | if "bit_scaler" in rule and not "W_bit_scaler" in rule: 123 | rule["W_bit_scaler"] = rule["bit_scaler"] 124 | if "bit_scaler" in rule and not "x_bit_scaler" in rule: 125 | rule["x_bit_scaler"] = rule["bit_scaler"] 126 | if "bit_stop_condition" in rule and not "W_bit_stop_condition" in rule: 127 | rule["W_bit_stop_condition"] = rule["bit_stop_condition"] 128 | if "bit_stop_condition" in rule and not "x_bit_stop_condition" in rule: 129 | rule["x_bit_stop_condition"] = rule["bit_stop_condition"] 130 | if flag: 131 | import sys; sys.exit(1) 132 | print(list(rule.keys())) 133 | return rule 134 | 135 | # see https://github.com/sksq96/pytorch-summary 136 | def get_summary(net, input_size, batch_size=1, device="cuda", verbose=False): 137 | s = "" 138 | mdict = {} 139 | for n,m in net.named_modules(): 140 | mdict[n] = m 141 | def register_hook(module): 142 | def hook(module, input, output): 143 | class_name = str(module.__class__).split(".")[-1].split("'")[0] 144 | module_idx = len(summary) 145 | m_key = next(n for n,m in mdict.items() if m==module) 146 | summary[m_key] = OrderedDict() 147 | summary[m_key]["input_shape"] = list(input[0].size()) 148 | summary[m_key]["input_shape"][0] = batch_size 149 | if isinstance(output, (list, tuple)): 150 | summary[m_key]["output_shape"] = [ 151 | [-1] + list(o.size())[1:] for o in output 152 | ] 153 | else: 154 | summary[m_key]["output_shape"] = list(output.size()) 155 | summary[m_key]["output_shape"][0] = batch_size 156 | 157 | params = 0 158 | if hasattr(module, "weight") and hasattr(module.weight, "size"): 159 | try: 160 | params += torch.prod(torch.LongTensor(list(module.weight.size()))) / module.group 161 | except AttributeError: 162 | params += torch.prod(torch.LongTensor(list(module.weight.size()))) 163 | summary[m_key]["trainable"] = module.weight.requires_grad 164 | if hasattr(module, "bias") and hasattr(module.bias, "size"): 165 | params += torch.prod(torch.LongTensor(list(module.bias.size()))) 166 | summary[m_key]["nb_params"] = params 167 | 168 | if hasattr(module, "W_precision"): 169 | summary[m_key]['W_bits'] = module.W_precision.get_bits() 170 | 171 | if hasattr(module, "precision"): 172 | summary[m_key]['bits'] = module.precision.get_bits() 173 | 174 | if ( 175 | not isinstance(module, torch.nn.Sequential) 176 | and not isinstance(module, torch.nn.ModuleList) 177 | and not (module == net) 178 | ): 179 | hooks.append(module.register_forward_hook(hook)) 180 | 181 | device = device.lower() 182 | assert device in [ 183 | "cuda", 184 | "cpu", 185 | ], "Input device is not valid, please specify 'cuda' or 'cpu'" 186 | 187 | if device == "cuda" and torch.cuda.is_available(): 188 | dtype = torch.cuda.FloatTensor 189 | else: 190 | dtype = torch.FloatTensor 191 | 192 | # multiple inputs to the network 193 | if isinstance(input_size, tuple): 194 | input_size = [input_size] 195 | 196 | # batch_size of 2 for batchnorm 197 | x = [torch.rand(2, *in_size).type(dtype) for in_size in input_size] 198 | 199 | # create properties 200 | summary = OrderedDict() 201 | hooks = [] 202 | 203 | # register hook 204 | net.apply(register_hook) 205 | 206 | # make a forward pass 207 | net(*x) 208 | 209 | # remove these hooks 210 | for h in hooks: 211 | h.remove() 212 | 213 | s += "----------------------------------------------------------------" + "\n" 214 | line_new = "{:>20} {:>25} {:>15}".format("Layer (type)", "Output Shape", "Param #") 215 | s += line_new + "\n" 216 | s += "================================================================" + "\n" 217 | total_params = 0 218 | total_output = 0 219 | trainable_params = 0 220 | params_size = 0 221 | output_size = 0 222 | input_size = 0 223 | for layer in summary: 224 | # input_shape, output_shape, trainable, nb_params 225 | line_new = "{:>20} {:>25} {:>15}".format( 226 | layer, 227 | str(summary[layer]["output_shape"]), 228 | "{0:,}".format(summary[layer]["nb_params"]), 229 | ) 230 | total_params += summary[layer]["nb_params"] 231 | try: 232 | params_size += abs(summary[layer]["nb_params"] * summary[layer]["W_bits"] / 8. / (1024.)) 233 | except KeyError: 234 | params_size += abs(summary[layer]["nb_params"] * 32. / 8. / (1024.)) 235 | total_output += np.prod(summary[layer]["output_shape"]) 236 | try: 237 | output_size = max(output_size, np.prod(summary[layer]["output_shape"]) * summary[layer]["bits"] / 8 / (1024.)) 238 | except KeyError: 239 | output_size = max(output_size, np.prod(summary[layer]["output_shape"]) * 32 / 8 / (1024.)) 240 | if "trainable" in summary[layer]: 241 | if summary[layer]["trainable"] == True: 242 | trainable_params += summary[layer]["nb_params"] 243 | s += line_new + "\n" 244 | 245 | s += "================================================================" + "\n" 246 | s += "Total params: {0:,}".format(total_params) + "\n" 247 | s += "Trainable params: {0:,}".format(trainable_params) + "\n" 248 | s += "Non-trainable params: {0:,}".format(total_params - trainable_params) + "\n" 249 | s += "----------------------------------------------------------------" + "\n" 250 | s += "Biggest activation tensor size (kB): {0:,.2f}".format(output_size) + "\n" 251 | s += "Params size (kB): {0:,.1f}".format(params_size) + "\n" 252 | s += "----------------------------------------------------------------" + "\n" 253 | if verbose: 254 | logging.info(s) 255 | return { 'dict': summary, 'prettyprint': s, 'biggest_activation': output_size, 'params_size': params_size } 256 | 257 | def get_intermediate_activations(net, test_fn, *test_args, **test_kwargs): 258 | l = len(list(net.named_modules())) 259 | buffer_in = OrderedDict([]) 260 | buffer_out = OrderedDict([]) 261 | hooks = OrderedDict([]) 262 | def get_hk(n): 263 | def hk(module, input, output): 264 | buffer_in [n] = input 265 | buffer_out [n] = output 266 | return hk 267 | for i,(n,l) in enumerate(net.named_modules()): 268 | hk = get_hk(n) 269 | hooks[n] = l.register_forward_hook(hk) 270 | ret = test_fn(*test_args, **test_kwargs) 271 | for n,l in net.named_modules(): 272 | hooks[n].remove() 273 | return buffer_in, buffer_out, ret 274 | 275 | def get_intermediate_eps(net, eps_in): 276 | l = len(list(net.named_modules())) 277 | eps = OrderedDict([]) 278 | for i,(n,l) in enumerate(net.named_modules()): 279 | eps[n] = net.get_eps_at(n, eps_in) 280 | return eps 281 | 282 | def get_integer_activations(buf, eps, net=None): 283 | if type(eps) is float and net is None: 284 | return buf 285 | elif type(eps) is float: 286 | eps_in = eps 287 | eps = OrderedDict([]) 288 | for n,m in net.named_modules(): 289 | try: 290 | eps[n] = m.get_output_eps(eps_in) 291 | except AttributeError: 292 | pass 293 | buf_ = OrderedDict([]) 294 | for n in buf.keys(): 295 | b = buf.get(n, None) 296 | e = eps.get(n, None) 297 | if b is None or e is None: 298 | continue 299 | if type(buf[n]) is tuple or type(buf[n]) is list: 300 | buf_[n] = [] 301 | for b in buf[n]: 302 | buf_[n].append((b / eps[n]).floor()) # FIXME 303 | else: 304 | buf_[n] = (buf[n] / eps[n]).floor() 305 | return buf_ 306 | 307 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.3.0 2 | torchvision>=0.4.1 3 | numpy 4 | tqdm 5 | packaging 6 | scikit-learn 7 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r") as fh: 4 | long_description = fh.read() 5 | 6 | setuptools.setup( 7 | name="pytorch-nemo", 8 | version="0.0.8", 9 | author="Francesco Conti", 10 | author_email="f.conti@unibo.it", 11 | description="NEural Minimizer for pytOrch", 12 | long_description=long_description, 13 | long_description_content_type="text/markdown", 14 | url="https://github.com/pulp-platform/nemo", 15 | packages=setuptools.find_packages(), 16 | classifiers=[ 17 | "Programming Language :: Python :: 3", 18 | "License :: OSI Approved :: Apache Software License", 19 | "Operating System :: OS Independent", 20 | ], 21 | python_requires='>=3.5', 22 | install_requires=[ 23 | "torch>=1.3.0", 24 | "torchvision>=0.4.1", 25 | "numpy", 26 | "tqdm", 27 | "packaging", 28 | "scikit-learn" 29 | ] 30 | ) 31 | -------------------------------------------------------------------------------- /tests/mnist_test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """NEMO: post-training quantization MNIST example 4 | 5 | Automatically generated by Colaboratory. 6 | 7 | Original file is located at 8 | https://colab.research.google.com/drive/1AmcITfN2ELQe07WKQ9szaxq-WSu4hdQb 9 | 10 | This example guides a first-time user of NEMO into the NEMO quantization process, using a small pretrained network and going through post-training per-layer quantization (i.e., representing weight and activation tensors as integers) and deployment (i.e., organizing operations so that they are an accurate representation of behavior in integer-based hardware). We will see how this operates through four stages: *FullPrecision*, *FakeQuantized*,*QuantizedDeployable*, *IntegerDeployable*. 11 | 12 | NEMO uses `float32` tensors to represent data at all four stages - including *IntegerDeployable*. This means that NEMO code does not need special hardware support for integers to run on GPUs. It also means that NEMO is not (and does not want to be) a runtime for quantized neural networks on embedded systems! 13 | 14 | Let us start by installing dependencies... 15 | """ 16 | 17 | """... and import all packages, including NEMO itself:""" 18 | 19 | import argparse 20 | import torch 21 | import torch.nn as nn 22 | import torch.nn.functional as F 23 | import torch.optim as optim 24 | from torchvision import datasets, transforms 25 | import nemo 26 | from tqdm import tqdm 27 | 28 | """The first real step is to define the network topology. This works exactly like in a "standard" PyTorch script, using regular `torch.nn.Module` instances. NEMO can transform many layers defined in `torch.nn` into its own representations. There are a few constraints, however, to the network construction: 29 | * Use `torch.nn.Module`, not `torch.autograd.Function`: NEMO works by listing modules and ignores functions by construction. Everything coming from the `torch.nn.functional` library is ignored by NEMO: for example, you have to use `torch.nn.ReLU` module instead of the equivalent `torch.nn.functional.relu` function, which is often found in examples online. 30 | * Instantiate a separate `torch.nn.Module` for each node in your topology; you already do this for parametric modules (e.g., `torch.nn.Conv2d`), but you have to do the same also for `torch.nn.ReLU` and other parameterless modules. NEMO will introduce quantization parameters that will change along the network. 31 | * To converge two network branches (e.g., a main and a residual branch), a normal PyTorch network will usually add the values of their output tensors. This will keep working for a network at the *FakeQuantized* stage, i.e., one that can be fine-tuned keeping into account quantization - but it will break in later stages. In the *QuantizedDeployable* and *IntegerDeployable* stages, the branch reconvergence has to take into account the possibly different precision of the two branches, therefore NEMO has to know that there is an "Add" node at that point of the network. This can be realized using the `nemo.quant.pact.PACT_IntegerAdd` module, which is entirely equivalent to a normal addition in *FullPrecision* and *FakeQuantized* stages. 32 | """ 33 | 34 | class ExampleNet(nn.Module): 35 | def __init__(self): 36 | super(ExampleNet, self).__init__() 37 | self.conv1 = nn.Conv2d(1, 32, 3, 1) 38 | self.bn1 = nn.BatchNorm2d(32) 39 | self.relu1 = nn.ReLU() # <== Module, not Function! 40 | self.conv2 = nn.Conv2d(32, 64, 3, 1) 41 | self.bn2 = nn.BatchNorm2d(64) 42 | self.relu2 = nn.ReLU() # <== Module, not Function! 43 | self.pool2 = nn.MaxPool2d(2) 44 | self.fc1 = nn.Linear(9216, 256) 45 | self.fcrelu1 = nn.ReLU() # <== Module, not Function! 46 | self.fc2 = nn.Linear(256, 10) 47 | 48 | def forward(self, x): 49 | x = self.conv1(x) 50 | x = self.bn1(x) 51 | x = self.relu1(x) # <== Module, not Function! 52 | x = self.conv2(x) 53 | x = self.bn2(x) 54 | x = self.relu2(x) # <== Module, not Function! 55 | x = self.pool2(x) 56 | x = torch.flatten(x, 1) 57 | x = self.fc1(x) 58 | x = self.fcrelu1(x) # <== Module, not Function! 59 | x = self.fc2(x) 60 | output = F.log_softmax(x, dim=1) # <== the softmax operation does not need to be quantized, we can keep it as it is 61 | return output 62 | 63 | """Then we define the training and testing functions (MNIST has no validation set). These are essentially identical to regular PyTorch code, with only one difference: testing (and validation) functions have a switch to support the production of non-negative integer data. This is important to test the last stage of quantization, i.e., *IntegerDeployable*. Of course, this change might also be effectively performed inside the data loaders; in this example, we use standard `torchvision` data loaders for MNIST.""" 64 | 65 | # convenience class to keep track of averages 66 | class Metric(object): 67 | def __init__(self, name): 68 | self.name = name 69 | self.sum = torch.tensor(0.) 70 | self.n = torch.tensor(0.) 71 | def update(self, val): 72 | self.sum += val.cpu() 73 | self.n += 1 74 | @property 75 | def avg(self): 76 | return self.sum / self.n 77 | 78 | def train(model, device, train_loader, optimizer, epoch, verbose=False): 79 | model.train() 80 | train_loss = Metric('train_loss') 81 | with tqdm(total=len(train_loader), 82 | desc='Train Epoch #{}'.format(epoch + 1), 83 | disable=not verbose) as t: 84 | for batch_idx, (data, target) in enumerate(train_loader): 85 | data, target = data.to(device), target.to(device) 86 | optimizer.zero_grad() 87 | output = model(data) 88 | loss = F.nll_loss(output, target) 89 | loss.backward() 90 | optimizer.step() 91 | train_loss.update(loss) 92 | t.set_postfix({'loss': train_loss.avg.item()}) 93 | t.update(1) 94 | return train_loss.avg.item() 95 | 96 | def test(model, device, test_loader, integer=False, verbose=False): 97 | model.eval() 98 | test_loss = 0 99 | correct = 0 100 | test_acc = Metric('test_acc') 101 | with tqdm(total=len(test_loader), 102 | desc='Test', 103 | disable=not verbose) as t: 104 | with torch.no_grad(): 105 | for data, target in test_loader: 106 | if integer: # <== this will be useful when we get to the 107 | data *= 255 # IntegerDeployable stage 108 | data, target = data.to(device), target.to(device) 109 | output = model(data) 110 | test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss 111 | pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability 112 | correct += pred.eq(target.view_as(pred)).sum().item() 113 | test_acc.update((pred == target.view_as(pred)).float().mean()) 114 | t.set_postfix({'acc' : test_acc.avg.item() * 100. }) 115 | t.update(1) 116 | test_loss /= len(test_loader.dataset) 117 | return test_acc.avg.item() * 100. 118 | 119 | """Set up the dataset loaders.""" 120 | 121 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 122 | kwargs = {'num_workers': 1, 'pin_memory': True} if torch.cuda.is_available() else {} 123 | train_loader = torch.utils.data.DataLoader( 124 | datasets.MNIST('../data', train=True, download=True, transform=transforms.Compose([ 125 | transforms.ToTensor() 126 | ])), 127 | batch_size=128, shuffle=True, **kwargs 128 | ) 129 | test_loader = torch.utils.data.DataLoader( 130 | datasets.MNIST('../data', train=False, transform=transforms.Compose([ 131 | transforms.ToTensor() 132 | ])), 133 | batch_size=128, shuffle=False, **kwargs 134 | ) 135 | 136 | """Download a pretrained model, and test it! Here we operate at what we call the ***FullPrecision*** stage: the regular PyTorch representation, which relies on real-valued tensors represented by `float32` in your CPU/GPU.""" 137 | 138 | # !wget https://raw.githubusercontent.com/FrancescoConti/nemo_examples_helper/master/mnist_cnn_fp.pt 139 | model = ExampleNet().to(device) 140 | state_dict = torch.load("mnist_cnn_fp.pt", map_location='cpu') 141 | model.load_state_dict(state_dict, strict=True) 142 | acc = test(model, device, test_loader) 143 | print("\nFullPrecision accuracy: %.02f%%" % acc) 144 | assert acc >= 99.0 145 | 146 | """Now, it's time to try out some post-training quantization. We will do so by switching to the ***FakeQuantized*** stage. This representation is very similar to *FullPrecision*, as it still uses real-valued tensors for weights and activations. However, activation functions such as ReLU become quantization functions, imposing that the output is representable in a certain number of steps. Mathematically, 147 | $$ 148 | \mathbf{y} = \mathrm{ReLU}(\mathbf{x}) = \mathrm{clip}_{ [0,\infty) }(\mathbf{x}) \quad\longrightarrow\quad \mathbf{y} = \left\lfloor 1/\varepsilon \cdot \mathrm{clip}_{ [0,\alpha) } (\mathbf{x}) \right\rfloor \cdot \varepsilon 149 | $$ 150 | 151 | Here, we introduce two changes to the ReLU. First, the clipping function is not only clipping at 0, but also at a maximum value $\alpha$, which can be set to the maximum value of $\mathbf{y}$ in the *FullPrecision* stage (see later). Second, we introduce a *quantum* $\varepsilon$, which is the smalles real-valued amount that is representable in $\mathbf{y}$. With $Q$ bits, $\varepsilon = \alpha / (2^Q - 1)$. 152 | Network weights are passed through a similar function before being used, with the main difference that they are not strictly non-negative: they are clipped between two values $\alpha$ and $\beta$. 153 | 154 | At this stage, the network can be trained / fine-tuned (see details in the fine-tuning example). Also, the numerical values of activations can differ a bit from a real hardware implementation. This is because quantization of tensors is only actively imposed in the activation functions, but not in the other operations (i.e., in ReLU, not in BatchNorm2d and Conv2d). 155 | 156 | Here, we first transform the model in a *FakeQuantized* version targeting a very relaxed 16-bit quantization for weights and activations. The quantization for each layer can be tweaked by means of a precision dictionary; notice that for weights we actually impose 15 bits instead of 16: this is to take into account the asymmetricity of $\alpha$ and $\beta$. 157 | """ 158 | 159 | model = nemo.transform.quantize_pact(model, dummy_input=torch.randn((1,1,28,28)).to(device)) 160 | precision = { 161 | 'conv1': { 162 | 'W_bits' : 15 163 | }, 164 | 'conv2': { 165 | 'W_bits' : 15 166 | }, 167 | 'fc1': { 168 | 'W_bits' : 15 169 | }, 170 | 'fc2': { 171 | 'W_bits' : 15 172 | }, 173 | 'relu1': { 174 | 'x_bits' : 16 175 | }, 176 | 'relu2': { 177 | 'x_bits' : 16 178 | }, 179 | 'fcrelu1': { 180 | 'x_bits' : 16 181 | }, 182 | } 183 | model.change_precision(bits=1, min_prec_dict=precision) 184 | acc = test(model, device, test_loader) 185 | print("\nFakeQuantized @ 16b accuracy (first try): %.02f%%" % acc) 186 | assert acc >= 80.0 187 | 188 | """The first try looks... not so good. 82% is actually pretty bad for MNIST! What happened? Remember that while clipping parameters for weights can be set statically, this is not true for activations: so the missing piece is the characterization of activation clipping ($\alpha$ parameter), which is currently set to a default value. 189 | 190 | In NEMO, this initial calibration can be performed by setting a special *statistics collection* mode for activations, which is used to explicitly reset the $\alpha$ params. The calibration is performed directly by running inference over a dataset; in this case, we cheat a bit and do it on the test set. 191 | """ 192 | 193 | model.set_statistics_act() 194 | _ = test(model, device, test_loader) 195 | model.unset_statistics_act() 196 | model.reset_alpha_act() 197 | acc = test(model, device, test_loader) 198 | print("\nFakeQuantized @ 16b accuracy (calibrated): %.02f%%" % acc) 199 | assert acc >= 99.0 200 | 201 | """Now the accuracy is substantially the same as the initial one! This is what we expect to see using a very conservative quantization scheme with 16 bits. Due to the way that NEMO implements the *FakeQuantized* stage, it is very easy to explore what happens by imposing a stricter or mixed precision quantization scheme. The number of bits we can use is very free: we can even set it to "fractionary" values if we want, which corresponds to intermediate $\varepsilon$ *quantum* sizes with respect to the nearest integers. For example, let's force `conv1`, `conv2` and `fc1` to be 7 bits, `fc2` to use only 3 bits for its parameters, and all activations to be 8-bit. 202 | 203 | After the experiment, we also save the network in a PyTorch checkpoint file, because afterwards we'll start doing some destructive transformations... 204 | """ 205 | 206 | precision = { 207 | 'conv1': { 208 | 'W_bits' : 7 209 | }, 210 | 'conv2': { 211 | 'W_bits' : 7 212 | }, 213 | 'fc1': { 214 | 'W_bits' : 7 215 | }, 216 | 'fc2': { 217 | 'W_bits' : 3 218 | }, 219 | 'relu1': { 220 | 'x_bits' : 8 221 | }, 222 | 'relu2': { 223 | 'x_bits' : 8 224 | }, 225 | 'fcrelu1': { 226 | 'x_bits' : 8 227 | }, 228 | } 229 | model.change_precision(bits=1, min_prec_dict=precision) 230 | acc = test(model, device, test_loader) 231 | print("\nFakeQuantized @ mixed-precision accuracy: %.02f%%" % acc) 232 | assert acc >= 99.0 233 | nemo.utils.save_checkpoint(model, None, 0, checkpoint_name='mnist_fq_mixed') 234 | 235 | """Since MNIST is very easy, there is only a very small accuracy reduction despite the aggressive reduction at the end of the network. Now, let us progress towards a possible deployment. One of the possible steps is the so-called *folding* of batch-normalization layers. With this operation, the normalization performed by batch-norm layers is absorbed within the parameters of the convolutional layers. To perform it, we first do the folding itself, then we reset the clipping parameters of weights (because the weights change their value!).""" 236 | 237 | model.fold_bn() 238 | model.reset_alpha_weights() 239 | acc = test(model, device, test_loader) 240 | print("\nFakeQuantized @ mixed-precision (folded) accuracy: %.02f%%" % acc) 241 | assert acc >= 98.8 242 | 243 | """Notice a small reduction in accuracy: as you might remember, the batch-norm layers were not quantized before folding; folding absorbs them inside the quantized parameters. Therefore a small reduction is expected! There are also a few techniques that can be used to recover accuracy, such as the weight equaliztion for Data Free Quantization proposed by Nagel et al. (https://arxiv.org/abs/1906.04721) . Here we try it on our network, which requires also a new calibration pass.""" 244 | 245 | model.equalize_weights_dfq({'conv1':'conv2'}, reset_alpha=False) 246 | model.set_statistics_act() 247 | _ = test(model, device, test_loader) 248 | model.unset_statistics_act() 249 | model.reset_alpha_act() 250 | acc = test(model, device, test_loader) 251 | print("\nFakeQuantized @ mixed-precision (folded+equalized) accuracy: %.02f%%" % acc) 252 | assert acc >= 98.8 253 | 254 | """Now we go back one step, reloading the state from the saved checkpoint (before folding) to show the "standard" deployment strategy that we use, based on Integer Batch-Norm (Rusci et al., https://arxiv.org/abs/1905.13082). 255 | This is organized in two steps: first, we replace all `BatchNorm2d` in the network into a special quantized form, which is equivalent to freezing their parameters and transforms them, essentially, in channel-wise affine transforms. Then, we harden weights in their current quantum representation. Finally, we use the `set_deployment` method to bring the network to the ***QuantizedDeployable*** stage. 256 | 257 | In the *QuantizedDeployable* stage, the network is still using real-valued weights and activations; but all operations consume and produce quantized tensors that can always be decomposed in the product of a real-valued quantum with an integer tensor, which we call the *integer image*. The network cannot be trained any more, but it can be exported in an ONNX graph that faithfully represents quantization. 258 | 259 | Bringing the network to this stage requires setting an input quantum $\varepsilon_{in}$, which corresponds to the value of 1 bit in the integer representation of our input. For example, for an 8-bit image represented by a tensor in the range $[0,1]$, this will be $1/255$. 260 | Internally, NEMO uses a graph representation of the network to propagate the $\varepsilon$ to all layers that are not explicitly quantized (in practice, all layers that are not activations). 261 | 262 | In the next cell, we bring our MNIST network to this representation, then we test it to verify that it still works. 263 | """ 264 | 265 | state_dict = torch.load('checkpoint/mnist_fq_mixed.pth')['state_dict'] 266 | model.load_state_dict(state_dict, strict=True) 267 | model.qd_stage(eps_in=1./255) 268 | print(model) 269 | acc = test(model, device, test_loader) 270 | print("\nQuantizedDeployable @ mixed-precision accuracy: %.02f%%" % acc) 271 | assert acc >= 99.0 272 | 273 | """The *QuantizedDeployable* network is accurate only in the sense that the operations keep all quantization assumptions. It is not, however, bit-accurate with respect to deployment on an integer-only hardware platform. To get that level of accuracy, we have to transform the network to the last stage: ***IntegerDeployable***. 274 | 275 | At this stage, the network can essentially "forget" about the quantum and only work on integer images in all nodes. This means that all weights and activations are replaced by integers! The next cells shows the transformation and how the final integer network still achieves the full accuracy... if we remember that now also test data has to be represented in an integer format. 276 | """ 277 | 278 | model = nemo.transform.integerize_pact(model, eps_in=1.0/255) 279 | print(model) 280 | acc = test(model, device, test_loader, integer=True) 281 | print("\nIntegerDeployable @ mixed-precision accuracy: %.02f%%" % acc) 282 | assert acc >= 99.0 283 | -------------------------------------------------------------------------------- /tests/mobi_fq_qd_id/mobi_fq_qd.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import PIL 3 | import os 4 | import time 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.parallel 8 | import torch.backends.cudnn as cudnn 9 | import torch.optim 10 | import torch.utils.data 11 | from mobilenet import mobilenet 12 | from torch.autograd import Variable 13 | from datetime import datetime 14 | from ast import literal_eval 15 | import json 16 | from torchvision.utils import save_image 17 | from tqdm import tqdm 18 | import nemo 19 | import warnings 20 | import math 21 | import copy 22 | import collections 23 | 24 | # -a mobilenet \ 25 | # --mobilenet_width 1.0 \ 26 | # --mobilenet_input 128 \ 27 | # --dataset imagenet \ 28 | # --weight_bits 8 \ 29 | # --activ_bits 8 \ 30 | # --gpus 0 \ 31 | # -j 8 \ 32 | # --epochs 12 \ 33 | # -b 128 \ 34 | # --quantize \ 35 | # --terminal \ 36 | # --resume checkpoint/mobilenet_1.0_128_best.pth 37 | 38 | # filter out ImageNet EXIF warnings 39 | warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning) 40 | warnings.filterwarnings("ignore", "Metadata Warning", UserWarning) 41 | 42 | model_names = ['mobilenet',] 43 | 44 | parser = argparse.ArgumentParser(description='PyTorch ConvNet Training') 45 | 46 | parser.add_argument('--results_dir', metavar='RESULTS_DIR', default='./results', 47 | help='results dir') 48 | parser.add_argument('--save', metavar='SAVE', default='', 49 | help='saved folder') 50 | parser.add_argument('--dataset', metavar='DATASET', default='imagenet', 51 | help='dataset name or folder') 52 | parser.add_argument('--model', '-a', metavar='MODEL', default='mobilenet', 53 | choices=model_names, 54 | help='model architecture: ' + 55 | ' | '.join(model_names) + 56 | ' (default: alexnet)') 57 | parser.add_argument('--input_size', type=int, default=None, 58 | help='image input size') 59 | parser.add_argument('--model_config', default='', 60 | help='additional architecture configuration') 61 | parser.add_argument('--type', default='torch.FloatTensor', 62 | help='type of tensor - e.g torch.cuda.HalfTensor') 63 | parser.add_argument('--gpus', default='', 64 | help='gpus used for training - e.g 0,1,3') 65 | parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', 66 | help='number of data loading workers (default: 8)') 67 | parser.add_argument('--epochs', default=150, type=int, metavar='N', 68 | help='number of total epochs to run') 69 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 70 | help='manual epoch number (useful on restarts)') 71 | parser.add_argument('-b', '--batch-size', default=256, type=int, 72 | metavar='N', help='mini-batch size (default: 256)') 73 | parser.add_argument('--optimizer', default='SGD', type=str, metavar='OPT', 74 | help='optimizer function used') 75 | parser.add_argument('--lr', '--learning_rate', default=0.1, type=float, 76 | metavar='LR', help='initial learning rate') 77 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 78 | help='momentum') 79 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 80 | metavar='W', help='weight decay (default: 1e-4)') 81 | parser.add_argument('--print-freq', '-p', default=100, type=int, 82 | metavar='N', help='print frequency (default: 10)') 83 | parser.add_argument('--resume', default="mobilenet_1.0_128_best.pth", type=str, metavar='PATH', 84 | help='path to latest checkpoint (default: none)') 85 | parser.add_argument('-e', '--evaluate', action='store_true', 86 | help='run model on validation set') 87 | parser.add_argument('--save_check', action='store_true', 88 | help='saving the checkpoint') 89 | parser.add_argument('--terminal', action='store_true') 90 | # quantization parameters 91 | parser.add_argument('--quantize', default=True, action='store_true', 92 | help='quantize the network') 93 | parser.add_argument('--type_quant', default=None, 94 | help='Type of binarization process') 95 | parser.add_argument('--weight_bits', default=8, 96 | help='Number of bits for the weights') 97 | parser.add_argument('--activ_bits', default=8, 98 | help='Number of bits for the activations') 99 | 100 | parser.add_argument('--initial_folding', default=False, action='store_true', 101 | help='Fold BNs into Linear layers before training') 102 | parser.add_argument('--initial_equalization', default=False, action='store_true', 103 | help='Perform Linear layer weight equalization before training') 104 | parser.add_argument('--quant_add_config', default='', type=str, 105 | help='Additional config of per-layer quantization') 106 | 107 | # mobilenet params 108 | parser.add_argument('--mobilenet_width', default=1.0, type=float, 109 | help='Mobilenet Width Muliplier') 110 | parser.add_argument('--mobilenet_input', default=128, type=int, 111 | help='Mobilenet input resolution ') 112 | 113 | # mixed-precision params 114 | parser.add_argument('--mem_constraint', default='', type=str, 115 | help='Memory constraints for automatic bitwidth quantization') 116 | parser.add_argument('--mixed_prec_quant', default='MixPL', type=str, 117 | help='Type of quantization for mixed-precision low bitwidth: MixPL | MixPC') 118 | 119 | 120 | def main(): 121 | global args, best_prec1 122 | best_prec1 = 0 123 | args = parser.parse_args() 124 | 125 | weight_bits = int(args.weight_bits) 126 | activ_bits = int(args.activ_bits) 127 | 128 | print("run arguments: %s" % args) 129 | 130 | args.gpus = None 131 | 132 | # create model 133 | print("creating model %s" % args.model) 134 | model_config = {'input_size': args.input_size, 'dataset': args.dataset, 'num_classes': 1000, \ 135 | 'width_mult': float(args.mobilenet_width), 'input_dim': float(args.mobilenet_input) } 136 | 137 | model = mobilenet(**model_config).to('cpu') 138 | print("created model with configuration: %s" % model_config) 139 | print(model) 140 | 141 | 142 | mobilenet_width = float(args.mobilenet_width) 143 | mobilenet_input = int(args.mobilenet_input) 144 | 145 | # transform the model in a NEMO FakeQuantized representation 146 | model = nemo.transform.quantize_pact(model, dummy_input=torch.randn((1,3,mobilenet_input,mobilenet_input)).to('cpu')) 147 | 148 | checkpoint_file = args.resume 149 | if os.path.isfile(checkpoint_file): 150 | print("loading checkpoint '%s'" % args.resume) 151 | checkpoint_loaded = torch.load(checkpoint_file, map_location=torch.device('cpu')) 152 | checkpoint = checkpoint_loaded['state_dict'] 153 | model.load_state_dict(checkpoint, strict=True) 154 | prec_dict = checkpoint_loaded.get('precision') 155 | else: 156 | print("no checkpoint found at '%s'" % args.resume) 157 | import sys; sys.exit(1) 158 | 159 | print("[NEMO] Not calibrating model, as it is pretrained") 160 | model.change_precision(bits=1, min_prec_dict=prec_dict) 161 | 162 | inputs = torch.load("input_fq.pth", map_location=torch.device('cpu'))['in'] 163 | inputs = inputs[:8] # reduce input size for GitHub CI regression test 164 | bin_fq, bout_fq, _ = nemo.utils.get_intermediate_activations(model, forward, model, inputs) 165 | 166 | input_bias_dict = {}# {'model.0.0' : +1.0, 'model.0.1' : +1.0} 167 | remove_bias_dict = {}#{'model.0.1' : 'model.0.2'} 168 | input_bias = 0 #math.floor(1.0 / (2./255)) * (2./255) 169 | 170 | model.qd_stage(eps_in=2./255, int_accurate=False) 171 | # fix ConstantPad2d 172 | # model.model[0][0].value = input_bias 173 | 174 | bin_qd, bout_qd, _ = nemo.utils.get_intermediate_activations(model, forward, model, inputs, input_bias=input_bias) 175 | 176 | diff = collections.OrderedDict() 177 | for k in bout_fq.keys(): 178 | diff[k] = (bout_fq[k] - bout_qd[k]).to('cpu').abs() 179 | 180 | for i in range(0,26): 181 | for j in range(3,4): 182 | k = 'model.%d.%d' % (i,j) 183 | kn = 'model.%d.%d' % (i if j<3 else i+1, j+1 if j<3 else 0) 184 | eps = model.get_eps_at(kn, eps_in=2./255)[0] 185 | print("%s:" % k) 186 | idx = diff[k]>eps 187 | n = idx.sum() 188 | t = (diff[k]>-1e9).sum() 189 | max_eps = torch.ceil(diff[k].max() / model.get_eps_at('model.%d.0' % (i+1), 2./255)[0]).item() 190 | mean_eps = torch.ceil(diff[k][idx].mean() / model.get_eps_at('model.%d.0' % (i+1), 2./255)[0]).item() 191 | assert(max_eps < 1) 192 | try: 193 | print(" max: %.3f (%d eps)" % (diff[k].max().item(), max_eps)) 194 | print(" mean: %.3f (%d eps) (only diff. elements)" % (diff[k][idx].mean().item(), mean_eps)) 195 | print(" #diff: %d/%d (%.1f%%)" % (n, t, float(n)/float(t)*100)) 196 | except ValueError: 197 | print(" #diff: 0/%d (0%%)" % (t,)) 198 | 199 | def forward(model, inputs, input_bias=0.0, eps_in=None, integer=False): 200 | 201 | model.eval() 202 | 203 | # input quantization 204 | if eps_in is None: 205 | scale_factor = 1. 206 | div_factor = 1. 207 | elif not integer: 208 | scale_factor = 1./eps_in 209 | div_factor = 1./eps_in 210 | else: 211 | scale_factor = 1./eps_in 212 | div_factor = 1. 213 | 214 | # measure data loading time 215 | with torch.no_grad(): 216 | if eps_in is None: 217 | input_var = (inputs + input_bias) 218 | else: 219 | input_var = (inputs + input_bias) * scale_factor 220 | 221 | # compute output 222 | output = model(input_var) 223 | 224 | if __name__ == '__main__': 225 | main() 226 | -------------------------------------------------------------------------------- /tests/mobi_fq_qd_id/mobi_qd_id.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import PIL 3 | import os 4 | import time 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.parallel 8 | import torch.backends.cudnn as cudnn 9 | import torch.optim 10 | import torch.utils.data 11 | from mobilenet import mobilenet 12 | from torch.autograd import Variable 13 | from datetime import datetime 14 | from ast import literal_eval 15 | import json 16 | from torchvision.utils import save_image 17 | from tqdm import tqdm 18 | import nemo 19 | import warnings 20 | import math 21 | import copy 22 | import collections 23 | 24 | # -a mobilenet \ 25 | # --mobilenet_width 1.0 \ 26 | # --mobilenet_input 128 \ 27 | # --dataset imagenet \ 28 | # --weight_bits 8 \ 29 | # --activ_bits 8 \ 30 | # --gpus 0 \ 31 | # -j 8 \ 32 | # --epochs 12 \ 33 | # -b 128 \ 34 | # --quantize \ 35 | # --terminal \ 36 | # --resume checkpoint/mobilenet_1.0_128_best.pth 37 | 38 | SAVE_RESULTS = False 39 | TOL_RESULTS = 1.01 40 | TOL_PERCENT = 1.1 41 | 42 | # filter out ImageNet EXIF warnings 43 | warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning) 44 | warnings.filterwarnings("ignore", "Metadata Warning", UserWarning) 45 | 46 | model_names = ['mobilenet',] 47 | 48 | parser = argparse.ArgumentParser(description='PyTorch ConvNet Training') 49 | 50 | parser.add_argument('--results_dir', metavar='RESULTS_DIR', default='./results', 51 | help='results dir') 52 | parser.add_argument('--save', metavar='SAVE', default='', 53 | help='saved folder') 54 | parser.add_argument('--dataset', metavar='DATASET', default='imagenet', 55 | help='dataset name or folder') 56 | parser.add_argument('--model', '-a', metavar='MODEL', default='mobilenet', 57 | choices=model_names, 58 | help='model architecture: ' + 59 | ' | '.join(model_names) + 60 | ' (default: alexnet)') 61 | parser.add_argument('--input_size', type=int, default=None, 62 | help='image input size') 63 | parser.add_argument('--model_config', default='', 64 | help='additional architecture configuration') 65 | parser.add_argument('--type', default='torch.FloatTensor', 66 | help='type of tensor - e.g torch.cuda.HalfTensor') 67 | parser.add_argument('--gpus', default='', 68 | help='gpus used for training - e.g 0,1,3') 69 | parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', 70 | help='number of data loading workers (default: 8)') 71 | parser.add_argument('--epochs', default=150, type=int, metavar='N', 72 | help='number of total epochs to run') 73 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 74 | help='manual epoch number (useful on restarts)') 75 | parser.add_argument('-b', '--batch-size', default=256, type=int, 76 | metavar='N', help='mini-batch size (default: 256)') 77 | parser.add_argument('--optimizer', default='SGD', type=str, metavar='OPT', 78 | help='optimizer function used') 79 | parser.add_argument('--lr', '--learning_rate', default=0.1, type=float, 80 | metavar='LR', help='initial learning rate') 81 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 82 | help='momentum') 83 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 84 | metavar='W', help='weight decay (default: 1e-4)') 85 | parser.add_argument('--print-freq', '-p', default=100, type=int, 86 | metavar='N', help='print frequency (default: 10)') 87 | parser.add_argument('--resume', default="mobilenet_1.0_128_best.pth", type=str, metavar='PATH', 88 | help='path to latest checkpoint (default: none)') 89 | parser.add_argument('-e', '--evaluate', action='store_true', 90 | help='run model on validation set') 91 | parser.add_argument('--save_check', action='store_true', 92 | help='saving the checkpoint') 93 | parser.add_argument('--terminal', action='store_true') 94 | # quantization parameters 95 | parser.add_argument('--quantize', default=True, action='store_true', 96 | help='quantize the network') 97 | parser.add_argument('--type_quant', default=None, 98 | help='Type of binarization process') 99 | parser.add_argument('--weight_bits', default=8, 100 | help='Number of bits for the weights') 101 | parser.add_argument('--activ_bits', default=8, 102 | help='Number of bits for the activations') 103 | 104 | parser.add_argument('--initial_folding', default=False, action='store_true', 105 | help='Fold BNs into Linear layers before training') 106 | parser.add_argument('--initial_equalization', default=False, action='store_true', 107 | help='Perform Linear layer weight equalization before training') 108 | parser.add_argument('--quant_add_config', default='', type=str, 109 | help='Additional config of per-layer quantization') 110 | 111 | # mobilenet params 112 | parser.add_argument('--mobilenet_width', default=1.0, type=float, 113 | help='Mobilenet Width Muliplier') 114 | parser.add_argument('--mobilenet_input', default=128, type=int, 115 | help='Mobilenet input resolution ') 116 | 117 | # mixed-precision params 118 | parser.add_argument('--mem_constraint', default='', type=str, 119 | help='Memory constraints for automatic bitwidth quantization') 120 | parser.add_argument('--mixed_prec_quant', default='MixPL', type=str, 121 | help='Type of quantization for mixed-precision low bitwidth: MixPL | MixPC') 122 | 123 | 124 | def main(): 125 | global args, best_prec1 126 | best_prec1 = 0 127 | args = parser.parse_args() 128 | 129 | weight_bits = int(args.weight_bits) 130 | activ_bits = int(args.activ_bits) 131 | 132 | print("run arguments: %s" % args) 133 | 134 | args.gpus = None 135 | 136 | # create model 137 | print("creating model %s", args.model) 138 | model_config = {'input_size': args.input_size, 'dataset': args.dataset, 'num_classes': 1000, \ 139 | 'width_mult': float(args.mobilenet_width), 'input_dim': float(args.mobilenet_input) } 140 | 141 | # model_config = dict(model_config, **literal_eval(args.model_config)) 142 | 143 | model = mobilenet(**model_config).to('cpu') 144 | print("created model with configuration: %s" % model_config) 145 | print(model) 146 | 147 | mobilenet_width = float(args.mobilenet_width) 148 | mobilenet_input = int(args.mobilenet_input) 149 | 150 | # transform the model in a NEMO FakeQuantized representation 151 | model = nemo.transform.quantize_pact(model, dummy_input=torch.randn((1,3,mobilenet_input,mobilenet_input)).to('cpu'), requantization_factor=128) 152 | 153 | checkpoint_file = args.resume 154 | if os.path.isfile(checkpoint_file): 155 | print("loading checkpoint '%s'" % args.resume) 156 | checkpoint_loaded = torch.load(checkpoint_file, map_location=torch.device('cpu')) 157 | checkpoint = checkpoint_loaded['state_dict'] 158 | model.load_state_dict(checkpoint, strict=True) 159 | prec_dict = checkpoint_loaded.get('precision') 160 | else: 161 | print("no checkpoint found at '%s'" % args.resume) 162 | import sys; sys.exit(1) 163 | 164 | print("[NEMO] Not calibrating model, as it is pretrained") 165 | model.change_precision(bits=1, min_prec_dict=prec_dict) 166 | 167 | inputs = torch.floor(torch.load("input_fq.pth", map_location=torch.device('cpu'))['in'] / (2./255)) * (2./255) 168 | inputs = inputs[:8] # reduce input size for GitHub CI regression test 169 | 170 | bin_fq, bout_fq, _ = nemo.utils.get_intermediate_activations(model, forward, model, inputs) 171 | 172 | input_bias = math.ceil(1.0 / (2./255)) * (2./255) 173 | input_bias_dict = {'model.0.0' : input_bias, 'model.0.1' : input_bias} 174 | remove_bias_dict = {'model.0.1' : 'model.0.2'} 175 | inputs += input_bias 176 | 177 | model.qd_stage(eps_in=2./255, add_input_bias_dict=input_bias_dict, remove_bias_dict=remove_bias_dict, precision=nemo.precision.Precision(bits=20), int_accurate=True, limit_at_32_bits=False, postpone_bn_hardening=False) 178 | # fix ConstantPad2d 179 | model.model[0][0].value = input_bias 180 | 181 | bin_qd, bout_qd, _ = nemo.utils.get_intermediate_activations(model, forward, model, inputs, input_bias=input_bias) 182 | qds = copy.deepcopy(model.state_dict()) 183 | 184 | model.id_stage(requantization_factor=128, limit_at_32_bits=False) 185 | # fix ConstantPad2d 186 | model.model[0][0].value = input_bias * (255./2) 187 | 188 | inputs = inputs * (255./2) 189 | ids = model.state_dict() 190 | bin_id, bout_id, _ = nemo.utils.get_intermediate_activations(model, forward, model, inputs, input_bias=input_bias, eps_in=2./255) 191 | 192 | diff = collections.OrderedDict() 193 | if SAVE_RESULTS: 194 | results = { 195 | 'mean_eps' : {}, 196 | 'max_eps' : {}, 197 | 'ratio' : {} 198 | } 199 | else: 200 | results = torch.load("mobi_qd_id_res.pth") 201 | for i in range(0,26): 202 | for j in range(3,4): 203 | k = 'model.%d.%d' % (i,j) 204 | kn = 'model.%d.%d' % (i if j<3 else i+1, j+1 if j<3 else 0) 205 | eps = model.get_eps_at(kn, eps_in=2./255)[0] 206 | diff[k] = (bout_id[k]*eps - bout_qd[k]).to('cpu').abs() 207 | print("%s:" % k) 208 | idx = diff[k]>=eps 209 | n = idx.sum() 210 | t = (diff[k]>-1e9).sum() 211 | max_eps = torch.ceil(diff[k].max() / eps).item() 212 | mean_eps = torch.ceil(diff[k][idx].mean() / eps).item() 213 | lim_max_eps = 0 if SAVE_RESULTS else math.ceil(results['max_eps'][k] * TOL_RESULTS) 214 | lim_mean_eps = 0 if SAVE_RESULTS else math.ceil(results['mean_eps'][k] * TOL_RESULTS) 215 | lim_ratio = 0 if SAVE_RESULTS else results['ratio'][k] * TOL_RESULTS 216 | try: 217 | print(" max: %.3f (%d eps): lim %d" % (diff[k].max().item(), max_eps, lim_max_eps)) 218 | print(" mean: %.3f (%d eps) (only diff. elements): lim %d" % (diff[k][idx].mean().item(), mean_eps, lim_mean_eps)) 219 | print(" #diff: %d/%d (%.1f%%): lim %.1f%%" % (n, t, float(n)/float(t)*100, lim_ratio)) 220 | except ValueError: 221 | mean_eps = 0.0 222 | max_eps = 0.0 223 | print(" #diff: 0/%d (0%%): lim %.3f" % (t, lim_ratio)) 224 | if SAVE_RESULTS: 225 | results['mean_eps'][k] = mean_eps 226 | results['max_eps'][k] = max_eps 227 | results['ratio'][k] = float(n)/float(t)*100 228 | assert(mean_eps <= math.ceil(results['mean_eps'][k] * TOL_RESULTS)) 229 | assert(max_eps <= math.ceil(results['max_eps'][k] * TOL_RESULTS)) 230 | assert(float(n)/float(t)*100 <= results['ratio'][k] * TOL_PERCENT) 231 | if SAVE_RESULTS: 232 | torch.save(results, "mobi_qd_id_res.pth") 233 | 234 | def forward(model, inputs, input_bias=0.0, eps_in=None, integer=False): 235 | 236 | model.eval() 237 | 238 | # measure data loading time 239 | with torch.no_grad(): 240 | input_var = inputs 241 | 242 | # compute output 243 | output = model(input_var) 244 | 245 | if __name__ == '__main__': 246 | main() 247 | -------------------------------------------------------------------------------- /tests/mobi_fq_qd_id/mobilenet.py: -------------------------------------------------------------------------------- 1 | # 2 | # mobilenet.py 3 | # Manuele Rusci 4 | # 5 | # Copyright (C) 2019 University of Bologna 6 | # All rights reserved. 7 | # 8 | # This is an implementation of the quantized mobilenet built from 9 | # https://github.com/marvis/pytorch-mobilenet/blob/master/main.py 10 | # 11 | 12 | import PIL 13 | import torch.nn as nn 14 | import torch.utils.model_zoo as model_zoo 15 | import math 16 | import torchvision.transforms as transforms 17 | import torch.nn.functional as F 18 | 19 | import nemo 20 | 21 | 22 | ###### Full Precision Blocks ############# 23 | def conv_dw(inp, oup, stride, pad1=0, bias_ena=False): 24 | padding = (1,1,1,1) if stride==1 else (0,1,0,1) 25 | return nn.Sequential( 26 | nn.ConstantPad2d(padding, value=0.), 27 | nn.Conv2d(inp, inp, 3, stride, 0, groups=inp, bias=bias_ena), 28 | nn.BatchNorm2d(inp), 29 | nn.ReLU6(inplace=False) 30 | ) 31 | 32 | def conv_pw(inp, oup, stride,bias_ena=False): 33 | padding = (0,0,0,0) 34 | return nn.Sequential( 35 | nn.ConstantPad2d(padding, value=0.), 36 | nn.Conv2d(inp, oup, 1, 1, 0, bias=bias_ena), 37 | nn.BatchNorm2d(oup), 38 | nn.ReLU6(inplace=False) 39 | ) 40 | 41 | def conv_bn(inp, oup, stride): 42 | padding = (1,1,1,1) if stride==1 else (0,1,0,1) 43 | return nn.Sequential( 44 | nn.ConstantPad2d(padding, value=0.), 45 | nn.Conv2d(inp, oup, 3, stride, 0, bias=False), 46 | nn.BatchNorm2d(oup), 47 | nn.ReLU6(inplace=False) 48 | ) 49 | 50 | class mobilenet_real(nn.Module): 51 | def __init__(self, width_mult=1.0, input_dim = 224): 52 | super(mobilenet_real, self).__init__() 53 | print(width_mult, input_dim) 54 | 55 | if input_dim == 224: 56 | avg_size = 7 57 | crop_size = 256 58 | elif input_dim == 192: 59 | avg_size = 6 60 | crop_size = 220 61 | elif input_dim == 160: 62 | avg_size = 5 63 | crop_size = 180 64 | elif input_dim == 128: 65 | avg_size = 4 66 | crop_size = 146 67 | else: 68 | return -1 69 | self.width_mult = width_mult 70 | self.model = nn.Sequential( 71 | conv_bn( 3, int(width_mult* 32), 2), 72 | conv_dw( int(width_mult* 32), int(width_mult* 64), 1), 73 | conv_pw( int(width_mult* 32), int(width_mult* 64), 1), 74 | conv_dw( int(width_mult* 64), int(width_mult*128), 2), 75 | conv_pw( int(width_mult* 64), int(width_mult*128), 2), 76 | conv_dw( int(width_mult*128), int(width_mult*128), 1), 77 | conv_pw( int(width_mult*128), int(width_mult*128), 1), 78 | conv_dw( int(width_mult*128), int(width_mult*256), 2), 79 | conv_pw( int(width_mult*128), int(width_mult*256), 2), 80 | conv_dw( int(width_mult*256), int(width_mult*256), 1), 81 | conv_pw( int(width_mult*256), int(width_mult*256), 1), 82 | conv_dw( int(width_mult*256), int(width_mult*512), 2), 83 | conv_pw( int(width_mult*256), int(width_mult*512), 2), 84 | conv_dw( int(width_mult*512), int(width_mult*512), 1), 85 | conv_pw( int(width_mult*512), int(width_mult*512), 1), 86 | conv_dw( int(width_mult*512), int(width_mult*512), 1), 87 | conv_pw( int(width_mult*512), int(width_mult*512), 1), 88 | conv_dw( int(width_mult*512), int(width_mult*512), 1), 89 | conv_pw( int(width_mult*512), int(width_mult*512), 1), 90 | conv_dw( int(width_mult*512), int(width_mult*512), 1), 91 | conv_pw( int(width_mult*512), int(width_mult*512), 1), 92 | conv_dw( int(width_mult*512), int(width_mult*512), 1), 93 | conv_pw( int(width_mult*512), int(width_mult*512), 1), 94 | conv_dw( int(width_mult*512), int(width_mult*1024), 2), 95 | conv_pw( int(width_mult*512), int(width_mult*1024), 2), 96 | conv_dw( int(width_mult*1024), int(width_mult*1024), 1), 97 | conv_pw( int(width_mult*1024), int(width_mult*1024), 1), 98 | nn.AvgPool2d(avg_size), 99 | ) 100 | self.fc = nn.Linear( int(width_mult*1024), 1000) 101 | 102 | self.regime = { 103 | 0: {'optimizer': 'Adam', 'lr': 1e-4 }, 104 | 5: {'lr': 5e-5}, 105 | 8: {'lr': 1e-5 } 106 | } 107 | 108 | #prepocessing 109 | normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 110 | 111 | self.input_transform = { 112 | 'train': transforms.Compose([ 113 | transforms.Scale(crop_size), #, interpolation=PIL.Image.BILINEAR), 114 | transforms.RandomCrop(input_dim), 115 | transforms.RandomHorizontalFlip(), 116 | transforms.ToTensor(), 117 | normalize 118 | ]), 119 | 'eval': transforms.Compose([ 120 | transforms.Scale(crop_size), #, interpolation=PIL.Image.BILINEAR), 121 | transforms.CenterCrop(input_dim), 122 | transforms.ToTensor(), 123 | normalize 124 | ]) 125 | } 126 | 127 | def forward(self, x): 128 | x = self.model(x) 129 | x = x.flatten(1) 130 | x = self.fc(x) 131 | return x 132 | 133 | 134 | def mobilenet(activ_bits =None, weight_bits= None, width_mult=1.0, input_dim = 224,**kwargs): 135 | print(','.join('{0}={1!r}'.format(k,v) for k,v in kwargs.items())) 136 | 137 | print(activ_bits, weight_bits) 138 | 139 | model = mobilenet_real(width_mult, input_dim) 140 | 141 | return model 142 | 143 | -------------------------------------------------------------------------------- /var/aloha.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pulp-platform/nemo/5ea3338ae172f96e996bdf75a5dacdf795282929/var/aloha.png --------------------------------------------------------------------------------