├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── figs └── logo_torchshiftadd.png ├── setup.py ├── test ├── layers │ ├── test_adder.py │ └── test_shift.py └── models │ ├── test_resnet20_adder.py │ ├── test_resnet20_shift.py │ └── test_resnet20_shiftadd.py └── torchshiftadd ├── layers ├── __init__.py ├── adder.py ├── attention.py ├── extension │ ├── adder_cuda.cpp │ ├── adder_cuda_kernel.cu │ └── adder_cuda_kernel_torch_1.4.cu └── shift.py ├── models ├── __init__.py ├── resnet20.py ├── resnet20_adder.py ├── resnet20_shift.py └── resnet20_shiftadd.py └── utils ├── __init__.py ├── ckpt_loading.py ├── comm.py ├── decorator.py ├── quantize.py ├── ste.py ├── test_acc.py └── torch.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | ckpts/ 11 | data.cifar*/ 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | cover/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | .pybuilder/ 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | # For a library or package, you might want to ignore these files since the code is 89 | # intended to run in multiple environments; otherwise, check them in: 90 | # .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # poetry 100 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 101 | # This is especially recommended for binary packages to ensure reproducibility, and is more 102 | # commonly ignored for libraries. 103 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 104 | #poetry.lock 105 | 106 | # pdm 107 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 108 | #pdm.lock 109 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 110 | # in version control. 111 | # https://pdm.fming.dev/#use-with-ide 112 | .pdm.toml 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | Coming soon. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | torchshiftadd logo 3 |

4 | 5 |

6 | A PyTorch library for developing energy efficient multiplication-less models. 7 |

8 | 9 | [![License: Apache 2.0](https://img.shields.io/badge/License-Apache%202.0-green)](https://opensource.org/licenses/Apache-2.0) 10 | [![Contributions](https://img.shields.io/badge/contributions-welcome-blue)](https://github.com/GATECH-EIC/torchshiftadd/blob/master/CONTRIBUTING.md) 11 | 12 | # TorchShiftAdd Overview 13 | 14 | Welcome to TorchShiftAdd, your go-to open-source library for crafting energy-efficient multiplication-less models and applications! 15 | 16 | [TorchShiftAdd](https://github.com/GATECH-EIC/torchshiftadd) embodies a pioneering initiative to simplify and expand the realm of multiplication-less networks within the machine learning community. Key features include: 17 | 18 | * Ready-to-use implementation of a wide range of ShiftAdd-based multiplication-less CNNs or Transformers. 19 | * CUDA kernels and TVM compilation support for seamless GPU deployment. 20 | * Profiling tools to furnish FLOPs, energy, and latency breakdown data for in-depth analysis and optimization. 21 | * Hardware accelerator simulators to estimate energy savings and latency improvements on ASICs or FPGAs. 22 | * Flexible support for developing both algorithmic and hardware accelerator designs tailored for multiplication-less networks. 23 | 24 | 25 | 26 | ## List of Implemented Papers 27 | * **ShiftAdd-based Convolutional Neural Networks** 28 | + [[NeurIPS'20] ShiftAddNet: A Hardware-Inspired Deep Network](https://arxiv.org/abs/2010.12785) 29 | + [[CVPR'20 Oral] AdderNet: Do We Really Need Multiplications in Deep Learning?](https://arxiv.org/abs/1912.13200) 30 | + [[CVPR'21 Workshop] DeepShift: Towards Multiplication-Less Neural Networks](https://arxiv.org/abs/1905.13298) 31 | * **ShiftAdd-based Transformers** 32 | + [[NeurIPS'23] ShiftAddViT: Mixture of Multiplication Primitives Towards Efficient Vision Transformer](https://arxiv.org/abs/2306.06446) 33 | + [[NeurIPS'24] ShiftAddLLM: Accelerating Pretrained LLMs via Post-Training Multiplication-Less Reparameterization](https://arxiv.org/abs/2406.05981) 34 | * **Hardware Accelerators for ShiftAdd-based Multiplication-less Networks** 35 | + [[ICCAD'22] NASA: Neural Architecture Search and Acceleration for Hardware Inspired Hybrid Networks](https://arxiv.org/abs/2210.13361) 36 | + [[IEEE TCAS-I] NASA+: Neural Architecture Search and Acceleration for Multiplication-Reduced Hybrid Networks](https://ieeexplore.ieee.org/document/10078392) 37 | 38 | # Installation 39 | 40 | ````bash 41 | python setup.py install -e . 42 | ```` 43 | 44 | # Qucik Start 45 | 46 | Currently codebase supports ShiftAdd-based CNNs. To use them, check our test files: 47 | 48 | ````bash 49 | python test/models/test_resnet20_adder.py 50 | python test/models/test_resnet20_shift.py 51 | python test/models/test_resnet20_shiftadd.py 52 | ```` 53 | 54 | # Upcoming Features 55 | 56 | We will continously develop this toolbox: 57 | 58 | - [x] ShiftAdd-based Convolutional Neural Networks 59 | - [ ] ShiftAdd-based Transformers 60 | - [ ] Hardware Accelerators for Energy & Latency Estimation 61 | 62 | # Contributing 63 | 64 | TorchShiftAdd is released under [Apache-2.0 License](LICENSE). Everyone is welcome to contribute to the development of TorchShiftAdd. Please refer to [contributing guidelines](CONTRIBUTING.md) for more details. 65 | 66 | # Acknowledgement 67 | 68 | All co-authors of ShiftAddNet, ShiftAddNAS, ShiftAddViT, and ShiftAddLLM. -------------------------------------------------------------------------------- /figs/logo_torchshiftadd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GATECH-EIC/torchshiftadd/f48be2983b0a0f3a299eca449e88c6fb8e5264fb/figs/logo_torchshiftadd.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r") as fh: 4 | long_description = fh.read() 5 | 6 | setuptools.setup( 7 | name="torchshiftadd", 8 | version="0.0.1", 9 | author="TorchShiftAdd Team", 10 | author_email="haoran.you@gatech.edu", 11 | description="A PyTorch library for developing energy efficient multiplication-less models", 12 | license="Apache License 2.0", 13 | long_description=long_description, 14 | long_description_content_type="text/markdown", 15 | url="https://github.com/GATECH-EIC/torchshiftadd", 16 | packages=setuptools.find_packages(), 17 | package_data={ 18 | "torchshiftadd": [ 19 | "layers/extension/adder_cuda.cpp", 20 | "layers/extension/adder_cuda_kernel.cu", 21 | ] 22 | }, 23 | install_requires=[ 24 | "torch>=1.7.0", 25 | "torchvision", 26 | "numpy>=1.19.0", 27 | "scipy>=1.5.0", 28 | "scikit-learn>=0.23.0", 29 | "matplotlib>=3.2.0", 30 | "tqdm>=4.46.0", 31 | "ninja", 32 | ], 33 | python_requires=">=3.6,<3.13", 34 | classifiers=[ 35 | "Programming Language :: Python :: 3", 36 | "Operating System :: OS Independent", 37 | ], 38 | ) -------------------------------------------------------------------------------- /test/layers/test_adder.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | 4 | from torchshiftadd import layers 5 | 6 | class Adder2DTest(unittest.TestCase): 7 | 8 | def setup(self): 9 | self.input = torch.rand(1, 3, 32, 32) 10 | self.weight = torch.rand(64, 3, 3, 3) 11 | self.bias = torch.rand(64) 12 | self.stride = 1 13 | self.padding = 1 14 | self.groups = 1 15 | self.eta = 1.0 16 | 17 | def test_adder2d(self): 18 | self.setup() 19 | adder = layers.Adder2D( 20 | input_channel=3, 21 | output_channel=64, 22 | kernel_size=3, 23 | stride=self.stride, 24 | padding=self.padding, 25 | groups=self.groups, 26 | bias=True, 27 | eta=self.eta, 28 | ) 29 | output = adder(self.input) 30 | self.assertEqual(output.shape, (1, 64, 32, 32)) 31 | 32 | if __name__ == "__main__": 33 | unittest.main() -------------------------------------------------------------------------------- /test/layers/test_shift.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | 4 | from torchshiftadd import layers 5 | 6 | class LinearShiftTest(unittest.TestCase): 7 | 8 | def setup(self): 9 | self.input = torch.rand(32, 32) 10 | 11 | def test_adder2d(self): 12 | self.setup() 13 | shift = layers.LinearShift( 14 | in_features=32, 15 | out_features=64, 16 | bias=True, 17 | ) 18 | output = shift(self.input) 19 | self.assertEqual(output.shape, (32, 64)) 20 | 21 | class Conv2dShiftTest(unittest.TestCase): 22 | 23 | def setup(self): 24 | self.input = torch.rand(1, 3, 32, 32) 25 | self.weight = torch.rand(64, 3, 3, 3) 26 | self.bias = torch.rand(64) 27 | self.stride = 1 28 | self.padding = 1 29 | self.groups = 1 30 | 31 | def test_adder2d(self): 32 | self.setup() 33 | shift = layers.Conv2dShift( 34 | in_channels=3, 35 | out_channels=64, 36 | kernel_size=3, 37 | stride=self.stride, 38 | padding=self.padding, 39 | groups=self.groups, 40 | bias=True, 41 | weight_bits=4, 42 | input_bits=16, 43 | ) 44 | output = shift(self.input) 45 | self.assertEqual(output.shape, (1, 64, 32, 32)) 46 | 47 | if __name__ == "__main__": 48 | unittest.main() -------------------------------------------------------------------------------- /test/models/test_resnet20_adder.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | import torchvision.transforms as transforms 4 | import torchvision.datasets as datasets 5 | from huggingface_hub import hf_hub_download 6 | from torchshiftadd import layers, models 7 | from torchshiftadd.utils import load_add_state_dict, test_acc 8 | 9 | class ResNetAdderShapeTest(unittest.TestCase): 10 | 11 | def setup(self): 12 | self.input = torch.rand(2, 3, 32, 32) 13 | self.model = models.resnet20_adder() 14 | 15 | def test_resnet_adder(self): 16 | self.setup() 17 | output = self.model(self.input) 18 | self.assertEqual(output.shape, (2, 10)) 19 | 20 | class ResNetAdderAccTest(unittest.TestCase): 21 | 22 | def setup(self): 23 | self.model = models.resnet20_adder().cuda() 24 | self.test_dataloader = torch.utils.data.DataLoader( 25 | datasets.CIFAR10( 26 | './data.cifar10', 27 | train=False, 28 | download=True, 29 | transform=transforms.Compose([ 30 | transforms.ToTensor(), 31 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 32 | ]) 33 | ), 34 | batch_size=64, 35 | shuffle=True, 36 | ) 37 | 38 | def test_resnet_adder(self): 39 | self.setup() 40 | 41 | ckpt_path = hf_hub_download( 42 | repo_id="hryou1998/pretrained-ckpts", 43 | filename="resnet20-adder-cifar10-FP32.pth.tar", 44 | ) 45 | 46 | checkpoint = torch.load(ckpt_path, map_location='cpu') 47 | self.model.load_state_dict(load_add_state_dict(checkpoint['state_dict'])) 48 | 49 | top_1, top_5 = test_acc(self.model, self.test_dataloader) 50 | 51 | print("Top-1 Acc: {:.2f}%".format(top_1)) 52 | print("Top-5 Acc: {:.2f}%".format(top_5)) 53 | 54 | if __name__ == "__main__": 55 | unittest.main() -------------------------------------------------------------------------------- /test/models/test_resnet20_shift.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | import torchvision.transforms as transforms 4 | import torchvision.datasets as datasets 5 | from huggingface_hub import hf_hub_download 6 | from torchshiftadd import layers, models 7 | from torchshiftadd.utils import test_acc 8 | 9 | class ResNetShiftShapeTest(unittest.TestCase): 10 | 11 | def setup(self): 12 | self.input = torch.rand(2, 3, 32, 32) 13 | self.model = models.resnet20() 14 | models.convert_to_shift(self.model) 15 | 16 | def test_resnet_shift(self): 17 | self.setup() 18 | output = self.model(self.input) 19 | self.assertEqual(output.shape, (2, 10)) 20 | 21 | class ResNetShiftAccTest(unittest.TestCase): 22 | 23 | def setup(self): 24 | self.model = models.resnet20() 25 | models.convert_to_shift(self.model) 26 | self.model.cuda() 27 | self.test_dataloader = torch.utils.data.DataLoader( 28 | datasets.CIFAR10( 29 | './data.cifar10', 30 | train=False, 31 | download=True, 32 | transform=transforms.Compose([ 33 | transforms.ToTensor(), 34 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 35 | ]) 36 | ), 37 | batch_size=64, 38 | shuffle=True, 39 | ) 40 | 41 | def test_resnet_shift(self): 42 | self.setup() 43 | 44 | ckpt_path = hf_hub_download( 45 | repo_id="hryou1998/pretrained-ckpts", 46 | filename="resnet20-shift-cifar10.pth.tar", 47 | ) 48 | 49 | checkpoint = torch.load(ckpt_path, map_location='cpu') 50 | self.model.load_state_dict(checkpoint['state_dict']) 51 | models.convert_to_shift(self.model) 52 | 53 | top_1, top_5 = test_acc(self.model, self.test_dataloader) 54 | 55 | print("Top-1 Acc: {:.2f}%".format(top_1)) 56 | print("Top-5 Acc: {:.2f}%".format(top_5)) 57 | 58 | if __name__ == "__main__": 59 | unittest.main() -------------------------------------------------------------------------------- /test/models/test_resnet20_shiftadd.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | import torchvision.transforms as transforms 4 | import torchvision.datasets as datasets 5 | from huggingface_hub import hf_hub_download 6 | from torchshiftadd import layers, models 7 | from torchshiftadd.utils import load_shiftadd_state_dict, test_acc 8 | 9 | class ResNetShiftAddShapeTest(unittest.TestCase): 10 | 11 | def setup(self): 12 | self.input = torch.rand(2, 3, 32, 32) 13 | self.model = models.resnet20_shiftadd() 14 | 15 | def test_resnet_shiftadd(self): 16 | self.setup() 17 | output = self.model(self.input) 18 | self.assertEqual(output.shape, (2, 10)) 19 | 20 | class ResNetShiftAddAccTest(unittest.TestCase): 21 | 22 | def setup(self): 23 | self.model = models.resnet20_shiftadd().cuda() 24 | self.test_dataloader = torch.utils.data.DataLoader( 25 | datasets.CIFAR10( 26 | './data.cifar10', 27 | train=False, 28 | download=True, 29 | transform=transforms.Compose([ 30 | transforms.ToTensor(), 31 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 32 | ]) 33 | ), 34 | batch_size=64, 35 | shuffle=True, 36 | ) 37 | 38 | def test_resnet_shiftadd(self): 39 | self.setup() 40 | 41 | ckpt_path = hf_hub_download( 42 | repo_id="hryou1998/pretrained-ckpts", 43 | filename="resnet20-shiftadd-cifar10-FIX16.pth.tar", 44 | ) 45 | 46 | checkpoint = torch.load(ckpt_path, map_location='cpu') 47 | self.model.load_state_dict(load_shiftadd_state_dict(checkpoint['state_dict']), strict=False) 48 | 49 | top_1, top_5 = test_acc(self.model, self.test_dataloader) 50 | 51 | print("Top-1 Acc: {:.2f}%".format(top_1)) 52 | print("Top-5 Acc: {:.2f}%".format(top_5)) 53 | 54 | if __name__ == "__main__": 55 | unittest.main() -------------------------------------------------------------------------------- /torchshiftadd/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .adder import Adder2D 2 | from .shift import LinearShift, Conv2dShift -------------------------------------------------------------------------------- /torchshiftadd/layers/adder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | 7 | from torchshiftadd import utils 8 | 9 | path = os.path.join(os.path.dirname(__file__), "extension") 10 | adder_cuda = utils.load_extension( 11 | "adder_cuda", [ 12 | os.path.join(path, "adder_cuda.cpp"), 13 | os.path.join(path, "adder_cuda_kernel.cu"), 14 | ] 15 | ) 16 | 17 | class Adder2D(nn.Module): 18 | 19 | def __init__(self, 20 | input_channel, 21 | output_channel, 22 | kernel_size, 23 | stride = 1, 24 | padding = 0, 25 | groups = 1, 26 | bias = False, 27 | eta = 0.2, 28 | ): 29 | super(Adder2D, self).__init__() 30 | self.stride = stride 31 | self.padding = padding 32 | self.groups = groups 33 | self.input_channel = input_channel 34 | self.output_channel = output_channel 35 | self.kernel_size = kernel_size 36 | self.bias = bias 37 | self.eta = eta 38 | 39 | self.adder = torch.nn.Parameter( 40 | nn.init.normal_( 41 | torch.randn(output_channel, input_channel // groups, kernel_size, kernel_size) 42 | ) 43 | ) 44 | 45 | if self.bias: 46 | self.bias = torch.nn.Parameter( 47 | nn.init.uniform_(torch.zeros(output_channel)) 48 | ) 49 | else: 50 | self.bias = None 51 | 52 | def forward(self, input, ratio_out=1, ratio_in=1, ratio_g=1, kernel=None): 53 | 54 | sample_weight = self.adder[:(self.output_channel//ratio_out),:(self.input_channel//ratio_in),:,:] 55 | if (kernel!=None): 56 | start, end = sub_filter_start_end(5, kernel) 57 | sample_weight = sample_weight[:,:, start:end, start:end] 58 | padding = kernel//2 59 | else: 60 | padding = self.padding 61 | 62 | output = Adder2DFunction.apply( 63 | input, 64 | sample_weight, 65 | self.kernel_size, 66 | self.stride, 67 | padding, 68 | (self.groups//ratio_g), 69 | self.eta, 70 | ) 71 | if self.bias is not None: 72 | output += self.bias.unsqueeze(0).unsqueeze(2).unsqueeze(3) 73 | 74 | return output 75 | 76 | 77 | class Adder2DFunction(torch.autograd.Function): 78 | @staticmethod 79 | def forward(ctx, input, weight, kernel_size, stride, padding, groups, eta): 80 | ctx.save_for_backward(input, weight) 81 | ctx.kernel_size = kernel_size 82 | ctx.stride = stride 83 | ctx.padding = padding 84 | ctx.groups = groups 85 | ctx.eta = eta 86 | ctx.quantize = False 87 | 88 | output = input.new_zeros( 89 | get_conv2d_output_shape(input, weight, stride, padding) 90 | ) 91 | 92 | adder_cuda.forward( 93 | input, 94 | weight, 95 | output, 96 | kernel_size, kernel_size, 97 | stride, stride, 98 | padding, padding, 99 | groups, groups 100 | ) 101 | 102 | return output 103 | 104 | @staticmethod 105 | def backward(ctx, grad_output): 106 | input, weight = ctx.saved_tensors 107 | grad_input = grad_weight = None 108 | eta, kernel_size, stride, padding, groups = ( 109 | ctx.eta, ctx.kernel_size, ctx.stride, ctx.padding, ctx.groups 110 | ) 111 | 112 | # input 113 | if ctx.needs_input_grad[0]: 114 | grad_input = torch.zeros_like(input) 115 | adder_cuda.backward_input( 116 | grad_output, 117 | input, 118 | weight, 119 | grad_input, 120 | kernel_size, kernel_size, 121 | stride, stride, 122 | padding, padding, 123 | groups, groups 124 | ) 125 | 126 | # weight 127 | if ctx.needs_input_grad[1]: 128 | grad_weight = torch.zeros_like(weight) 129 | adder_cuda.backward_weight( 130 | grad_output, 131 | input, 132 | weight, 133 | grad_weight, 134 | kernel_size, kernel_size, 135 | stride, stride, 136 | padding, padding, 137 | groups, groups) 138 | grad_weight = eta * np.sqrt(grad_weight.numel()) / torch.norm(grad_weight).clamp(min=1e-12) * grad_weight 139 | 140 | return grad_input, grad_weight, None, None, None, None, None 141 | 142 | 143 | def get_conv2d_output_shape(input, weight, stride, padding): 144 | n_filters, d_filter, h_filter, w_filter = weight.size() 145 | n_x, d_x, h_x, w_x = input.size() 146 | 147 | h_out = (h_x - h_filter + 2 * padding) // stride + 1 148 | w_out = (w_x - w_filter + 2 * padding) // stride + 1 149 | 150 | return (n_x, n_filters, h_out, w_out) 151 | 152 | 153 | def sub_filter_start_end(kernel_size, sub_kernel_size): 154 | center = kernel_size // 2 155 | dev = sub_kernel_size // 2 156 | start, end = center - dev, center + dev + 1 157 | assert end - start == sub_kernel_size 158 | return start, end -------------------------------------------------------------------------------- /torchshiftadd/layers/attention.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GATECH-EIC/torchshiftadd/f48be2983b0a0f3a299eca449e88c6fb8e5264fb/torchshiftadd/layers/attention.py -------------------------------------------------------------------------------- /torchshiftadd/layers/extension/adder_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | int adder_cuda_forward( 5 | const at::Tensor &input, 6 | const at::Tensor &weight, 7 | // const at::Tensor &bias, 8 | at::Tensor &output, 9 | int KW, int KH, 10 | int SW, int SH, 11 | int PW, int PH, 12 | int GW, int GH 13 | ); 14 | 15 | int adder_cuda_backward_grad_in( 16 | at::Tensor &grad_out, 17 | at::Tensor &input, 18 | at::Tensor &weight, 19 | at::Tensor &grad_in, 20 | int KW, int KH, 21 | int SW, int SH, 22 | int PW, int PH, 23 | int GW, int GH 24 | ); 25 | 26 | int adder_cuda_backward_grad_weight( 27 | at::Tensor &grad_out, 28 | at::Tensor &input, 29 | at::Tensor &weight, 30 | at::Tensor &grad_weight, 31 | int KW, int KH, 32 | int SW, int SH, 33 | int PW, int PH, 34 | int GW, int GH 35 | ); 36 | 37 | #define CHECK_CUDA(x) AT_ASSERT((x).type().is_cuda(), #x "must be a CUDA tensor") 38 | #define CHECK_CONTIGUOUS(x) AT_ASSERT((x).type().is_contiguous(), #x "must be contiguous") 39 | #define CHECK_INPUT(x) \ 40 | CHECK_CUDA((x)); \ 41 | CHECK_CONTIGUOUS((x)) 42 | 43 | int adder_forward( 44 | const at::Tensor &input, 45 | const at::Tensor &weight, 46 | // const at::Tensor &bias, 47 | at::Tensor &output, 48 | int KW, int KH, 49 | int SW, int SH, 50 | int PW, int PH, 51 | int GW, int GH 52 | ) 53 | { 54 | // TODO: add checks checks 55 | return adder_cuda_forward( 56 | input, 57 | weight, 58 | // bias, 59 | output, 60 | KW, KH, 61 | SW, SH, 62 | PW, PH, 63 | GW, GH 64 | ); 65 | } 66 | 67 | int adder_backward_input( 68 | at::Tensor &grad_out, 69 | at::Tensor &input, 70 | at::Tensor &weight, 71 | at::Tensor &grad_in, 72 | int KW, int KH, 73 | int SW, int SH, 74 | int PW, int PH, 75 | int GW, int GH 76 | ) 77 | { 78 | // TODO: add checks checks 79 | return adder_cuda_backward_grad_in( 80 | grad_out, 81 | input, 82 | weight, 83 | grad_in, 84 | KW, KH, 85 | SW, SH, 86 | PW, PH, 87 | GW, GH 88 | ); 89 | } 90 | 91 | int adder_backward_weight( 92 | at::Tensor &grad_out, 93 | at::Tensor &input, 94 | at::Tensor &weight, 95 | at::Tensor &grad_weight, 96 | int KW, int KH, 97 | int SW, int SH, 98 | int PW, int PH, 99 | int GW, int GH 100 | ) 101 | { 102 | // TODO: add checks checks 103 | return adder_cuda_backward_grad_weight( 104 | grad_out, 105 | input, 106 | weight, 107 | grad_weight, 108 | KW, KH, 109 | SW, SH, 110 | PW, PH, 111 | GW, GH 112 | ); 113 | } 114 | 115 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 116 | { 117 | m.def("forward", &adder_forward, "adder forward (CUDA)"); 118 | m.def("backward_input", &adder_backward_input, "adder backward input (CUDA)"); 119 | m.def("backward_weight", &adder_backward_weight, "adder backward weight (CUDA)"); 120 | } -------------------------------------------------------------------------------- /torchshiftadd/layers/extension/adder_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | #include 8 | #include 9 | #define MAX_BLOCKS 256 10 | #define NUM_THREADS 256 11 | #define MAX(a, b) ((a) > (b)) ? (a) : (b) 12 | #define MIN(a, b) ((a) < (b)) ? (a) : (b) 13 | #define HARDTANH(x) ((x) < (-1.0)) ? (-1.0) : (((x) <= (1.0)) ? (x) : (1.0)) 14 | const int WARP_SIZE = 32; 15 | const int MAX_BLOCK_SIZE = 256; 16 | 17 | template 18 | struct SharedMem { 19 | __device__ T *getPointer() { 20 | __shared__ T smem[MAX_BLOCK_SIZE]; 21 | return smem; 22 | } 23 | }; 24 | 25 | static int getGradParamsNumThreads(int batchSize) { 26 | return std::min(batchSize * WARP_SIZE, MAX_BLOCK_SIZE); 27 | } 28 | 29 | int get_blocks(int n) { 30 | return MIN(MAX_BLOCKS, (n - NUM_THREADS + 1) / NUM_THREADS) + 1; 31 | } 32 | 33 | template 34 | __global__ void adder_forward_kernel( 35 | const scalar_t* __restrict__ input, 36 | const scalar_t* __restrict__ weight, 37 | scalar_t* __restrict__ output, 38 | const int num_elem, 39 | const int out_channels, 40 | const int in_channels, 41 | const int IW, const int IH, 42 | const int OW, const int OH, 43 | const int KW, const int KH, 44 | const int SW, const int SH, 45 | const int PW, const int PH, 46 | const int GW, const int GH) 47 | { 48 | // #TODO: 49 | if (GW==1) 50 | { 51 | for (int index = blockDim.x * blockIdx.x + threadIdx.x; index < num_elem; index += gridDim.x * blockDim.x) 52 | { 53 | const int n = index / OW / OH / out_channels; 54 | const int m = index / OW / OH % out_channels; 55 | const int h = index / OW % OH; 56 | const int w = index % OW; 57 | 58 | const scalar_t *p_weight = weight + m * in_channels * KH * KW; //the start position of the kernel(corresponding to the output) 59 | // scalar_t value = bias[m]; 60 | scalar_t value = 0; 61 | // #TODO: 62 | // #pragma unroll 63 | for (int cc = 0; cc < in_channels; cc++) 64 | { 65 | // #pragma unroll 66 | const int image_offset0 = (n * in_channels + cc) * IH * IW; 67 | for (int kh = 0; kh < KH; kh++) 68 | { 69 | // #pragma unroll 70 | for (int kw = 0; kw < KW; kw++) 71 | { 72 | const int ih = h * SH - PH + kh; 73 | const int iw = w * SW - PW + kw; 74 | 75 | bool boundary_condition = (ih >= 0) && (ih < IH) && (iw >= 0) && (iw < IW); 76 | if (boundary_condition) 77 | { 78 | // value += input[image_offset0 + ih * IW + iw] * (*p_weight); 79 | value -= abs(input[image_offset0 + ih * IW + iw] - (*p_weight)); 80 | } 81 | else // padded area 82 | { 83 | value -= abs(*p_weight); 84 | } 85 | p_weight++; 86 | } 87 | } 88 | } 89 | output[index] = value; 90 | } 91 | } 92 | if (GW==2) 93 | { 94 | for (int index = blockDim.x * blockIdx.x + threadIdx.x; index < num_elem; index += gridDim.x * blockDim.x) 95 | // total size of output (batch_size * out_channels * W-out * H_out) 96 | { 97 | // TODO 98 | const int n = index / OW / OH / out_channels; // batch size of n 99 | const int m = index / OW / OH % out_channels; // relative output channel size 100 | const int h = index / OW % OH; // relative position of H 101 | const int w = index % OW; //relative position of W 102 | 103 | const scalar_t *p_weight = weight + m * in_channels/2 * KH * KW; //the start position of the kernel(corresponding to the output) 104 | 105 | // scalar_t value = bias[m]; 106 | scalar_t value = 0; 107 | // #TODO: 108 | // #pragma unroll 109 | if (m < out_channels/2) 110 | { 111 | for (int cc = 0; cc < in_channels/2; cc++) 112 | { 113 | // #pragma unroll 114 | const int image_offset0 = (n * in_channels + cc) * IH * IW; //channel offset (absolute) 115 | for (int kh = 0; kh < KH; kh++) 116 | { 117 | // #pragma unroll 118 | for (int kw = 0; kw < KW; kw++) 119 | { 120 | const int ih = h * SH - PH + kh; // *stride-padding of H 121 | const int iw = w * SW - PW + kw; // *stride-padding of W 122 | 123 | bool boundary_condition = (ih >= 0) && (ih < IH) && (iw >= 0) && (iw < IW); 124 | if (boundary_condition) 125 | { 126 | // value += input[image_offset0 + ih * IW + iw] * (*p_weight); 127 | value -= abs(input[image_offset0 + ih * IW + iw] - (*p_weight)); //pure operation 128 | } 129 | else // padded area 130 | { 131 | value -= abs(*p_weight); 132 | } 133 | p_weight++; 134 | } 135 | } 136 | } 137 | } 138 | else 139 | { 140 | for (int cc = in_channels/2; cc < in_channels; cc++) 141 | { 142 | // #pragma unroll 143 | const int image_offset0 = (n * in_channels + cc) * IH * IW; //channel offset (absolute) 144 | for (int kh = 0; kh < KH; kh++) 145 | { 146 | // #pragma unroll 147 | for (int kw = 0; kw < KW; kw++) 148 | { 149 | const int ih = h * SH - PH + kh; // *stride-padding of H 150 | const int iw = w * SW - PW + kw; // *stride-padding of W 151 | 152 | bool boundary_condition = (ih >= 0) && (ih < IH) && (iw >= 0) && (iw < IW); 153 | if (boundary_condition) 154 | { 155 | // value += input[image_offset0 + ih * IW + iw] * (*p_weight); 156 | value -= abs(input[image_offset0 + ih * IW + iw] - (*p_weight)); //pure operation 157 | } 158 | else // padded area 159 | { 160 | value -= abs(*p_weight); 161 | } 162 | p_weight++; 163 | } 164 | } 165 | } 166 | } 167 | output[index] = value; 168 | } 169 | } 170 | if (GW==in_channels) //Dpws 171 | { 172 | for (int index = blockDim.x * blockIdx.x + threadIdx.x; index < num_elem; index += gridDim.x * blockDim.x) 173 | { 174 | const int n = index / OW / OH / out_channels; 175 | const int m = index / OW / OH % out_channels; 176 | const int h = index / OW % OH; 177 | const int w = index % OW; 178 | 179 | const scalar_t *p_weight = weight + m * 1 * KH * KW; 180 | // scalar_t value = bias[m]; 181 | scalar_t value = 0; 182 | // #TODO: 183 | // #pragma unroll 184 | // #pragma unroll 185 | const int image_offset0 = (n * in_channels + m) * IH * IW; 186 | for (int kh = 0; kh < KH; kh++) 187 | { 188 | // #pragma unroll 189 | for (int kw = 0; kw < KW; kw++) 190 | { 191 | const int ih = h * SH - PH + kh; 192 | const int iw = w * SW - PW + kw; 193 | 194 | bool boundary_condition = (ih >= 0) && (ih < IH) && (iw >= 0) && (iw < IW); 195 | if (boundary_condition) 196 | { 197 | // value += input[image_offset0 + ih * IW + iw] * (*p_weight); 198 | value -= abs(input[image_offset0 + ih * IW + iw] - (*p_weight)); 199 | } 200 | else // padded area 201 | { 202 | value -= abs(*p_weight); 203 | } 204 | p_weight++; 205 | } 206 | } 207 | output[index] = value; 208 | } 209 | } 210 | 211 | } 212 | 213 | template 214 | __global__ void adder_backward_grad_in_kernel( 215 | scalar_t *grad_out, 216 | scalar_t *input, 217 | scalar_t *weight, 218 | scalar_t *grad_in, 219 | const int num_elem, 220 | const int out_channels, 221 | const int in_channels, 222 | const int IW, const int IH, 223 | const int OW, const int OH, 224 | const int KW, const int KH, 225 | const int SW, const int SH, 226 | const int PW, const int PH, 227 | const int GW, const int GH) 228 | { if (GW==1) 229 | { 230 | for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_elem; index += gridDim.x * blockDim.x) 231 | { 232 | const int n = index / IW / IH / in_channels; 233 | const int c = index / IW / IH % in_channels; 234 | const int h = index / IW % IH; 235 | const int w = index % IW; 236 | 237 | scalar_t value = 0; 238 | for (int mm = 0; mm < out_channels; mm++) 239 | { 240 | const int grad_out_offset0 = (n * out_channels + mm) * OH * OW; 241 | scalar_t *p_weight = weight + (mm * in_channels + c) * KH * KW; 242 | for (int kh = 0; kh < KH; kh++) 243 | { 244 | for (int kw = 0; kw < KW; kw++) 245 | { 246 | int oh = h + PH - kh; 247 | int ow = w + PW - kw; 248 | 249 | if ((oh % SH == 0) && (ow % SW == 0)) 250 | { 251 | const bool boundary_condition = (oh >= 0) && (oh < OH) && (ow >= 0) && (ow < OW); 252 | if (boundary_condition) 253 | { 254 | oh = oh / SH; 255 | ow = ow / SW; 256 | // value += grad_out[grad_out_offset0 + oh * OW + ow] * (*p_weight); 257 | scalar_t ht = HARDTANH(*p_weight - input[index]); 258 | value += grad_out[grad_out_offset0 + oh * OW + ow] * ht; 259 | } 260 | } 261 | p_weight++; 262 | } 263 | } 264 | } 265 | grad_in[index] = value; 266 | } 267 | } 268 | if (GW==2) 269 | { 270 | for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_elem; index += gridDim.x * blockDim.x) 271 | { 272 | const int n = index / IW / IH / in_channels; 273 | const int c = index / IW / IH % in_channels; 274 | const int h = index / IW % IH; 275 | const int w = index % IW; 276 | 277 | scalar_t value = 0; 278 | if (c < in_channels/2) 279 | { 280 | for (int mm = 0; mm < out_channels/2; mm++) 281 | { 282 | const int grad_out_offset0 = (n * out_channels + mm) * OH * OW; 283 | scalar_t *p_weight = weight + (mm * in_channels/2 + c) * KH * KW; 284 | for (int kh = 0; kh < KH; kh++) 285 | { 286 | for (int kw = 0; kw < KW; kw++) 287 | { 288 | int oh = h + PH - kh; 289 | int ow = w + PW - kw; 290 | 291 | if ((oh % SH == 0) && (ow % SW == 0)) 292 | { 293 | const bool boundary_condition = (oh >= 0) && (oh < OH) && (ow >= 0) && (ow < OW); 294 | if (boundary_condition) 295 | { 296 | oh = oh / SH; 297 | ow = ow / SW; 298 | // value += grad_out[grad_out_offset0 + oh * OW + ow] * (*p_weight); 299 | scalar_t ht = HARDTANH(*p_weight - input[index]); 300 | value += grad_out[grad_out_offset0 + oh * OW + ow] * ht; 301 | } 302 | } 303 | p_weight++; 304 | } 305 | } 306 | } 307 | } 308 | else 309 | { 310 | for (int mm = out_channels/2; mm < out_channels; mm++) 311 | { 312 | const int grad_out_offset0 = (n * out_channels + mm) * OH * OW; 313 | scalar_t *p_weight = weight + (mm * in_channels/2 + c - out_channels/2) * KH * KW; 314 | for (int kh = 0; kh < KH; kh++) 315 | { 316 | for (int kw = 0; kw < KW; kw++) 317 | { 318 | int oh = h + PH - kh; 319 | int ow = w + PW - kw; 320 | 321 | if ((oh % SH == 0) && (ow % SW == 0)) 322 | { 323 | const bool boundary_condition = (oh >= 0) && (oh < OH) && (ow >= 0) && (ow < OW); 324 | if (boundary_condition) 325 | { 326 | oh = oh / SH; 327 | ow = ow / SW; 328 | // value += grad_out[grad_out_offset0 + oh * OW + ow] * (*p_weight); 329 | scalar_t ht = HARDTANH(*p_weight - input[index]); 330 | value += grad_out[grad_out_offset0 + oh * OW + ow] * ht; 331 | } 332 | } 333 | p_weight++; 334 | } 335 | } 336 | } 337 | } 338 | grad_in[index] = value; 339 | } 340 | } 341 | if (GW==in_channels) 342 | { 343 | for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_elem; index += gridDim.x * blockDim.x) 344 | { 345 | const int n = index / IW / IH / in_channels; 346 | const int c = index / IW / IH % in_channels; 347 | const int h = index / IW % IH; 348 | const int w = index % IW; 349 | 350 | scalar_t value = 0; 351 | 352 | const int grad_out_offset0 = (n * out_channels + c) * OH * OW; 353 | scalar_t *p_weight = weight + c * KH * KW; 354 | for (int kh = 0; kh < KH; kh++) 355 | { 356 | for (int kw = 0; kw < KW; kw++) 357 | { 358 | int oh = h + PH - kh; 359 | int ow = w + PW - kw; 360 | 361 | if ((oh % SH == 0) && (ow % SW == 0)) 362 | { 363 | const bool boundary_condition = (oh >= 0) && (oh < OH) && (ow >= 0) && (ow < OW); 364 | if (boundary_condition) 365 | { 366 | oh = oh / SH; 367 | ow = ow / SW; 368 | // value += grad_out[grad_out_offset0 + oh * OW + ow] * (*p_weight); 369 | scalar_t ht = HARDTANH(*p_weight - input[index]); 370 | value += grad_out[grad_out_offset0 + oh * OW + ow] * ht; 371 | } 372 | } 373 | p_weight++; 374 | } 375 | } 376 | grad_in[index] = value; 377 | } 378 | } 379 | } 380 | 381 | template 382 | __global__ void adder_backward_grad_weight_kernel( 383 | scalar_t *grad_out, 384 | scalar_t *input, 385 | scalar_t *weight, 386 | scalar_t *grad_weight, 387 | const int batch_size, 388 | const int out_channels, 389 | const int in_channels, 390 | const int IW, const int IH, 391 | const int OW, const int OH, 392 | const int KW, const int KH, 393 | const int SW, const int SH, 394 | const int PW, const int PH, 395 | const int GW, const int GH) 396 | { 397 | int bidx = blockIdx.x; 398 | int kW = bidx % KW; 399 | int kH = bidx / KW % KH; 400 | int ch = bidx / KW / KH % in_channels; 401 | int mh = bidx / KW / KH / in_channels; 402 | 403 | if (GW == 2) { 404 | ch = bidx / KW / KH % (in_channels / 2); 405 | mh = bidx / KW / KH / (in_channels / 2); 406 | if (mh >= out_channels / 2) { 407 | ch = ch + in_channels / 2; 408 | } 409 | } 410 | if (GW == in_channels) { 411 | ch = bidx / KW / KH; 412 | mh = bidx / KW / KH; 413 | } 414 | 415 | scalar_t grad = 0; 416 | const int laneId = threadIdx.x % WARP_SIZE; 417 | const int batch = threadIdx.x / WARP_SIZE; 418 | const int nwarps = blockDim.x / WARP_SIZE; 419 | const int imageElements = OW * OH; 420 | 421 | for (int batchIdx = batch; batchIdx < batch_size; batchIdx += nwarps) { 422 | for (int idx = laneId; idx < imageElements; idx += WARP_SIZE) { 423 | int go_w_offset = idx % OW; 424 | int go_h_offset = idx / OW; 425 | 426 | int i_w_offset = go_w_offset * SW + kW - PW; 427 | int i_h_offset = go_h_offset * SH + kH - PH; 428 | 429 | int outputOffset = ((batchIdx * out_channels + mh) * OH) * OW + idx; 430 | if (i_w_offset >= 0 && i_h_offset >= 0 && i_w_offset < IW && i_h_offset < IH) { 431 | int inputOffset = ((batchIdx * in_channels + ch) * IH + i_h_offset) * IW + i_w_offset; 432 | grad += (input[inputOffset] - weight[bidx]) * grad_out[outputOffset]; 433 | } else { 434 | grad += -weight[bidx] * grad_out[outputOffset]; 435 | } 436 | } 437 | } 438 | 439 | __shared__ scalar_t shared[NUM_THREADS]; 440 | shared[threadIdx.x] = grad; 441 | __syncthreads(); 442 | 443 | for (int stride = blockDim.x / 2; stride > 0; stride /= 2) { 444 | if (threadIdx.x < stride) { 445 | shared[threadIdx.x] += shared[threadIdx.x + stride]; 446 | } 447 | __syncthreads(); 448 | } 449 | 450 | if (threadIdx.x == 0) { 451 | scalar_t tval = shared[0]; 452 | if (GW == 1) { 453 | int weightOffset = kW + (KW * kH) + (KW * KH * ch) + (KW * KH * in_channels * mh); 454 | grad_weight[weightOffset] = tval; 455 | } 456 | if (GW == 2) { 457 | if (mh < out_channels / 2) { 458 | int weightOffset = kW + (KW * kH) + (KW * KH * ch) + (KW * KH * (in_channels / 2) * mh); 459 | grad_weight[weightOffset] = tval; 460 | } else { 461 | int weightOffset = kW + (KW * kH) + (KW * KH * (ch - in_channels / 2)) + (KW * KH * (in_channels / 2) * mh); 462 | grad_weight[weightOffset] = tval; 463 | } 464 | } 465 | if (GW == in_channels) { 466 | int weightOffset = kW + (KW * kH) + (KW * KH * 0) + (KW * KH * 1 * mh); 467 | grad_weight[weightOffset] = tval; 468 | } 469 | } 470 | } 471 | 472 | //////////////////////////////////////////////////////////////////////// 473 | ////////////////////////////END OF KERNEL/////////////////////////////// 474 | //////////////////////////////////////////////////////////////////////// 475 | 476 | int adder_cuda_forward( 477 | const at::Tensor &input, 478 | const at::Tensor &weight, 479 | // const at::Tensor &bias, 480 | at::Tensor &output, 481 | int KW, int KH, 482 | int SW, int SH, 483 | int PW, int PH, 484 | int GW, int GH 485 | ) 486 | { 487 | const int batch_size = output.size(0); 488 | const int in_channels = input.size(1); 489 | const int out_channels = output.size(1); 490 | const int IW = input.size(3); 491 | const int IH = input.size(2); 492 | const int OW = output.size(3); 493 | const int OH = output.size(2); 494 | const int num_elem = batch_size * out_channels * OH * OW; 495 | const int num_blocks = get_blocks(num_elem); 496 | 497 | AT_DISPATCH_FLOATING_TYPES(output.scalar_type(), "adder_cuda_forward", ([&] { 498 | adder_forward_kernel<<>>( 499 | input.data_ptr(), 500 | weight.data_ptr(), 501 | output.data_ptr(), 502 | num_elem, 503 | out_channels, 504 | in_channels, 505 | IW, IH, 506 | OW, OH, 507 | KW, KH, 508 | SW, SH, 509 | PW, PH, 510 | GW, GH 511 | ); 512 | })); 513 | // AT_CUDA_CHECK(cudaGetLastError()); 514 | C10_CUDA_CHECK(cudaGetLastError()); 515 | return 1; 516 | } 517 | 518 | /* 519 | scalar_t *grad_out, 520 | scalar_t *weight, 521 | scalar_t *grad_in, 522 | const int num_elem, 523 | const int out_channels, 524 | const int in_channels, 525 | const int IW, const int IH, 526 | const int OW, const int OH, 527 | const int KW, const int KH, 528 | const int SW, const int SH, 529 | const int PW, const int PH, 530 | const int GW, const int GH 531 | */ 532 | 533 | int adder_cuda_backward_grad_in( 534 | at::Tensor &grad_out, 535 | at::Tensor &input, 536 | at::Tensor &weight, 537 | at::Tensor &grad_in, 538 | int KW, int KH, 539 | int SW, int SH, 540 | int PW, int PH, 541 | int GW, int GH 542 | ) 543 | { 544 | const int batch_size = grad_in.size(0); 545 | const int in_channels = grad_in.size(1); 546 | const int out_channels = grad_out.size(1); 547 | const int IW = grad_in.size(3); 548 | const int IH = grad_in.size(2); 549 | const int OW = grad_out.size(3); 550 | const int OH = grad_out.size(2); 551 | const int num_elem = batch_size * in_channels * IH * IW; 552 | const int num_blocks = get_blocks(num_elem); 553 | 554 | AT_DISPATCH_FLOATING_TYPES(grad_in.type(), "adder_cuda_backward_grad_in", ([&] { 555 | adder_backward_grad_in_kernel<<>>( 556 | grad_out.data_ptr(), 557 | input.data_ptr(), 558 | weight.data_ptr(), 559 | grad_in.data_ptr(), 560 | num_elem, 561 | out_channels, 562 | in_channels, 563 | IW, IH, 564 | OW, OH, 565 | KW, KH, 566 | SW, SH, 567 | PW, PH, 568 | GW, GH 569 | ); 570 | })); 571 | // AT_CUDA_CHECK(cudaGetLastError()); 572 | C10_CUDA_CHECK(cudaGetLastError()); 573 | return 1; 574 | } 575 | 576 | int adder_cuda_backward_grad_weight( 577 | at::Tensor &grad_out, 578 | at::Tensor &input, 579 | at::Tensor &weight, 580 | at::Tensor &grad_weight, 581 | int KW, int KH, 582 | int SW, int SH, 583 | int PW, int PH, 584 | int GW, int GH 585 | ) 586 | { 587 | const int batch_size = input.size(0); 588 | const int in_channels = input.size(1); 589 | const int out_channels = grad_out.size(1); 590 | const int IW = input.size(3); 591 | const int IH = input.size(2); 592 | const int OW = grad_out.size(3); 593 | const int OH = grad_out.size(2); 594 | 595 | int blocks = out_channels * in_channels * KH * KW; 596 | 597 | if (GW==2) 598 | { 599 | blocks = out_channels * (in_channels/2) * KH * KW; 600 | } 601 | if (GW==in_channels) 602 | { 603 | blocks = out_channels * 1 * KH * KW; 604 | } 605 | 606 | // Make sure we have enough threads to perform the reduction, and use this number 607 | // to create the shared memory size for the reduction 608 | dim3 grid(blocks); 609 | dim3 block(getGradParamsNumThreads(batch_size)); 610 | // int smem = block.x * sizeof(accreal); 611 | 612 | AT_DISPATCH_FLOATING_TYPES(grad_weight.type(), "adder_cuda_backward_grad_weight", ([&] { 613 | adder_backward_grad_weight_kernel<<>>( 614 | grad_out.data_ptr(), 615 | input.data_ptr(), 616 | weight.data_ptr(), 617 | grad_weight.data_ptr(), 618 | batch_size, 619 | out_channels, 620 | in_channels, 621 | IW, IH, 622 | OW, OH, 623 | KW, KH, 624 | SW, SH, 625 | PW, PH, 626 | GW, GH); 627 | })); 628 | // AT_CUDA_CHECK(cudaGetLastError()); 629 | C10_CUDA_CHECK(cudaGetLastError()); 630 | return 1; 631 | } 632 | 633 | /* 634 | scalar_t *grad_out, 635 | scalar_t *input, 636 | scalar_t *grad_weight, 637 | const int batch_size, 638 | const int out_channels, 639 | const int in_channels, 640 | const int IW, const int IH, 641 | const int OW, const int OH, 642 | const int KW, const int KH, 643 | const int SW, const int SH, 644 | const int PW, const int PH, 645 | const int GW, const int GH 646 | */ -------------------------------------------------------------------------------- /torchshiftadd/layers/extension/adder_cuda_kernel_torch_1.4.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | // #include 4 | // #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include 11 | #define MAX_BLOCKS 256 12 | #define NUM_THREADS 256 13 | #define MAX(a, b) ((a) > (b)) ? (a) : (b) 14 | #define MIN(a, b) ((a) < (b)) ? (a) : (b) 15 | #define HARDTANH(x) ((x) < (-1.0)) ? (-1.0) : (((x) <= (1.0)) ? (x) : (1.0)) 16 | const int WARP_SIZE = 32; 17 | // Crude benchmarks suggest 256 is better than 512 and 1024 18 | // TODO: Autotune/use better heuristics, improve speed more. 19 | const int MAX_BLOCK_SIZE = 256; 20 | 21 | template 22 | struct SharedMem { 23 | __device__ T *getPointer() 24 | { 25 | extern __device__ void error(void); 26 | error(); 27 | return NULL; 28 | } 29 | }; 30 | 31 | 32 | static int getGradParamsNumThreads(int batchSize) 33 | { 34 | //warp per item in a batch, up to a maximum 35 | return std::min(batchSize * WARP_SIZE, MAX_BLOCK_SIZE); 36 | } 37 | 38 | int get_blocks(int n) 39 | { 40 | // return MAX(1, MIN(MAX_BLOCKS, (n - NUM_THREADS + 1) / NUM_THREADS)); 41 | return MIN(MAX_BLOCKS, (n - NUM_THREADS + 1) / NUM_THREADS) + 1; 42 | } 43 | 44 | template 45 | __global__ void adder_forward_kernel( 46 | const scalar_t const *input, 47 | const scalar_t const *weight, 48 | // const scalar_t const *bias, 49 | scalar_t *output, 50 | const int num_elem, 51 | const int out_channels, 52 | const int in_channels, 53 | const int IW, const int IH, 54 | const int OW, const int OH, 55 | const int KW, const int KH, 56 | const int SW, const int SH, 57 | const int PW, const int PH, 58 | const int GW, const int GH) 59 | { 60 | // #TODO: 61 | if (GW==1) 62 | { 63 | for (int index = blockDim.x * blockIdx.x + threadIdx.x; index < num_elem; index += gridDim.x * blockDim.x) 64 | { 65 | const int n = index / OW / OH / out_channels; 66 | const int m = index / OW / OH % out_channels; 67 | const int h = index / OW % OH; 68 | const int w = index % OW; 69 | 70 | const scalar_t *p_weight = weight + m * in_channels * KH * KW; //the start position of the kernel(corresponding to the output) 71 | // scalar_t value = bias[m]; 72 | scalar_t value = 0; 73 | // #TODO: 74 | // #pragma unroll 75 | for (int cc = 0; cc < in_channels; cc++) 76 | { 77 | // #pragma unroll 78 | const int image_offset0 = (n * in_channels + cc) * IH * IW; 79 | for (int kh = 0; kh < KH; kh++) 80 | { 81 | // #pragma unroll 82 | for (int kw = 0; kw < KW; kw++) 83 | { 84 | const int ih = h * SH - PH + kh; 85 | const int iw = w * SW - PW + kw; 86 | 87 | bool boundary_condition = (ih >= 0) && (ih < IH) && (iw >= 0) && (iw < IW); 88 | if (boundary_condition) 89 | { 90 | // value += input[image_offset0 + ih * IW + iw] * (*p_weight); 91 | value -= abs(input[image_offset0 + ih * IW + iw] - (*p_weight)); 92 | } 93 | else // padded area 94 | { 95 | value -= abs(*p_weight); 96 | } 97 | p_weight++; 98 | } 99 | } 100 | } 101 | output[index] = value; 102 | } 103 | } 104 | if (GW==2) 105 | { 106 | for (int index = blockDim.x * blockIdx.x + threadIdx.x; index < num_elem; index += gridDim.x * blockDim.x) 107 | // total size of output (batch_size * out_channels * W-out * H_out) 108 | { 109 | // TODO 110 | const int n = index / OW / OH / out_channels; // batch size of n 111 | const int m = index / OW / OH % out_channels; // relative output channel size 112 | const int h = index / OW % OH; // relative position of H 113 | const int w = index % OW; //relative position of W 114 | 115 | const scalar_t *p_weight = weight + m * in_channels/2 * KH * KW; //the start position of the kernel(corresponding to the output) 116 | 117 | // scalar_t value = bias[m]; 118 | scalar_t value = 0; 119 | // #TODO: 120 | // #pragma unroll 121 | if (m < out_channels/2) 122 | { 123 | for (int cc = 0; cc < in_channels/2; cc++) 124 | { 125 | // #pragma unroll 126 | const int image_offset0 = (n * in_channels + cc) * IH * IW; //channel offset (absolute) 127 | for (int kh = 0; kh < KH; kh++) 128 | { 129 | // #pragma unroll 130 | for (int kw = 0; kw < KW; kw++) 131 | { 132 | const int ih = h * SH - PH + kh; // *stride-padding of H 133 | const int iw = w * SW - PW + kw; // *stride-padding of W 134 | 135 | bool boundary_condition = (ih >= 0) && (ih < IH) && (iw >= 0) && (iw < IW); 136 | if (boundary_condition) 137 | { 138 | // value += input[image_offset0 + ih * IW + iw] * (*p_weight); 139 | value -= abs(input[image_offset0 + ih * IW + iw] - (*p_weight)); //pure operation 140 | } 141 | else // padded area 142 | { 143 | value -= abs(*p_weight); 144 | } 145 | p_weight++; 146 | } 147 | } 148 | } 149 | } 150 | else 151 | { 152 | for (int cc = in_channels/2; cc < in_channels; cc++) 153 | { 154 | // #pragma unroll 155 | const int image_offset0 = (n * in_channels + cc) * IH * IW; //channel offset (absolute) 156 | for (int kh = 0; kh < KH; kh++) 157 | { 158 | // #pragma unroll 159 | for (int kw = 0; kw < KW; kw++) 160 | { 161 | const int ih = h * SH - PH + kh; // *stride-padding of H 162 | const int iw = w * SW - PW + kw; // *stride-padding of W 163 | 164 | bool boundary_condition = (ih >= 0) && (ih < IH) && (iw >= 0) && (iw < IW); 165 | if (boundary_condition) 166 | { 167 | // value += input[image_offset0 + ih * IW + iw] * (*p_weight); 168 | value -= abs(input[image_offset0 + ih * IW + iw] - (*p_weight)); //pure operation 169 | } 170 | else // padded area 171 | { 172 | value -= abs(*p_weight); 173 | } 174 | p_weight++; 175 | } 176 | } 177 | } 178 | } 179 | output[index] = value; 180 | } 181 | } 182 | if (GW==in_channels) //Dpws 183 | { 184 | for (int index = blockDim.x * blockIdx.x + threadIdx.x; index < num_elem; index += gridDim.x * blockDim.x) 185 | { 186 | const int n = index / OW / OH / out_channels; 187 | const int m = index / OW / OH % out_channels; 188 | const int h = index / OW % OH; 189 | const int w = index % OW; 190 | 191 | const scalar_t *p_weight = weight + m * 1 * KH * KW; 192 | // scalar_t value = bias[m]; 193 | scalar_t value = 0; 194 | // #TODO: 195 | // #pragma unroll 196 | // #pragma unroll 197 | const int image_offset0 = (n * in_channels + m) * IH * IW; 198 | for (int kh = 0; kh < KH; kh++) 199 | { 200 | // #pragma unroll 201 | for (int kw = 0; kw < KW; kw++) 202 | { 203 | const int ih = h * SH - PH + kh; 204 | const int iw = w * SW - PW + kw; 205 | 206 | bool boundary_condition = (ih >= 0) && (ih < IH) && (iw >= 0) && (iw < IW); 207 | if (boundary_condition) 208 | { 209 | // value += input[image_offset0 + ih * IW + iw] * (*p_weight); 210 | value -= abs(input[image_offset0 + ih * IW + iw] - (*p_weight)); 211 | } 212 | else // padded area 213 | { 214 | value -= abs(*p_weight); 215 | } 216 | p_weight++; 217 | } 218 | } 219 | output[index] = value; 220 | } 221 | } 222 | 223 | } 224 | 225 | template 226 | __global__ void adder_backward_grad_in_kernel( 227 | scalar_t *grad_out, 228 | scalar_t *input, 229 | scalar_t *weight, 230 | scalar_t *grad_in, 231 | const int num_elem, 232 | const int out_channels, 233 | const int in_channels, 234 | const int IW, const int IH, 235 | const int OW, const int OH, 236 | const int KW, const int KH, 237 | const int SW, const int SH, 238 | const int PW, const int PH, 239 | const int GW, const int GH) 240 | { if (GW==1) 241 | { 242 | for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_elem; index += gridDim.x * blockDim.x) 243 | { 244 | const int n = index / IW / IH / in_channels; 245 | const int c = index / IW / IH % in_channels; 246 | const int h = index / IW % IH; 247 | const int w = index % IW; 248 | 249 | scalar_t value = 0; 250 | for (int mm = 0; mm < out_channels; mm++) 251 | { 252 | const int grad_out_offset0 = (n * out_channels + mm) * OH * OW; 253 | scalar_t *p_weight = weight + (mm * in_channels + c) * KH * KW; 254 | for (int kh = 0; kh < KH; kh++) 255 | { 256 | for (int kw = 0; kw < KW; kw++) 257 | { 258 | int oh = h + PH - kh; 259 | int ow = w + PW - kw; 260 | 261 | if ((oh % SH == 0) && (ow % SW == 0)) 262 | { 263 | const bool boundary_condition = (oh >= 0) && (oh < OH) && (ow >= 0) && (ow < OW); 264 | if (boundary_condition) 265 | { 266 | oh = oh / SH; 267 | ow = ow / SW; 268 | // value += grad_out[grad_out_offset0 + oh * OW + ow] * (*p_weight); 269 | scalar_t ht = HARDTANH(*p_weight - input[index]); 270 | value += grad_out[grad_out_offset0 + oh * OW + ow] * ht; 271 | } 272 | } 273 | p_weight++; 274 | } 275 | } 276 | } 277 | grad_in[index] = value; 278 | } 279 | } 280 | if (GW==2) 281 | { 282 | for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_elem; index += gridDim.x * blockDim.x) 283 | { 284 | const int n = index / IW / IH / in_channels; 285 | const int c = index / IW / IH % in_channels; 286 | const int h = index / IW % IH; 287 | const int w = index % IW; 288 | 289 | scalar_t value = 0; 290 | if (c < in_channels/2) 291 | { 292 | for (int mm = 0; mm < out_channels/2; mm++) 293 | { 294 | const int grad_out_offset0 = (n * out_channels + mm) * OH * OW; 295 | scalar_t *p_weight = weight + (mm * in_channels/2 + c) * KH * KW; 296 | for (int kh = 0; kh < KH; kh++) 297 | { 298 | for (int kw = 0; kw < KW; kw++) 299 | { 300 | int oh = h + PH - kh; 301 | int ow = w + PW - kw; 302 | 303 | if ((oh % SH == 0) && (ow % SW == 0)) 304 | { 305 | const bool boundary_condition = (oh >= 0) && (oh < OH) && (ow >= 0) && (ow < OW); 306 | if (boundary_condition) 307 | { 308 | oh = oh / SH; 309 | ow = ow / SW; 310 | // value += grad_out[grad_out_offset0 + oh * OW + ow] * (*p_weight); 311 | scalar_t ht = HARDTANH(*p_weight - input[index]); 312 | value += grad_out[grad_out_offset0 + oh * OW + ow] * ht; 313 | } 314 | } 315 | p_weight++; 316 | } 317 | } 318 | } 319 | } 320 | else 321 | { 322 | for (int mm = out_channels/2; mm < out_channels; mm++) 323 | { 324 | const int grad_out_offset0 = (n * out_channels + mm) * OH * OW; 325 | scalar_t *p_weight = weight + (mm * in_channels/2 + c - out_channels/2) * KH * KW; 326 | for (int kh = 0; kh < KH; kh++) 327 | { 328 | for (int kw = 0; kw < KW; kw++) 329 | { 330 | int oh = h + PH - kh; 331 | int ow = w + PW - kw; 332 | 333 | if ((oh % SH == 0) && (ow % SW == 0)) 334 | { 335 | const bool boundary_condition = (oh >= 0) && (oh < OH) && (ow >= 0) && (ow < OW); 336 | if (boundary_condition) 337 | { 338 | oh = oh / SH; 339 | ow = ow / SW; 340 | // value += grad_out[grad_out_offset0 + oh * OW + ow] * (*p_weight); 341 | scalar_t ht = HARDTANH(*p_weight - input[index]); 342 | value += grad_out[grad_out_offset0 + oh * OW + ow] * ht; 343 | } 344 | } 345 | p_weight++; 346 | } 347 | } 348 | } 349 | } 350 | grad_in[index] = value; 351 | } 352 | } 353 | if (GW==in_channels) 354 | { 355 | for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_elem; index += gridDim.x * blockDim.x) 356 | { 357 | const int n = index / IW / IH / in_channels; 358 | const int c = index / IW / IH % in_channels; 359 | const int h = index / IW % IH; 360 | const int w = index % IW; 361 | 362 | scalar_t value = 0; 363 | 364 | const int grad_out_offset0 = (n * out_channels + c) * OH * OW; 365 | scalar_t *p_weight = weight + c * KH * KW; 366 | for (int kh = 0; kh < KH; kh++) 367 | { 368 | for (int kw = 0; kw < KW; kw++) 369 | { 370 | int oh = h + PH - kh; 371 | int ow = w + PW - kw; 372 | 373 | if ((oh % SH == 0) && (ow % SW == 0)) 374 | { 375 | const bool boundary_condition = (oh >= 0) && (oh < OH) && (ow >= 0) && (ow < OW); 376 | if (boundary_condition) 377 | { 378 | oh = oh / SH; 379 | ow = ow / SW; 380 | // value += grad_out[grad_out_offset0 + oh * OW + ow] * (*p_weight); 381 | scalar_t ht = HARDTANH(*p_weight - input[index]); 382 | value += grad_out[grad_out_offset0 + oh * OW + ow] * ht; 383 | } 384 | } 385 | p_weight++; 386 | } 387 | } 388 | grad_in[index] = value; 389 | } 390 | } 391 | } 392 | 393 | template 394 | __global__ void adder_backward_grad_weight_kernel( 395 | scalar_t *grad_out, 396 | scalar_t *input, 397 | scalar_t *weight, 398 | scalar_t *grad_weight, 399 | const int batch_size, 400 | const int out_channels, 401 | const int in_channels, 402 | const int IW, const int IH, 403 | const int OW, const int OH, 404 | const int KW, const int KH, 405 | const int SW, const int SH, 406 | const int PW, const int PH, 407 | const int GW, const int GH) 408 | { 409 | SharedMem smem; 410 | int bidx = blockIdx.x; 411 | int kW = bidx % KW; 412 | int kH = bidx / KW % KH; 413 | int ch = bidx / KW / KH % in_channels; 414 | int mh = bidx / KW / KH / in_channels; 415 | 416 | if (GW==2) 417 | { 418 | ch = bidx / KW / KH % (in_channels/2); 419 | mh = bidx / KW / KH / (in_channels/2); 420 | if (mh >= out_channels/2) 421 | { 422 | ch = ch + in_channels/2; 423 | } 424 | } 425 | if (GW==in_channels) 426 | { 427 | ch = bidx / KW / KH; 428 | mh = bidx / KW / KH; 429 | } 430 | 431 | scalar_t grad = 0; 432 | const int laneId = threadIdx.x % WARP_SIZE; 433 | const int batch = threadIdx.x / WARP_SIZE; 434 | const int nwarps = blockDim.x / WARP_SIZE; 435 | const int imageElements = OW * OH; 436 | for (int batchIdx = batch; batchIdx < batch_size; batchIdx += nwarps) 437 | { 438 | // Warp-stride loop over elements in a batch item 439 | for (int idx = laneId; idx < imageElements; idx += WARP_SIZE) 440 | { 441 | // Need to calculate the following: batch position, and offset into the gradOutput 442 | // in height, and width. We can intuit the corresponding position in the input from 443 | // the other parameters we have 444 | int go_w_offset = idx % OW; 445 | int go_h_offset = (idx / OW); 446 | 447 | int i_w_offset = go_w_offset * SW + kW - PW; 448 | int i_h_offset = go_h_offset * SH + kH - PH; 449 | 450 | int outputOffset = ((batchIdx * out_channels + mh) * OH) * OW + idx; 451 | if (i_w_offset >= 0 && i_h_offset >= 0 && i_w_offset < IW && i_h_offset < IH) 452 | { 453 | int inputOffset = ((batchIdx * in_channels + ch) * IH + i_h_offset) * IW + i_w_offset; 454 | // int outputOffset = ((batchIdx * out_channels + mh) * OH) * OW + idx; 455 | // grad += input[inputOffset] * grad_out[outputOffset]; 456 | grad += (input[inputOffset] - weight[bidx]) * grad_out[outputOffset]; 457 | } 458 | else // padded area 459 | { 460 | grad += - weight[bidx] * grad_out[outputOffset]; 461 | } 462 | } 463 | } 464 | __syncthreads(); 465 | scalar_t *buf = smem.getPointer(); 466 | scalar_t tval = reduceBlock>( 467 | buf, blockDim.x, grad, ReduceAdd(), 0); 468 | 469 | // After reduction, first thread in the block has the gradient, so its responsible 470 | // for writing it to gradWeight 471 | if (threadIdx.x == 0) 472 | { 473 | if (GW==1) 474 | { 475 | int weightOffset = kW + (KW * kH) + (KW * KH * ch) + (KW * KH * in_channels * mh); 476 | grad_weight[weightOffset] = tval; 477 | } 478 | if (GW==2) 479 | { 480 | if (mh < out_channels/2) 481 | { 482 | int weightOffset = kW + (KW * kH) + (KW * KH * ch) + (KW * KH * in_channels/2 * mh); 483 | grad_weight[weightOffset] = tval; 484 | } 485 | else 486 | { 487 | int weightOffset = kW + (KW * kH) + (KW * KH * (ch - in_channels/2)) + (KW * KH * in_channels/2 * mh); 488 | grad_weight[weightOffset] = tval; 489 | } 490 | 491 | } 492 | if (GW==in_channels) 493 | { 494 | int weightOffset = kW + (KW * kH) + (KW * KH * 0) + (KW * KH * 1 * mh); 495 | grad_weight[weightOffset] = tval; 496 | } 497 | } 498 | } 499 | 500 | //////////////////////////////////////////////////////////////////////// 501 | ////////////////////////////END OF KERNEL/////////////////////////////// 502 | //////////////////////////////////////////////////////////////////////// 503 | 504 | int adder_cuda_forward( 505 | const at::Tensor &input, 506 | const at::Tensor &weight, 507 | // const at::Tensor &bias, 508 | at::Tensor &output, 509 | int KW, int KH, 510 | int SW, int SH, 511 | int PW, int PH, 512 | int GW, int GH 513 | ) 514 | { 515 | const int batch_size = output.size(0); 516 | const int in_channels = input.size(1); 517 | const int out_channels = output.size(1); 518 | const int IW = input.size(3); 519 | const int IH = input.size(2); 520 | const int OW = output.size(3); 521 | const int OH = output.size(2); 522 | const int num_elem = batch_size * out_channels * OH * OW; 523 | const int num_blocks = get_blocks(num_elem); 524 | 525 | AT_DISPATCH_FLOATING_TYPES(output.type(), "adder_cuda_forward", ([&] { 526 | adder_forward_kernel<<>>( 527 | input.data(), 528 | weight.data(), 529 | // bias.data(), 530 | output.data(), 531 | num_elem, 532 | out_channels, 533 | in_channels, 534 | IW, IH, 535 | OW, OH, 536 | KW, KH, 537 | SW, SH, 538 | PW, PH, 539 | GW, GH 540 | ); 541 | })); 542 | AT_CUDA_CHECK(cudaGetLastError()); 543 | return 1; 544 | } 545 | 546 | /* 547 | scalar_t *grad_out, 548 | scalar_t *weight, 549 | scalar_t *grad_in, 550 | const int num_elem, 551 | const int out_channels, 552 | const int in_channels, 553 | const int IW, const int IH, 554 | const int OW, const int OH, 555 | const int KW, const int KH, 556 | const int SW, const int SH, 557 | const int PW, const int PH, 558 | const int GW, const int GH 559 | */ 560 | 561 | int adder_cuda_backward_grad_in( 562 | at::Tensor &grad_out, 563 | at::Tensor &input, 564 | at::Tensor &weight, 565 | at::Tensor &grad_in, 566 | int KW, int KH, 567 | int SW, int SH, 568 | int PW, int PH, 569 | int GW, int GH 570 | ) 571 | { 572 | const int batch_size = grad_in.size(0); 573 | const int in_channels = grad_in.size(1); 574 | const int out_channels = grad_out.size(1); 575 | const int IW = grad_in.size(3); 576 | const int IH = grad_in.size(2); 577 | const int OW = grad_out.size(3); 578 | const int OH = grad_out.size(2); 579 | const int num_elem = batch_size * in_channels * IH * IW; 580 | const int num_blocks = get_blocks(num_elem); 581 | 582 | AT_DISPATCH_FLOATING_TYPES(grad_in.type(), "adder_cuda_backward_grad_in", ([&] { 583 | adder_backward_grad_in_kernel<<>>( 584 | grad_out.data(), 585 | input.data(), 586 | weight.data(), 587 | grad_in.data(), 588 | num_elem, 589 | out_channels, 590 | in_channels, 591 | IW, IH, 592 | OW, OH, 593 | KW, KH, 594 | SW, SH, 595 | PW, PH, 596 | GW, GH 597 | ); 598 | })); 599 | AT_CUDA_CHECK(cudaGetLastError()); 600 | return 1; 601 | } 602 | 603 | int adder_cuda_backward_grad_weight( 604 | at::Tensor &grad_out, 605 | at::Tensor &input, 606 | at::Tensor &weight, 607 | at::Tensor &grad_weight, 608 | int KW, int KH, 609 | int SW, int SH, 610 | int PW, int PH, 611 | int GW, int GH 612 | ) 613 | { 614 | const int batch_size = input.size(0); 615 | const int in_channels = input.size(1); 616 | const int out_channels = grad_out.size(1); 617 | const int IW = input.size(3); 618 | const int IH = input.size(2); 619 | const int OW = grad_out.size(3); 620 | const int OH = grad_out.size(2); 621 | 622 | int blocks = out_channels * in_channels * KH * KW; 623 | 624 | if (GW==2) 625 | { 626 | blocks = out_channels * (in_channels/2) * KH * KW; 627 | } 628 | if (GW==in_channels) 629 | { 630 | blocks = out_channels * 1 * KH * KW; 631 | } 632 | 633 | // Make sure we have enough threads to perform the reduction, and use this number 634 | // to create the shared memory size for the reduction 635 | dim3 grid(blocks); 636 | dim3 block(getGradParamsNumThreads(batch_size)); 637 | // int smem = block.x * sizeof(accreal); 638 | 639 | AT_DISPATCH_FLOATING_TYPES(grad_weight.type(), "adder_cuda_backward_grad_weight", ([&] { 640 | adder_backward_grad_weight_kernel<<>>( 641 | grad_out.data(), 642 | input.data(), 643 | weight.data(), 644 | grad_weight.data(), 645 | batch_size, 646 | out_channels, 647 | in_channels, 648 | IW, IH, 649 | OW, OH, 650 | KW, KH, 651 | SW, SH, 652 | PW, PH, 653 | GW, GH); 654 | })); 655 | AT_CUDA_CHECK(cudaGetLastError()); 656 | return 1; 657 | } 658 | 659 | /* 660 | scalar_t *grad_out, 661 | scalar_t *input, 662 | scalar_t *grad_weight, 663 | const int batch_size, 664 | const int out_channels, 665 | const int in_channels, 666 | const int IW, const int IH, 667 | const int OW, const int OH, 668 | const int KW, const int KH, 669 | const int SW, const int SH, 670 | const int PW, const int PH, 671 | const int GW, const int GH 672 | */ -------------------------------------------------------------------------------- /torchshiftadd/layers/shift.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import torch.nn as nn 4 | from torch.nn import init 5 | import torch.nn.functional as F 6 | from torch.autograd import Function 7 | from torch.nn.modules.utils import _pair 8 | import torchshiftadd.utils.ste as ste 9 | from torchshiftadd.utils.quantize import quantize_grad 10 | 11 | log2 = math.log(2) 12 | 13 | ###### FC 14 | 15 | class LinearShiftFunction(Function): 16 | 17 | @staticmethod 18 | def forward(ctx, input, shift, sign, bias=None, conc_weight=None, use_kernel=False, use_cuda=True, rounding='deterministic', shift_range=(-14, 0)): 19 | fraction_bits = 16 20 | integer_bit = 16 21 | 22 | sign = sign.clamp(-1,1) 23 | shift = shift.clamp(*shift_range) 24 | input.data = ste.round_to_fixed(input.data, fraction_bits, integer_bit) 25 | if bias is not None: 26 | bias.data = ste.round_to_fixed(bias.data, fraction_bits, integer_bit) 27 | 28 | v = 2**shift.round() * sign.round().sign() 29 | out = input.mm(v.t()) 30 | if bias is not None: 31 | out += bias.unsqueeze(0).expand_as(out) 32 | 33 | ctx.save_for_backward(input, shift, sign, bias, v) 34 | 35 | return out 36 | 37 | @staticmethod 38 | def backward(ctx, grad_output): 39 | 40 | input, shift, sign, bias, v = ctx.saved_tensors 41 | grad_input = grad_shift = grad_sign = grad_bias = None 42 | 43 | if ctx.needs_input_grad[0]: 44 | grad_input = grad_output.mm(v) 45 | if ctx.needs_input_grad[2]: 46 | grad_sign = grad_output.t().mm(input) 47 | if ctx.needs_input_grad[1]: 48 | if grad_sign is None: 49 | grad_shift = grad_output.t().mm(input) * v * log2 50 | else: 51 | grad_shift = grad_sign * v * log2 52 | if bias is not None and ctx.needs_input_grad[3]: 53 | grad_bias = grad_output.sum(0).squeeze(0) 54 | 55 | return grad_input, grad_shift, grad_sign, grad_bias, None, None, None 56 | 57 | 58 | class LinearShift(nn.Module): 59 | def __init__(self, in_features, out_features, bias=True, check_grad=False, freeze_sign=False, use_kernel=False, use_cuda=True, rounding='deterministic', weight_bits=5, threshold=None): 60 | super(LinearShift, self).__init__() 61 | self.in_features = in_features 62 | self.out_features = out_features 63 | self.use_kernel = use_kernel 64 | self.check_grad = check_grad 65 | self.use_cuda = use_cuda 66 | self.conc_weight = None 67 | self.rounding = rounding 68 | self.shift_range = (-1 * (2**(weight_bits - 1) - 2), 0) # we use ternary weights to represent sign 69 | self.threshold = threshold 70 | print(self.shift_range) 71 | 72 | if check_grad: 73 | tensor_constructor = torch.DoubleTensor # double precision required to check grad 74 | else: 75 | tensor_constructor = torch.Tensor # In PyTorch torch.Tensor is alias torch.FloatTensor 76 | 77 | self.shift = nn.Parameter(tensor_constructor(out_features, in_features)) 78 | self.sign = nn.Parameter(tensor_constructor(out_features, in_features), requires_grad = (freeze_sign == False)) 79 | 80 | if bias: 81 | self.bias = nn.Parameter(tensor_constructor(out_features)) 82 | else: 83 | self.register_parameter('bias', None) 84 | 85 | self.reset_parameters() 86 | 87 | def reset_parameters(self): 88 | self.shift.data.uniform_(*self.shift_range) 89 | self.sign.data.uniform_(-1, 1) 90 | 91 | if self.bias is not None: 92 | fan_in, _ = init._calculate_fan_in_and_fan_out(self.shift) 93 | bound = 1 / math.sqrt(fan_in) 94 | init.uniform_(self.bias, -bound, bound) 95 | 96 | def forward(self, input): 97 | self.shift.data = ste.clamp(self.shift.data, *self.shift_range) 98 | shift_rounded = ste.round(self.shift, rounding=self.rounding) 99 | if self.threshold == None: 100 | sign_rounded_signed = ste.sign(ste.round(self.sign, rounding=self.rounding)) 101 | else: 102 | sign_rounded_signed = ste.sign(round(self.sign, self.threshold)) 103 | weight_ps = ste.unsym_grad_mul(2**shift_rounded, sign_rounded_signed) 104 | 105 | if self.use_kernel: 106 | return LinearShiftFunction.apply(input, self.shift, self.sign, self.bias, self.conc_weight, self.use_kernel, self.use_cuda, self.rounding, self.shift_range) 107 | else: 108 | return torch.nn.functional.linear(input, weight_ps, self.bias) 109 | 110 | def extra_repr(self): 111 | return 'in_features={}, out_features={}, bias={}'.format( 112 | self.in_features, self.out_features, self.bias is not None 113 | ) 114 | 115 | ##### Conv 116 | 117 | class _ConvNdShift(nn.Module): 118 | 119 | __constants__ = ['stride', 'padding', 'dilation', 'groups', 'bias', 'padding_mode'] 120 | 121 | def __init__(self, in_channels, out_channels, kernel_size, stride, 122 | padding, dilation, transposed, output_padding, 123 | groups, bias, padding_mode, 124 | check_grad=False, freeze_sign=False, 125 | rounding='deterministic', weight_bits=5): 126 | super(_ConvNdShift, self).__init__() 127 | if in_channels % groups != 0: 128 | raise ValueError('in_channels must be divisible by groups') 129 | if out_channels % groups != 0: 130 | raise ValueError('out_channels must be divisible by groups') 131 | self.in_channels = in_channels 132 | self.out_channels = out_channels 133 | self.kernel_size = kernel_size 134 | self.stride = stride 135 | self.padding = padding 136 | self.dilation = dilation 137 | self.transposed = transposed 138 | self.output_padding = output_padding 139 | self.groups = groups 140 | self.padding_mode = padding_mode 141 | self.rounding=rounding 142 | self.shift_range = (-1 * (2**(weight_bits - 1) - 2), 0) # we use ternary weights to represent sign 143 | # for ps 144 | # self.shift_range = (-1 * weight_bits, 0) 145 | 146 | if check_grad: 147 | tensor_constructor = torch.DoubleTensor # double precision required to check grad 148 | else: 149 | tensor_constructor = torch.Tensor # In PyTorch torch.Tensor is alias torch.FloatTensor 150 | 151 | if transposed: 152 | self.shift = nn.Parameter(tensor_constructor( 153 | in_channels, out_channels // groups, *kernel_size)) 154 | self.sign = nn.Parameter(tensor_constructor( 155 | in_channels, out_channels // groups, *kernel_size), 156 | requires_grad = (freeze_sign == False)) 157 | else: 158 | self.shift = nn.Parameter(tensor_constructor( 159 | out_channels, in_channels // groups, *kernel_size)) 160 | self.sign = nn.Parameter(tensor_constructor( 161 | out_channels, in_channels // groups, *kernel_size), 162 | requires_grad = (freeze_sign == False)) 163 | if bias: 164 | self.bias = nn.Parameter(tensor_constructor(out_channels)) 165 | else: 166 | self.register_parameter('bias', None) 167 | self.reset_parameters(weight_bits) 168 | 169 | def reset_parameters(self, weight_bits): 170 | self.shift.data.uniform_(*self.shift_range) # (-0.1, 0.1) 171 | self.sign.data.uniform_(-1, 1) # (-0.1, 0.1) 172 | 173 | if self.bias is not None: 174 | fan_in, _ = init._calculate_fan_in_and_fan_out(self.shift) 175 | bound = 1 / math.sqrt(fan_in) 176 | init.uniform_(self.bias, -bound, bound) 177 | 178 | def extra_repr(self): 179 | s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}' 180 | ', stride={stride}') 181 | if self.padding != (0,) * len(self.padding): 182 | s += ', padding={padding}' 183 | if self.dilation != (1,) * len(self.dilation): 184 | s += ', dilation={dilation}' 185 | if self.output_padding != (0,) * len(self.output_padding): 186 | s += ', output_padding={output_padding}' 187 | if self.groups != 1: 188 | s += ', groups={groups}' 189 | if self.bias is None: 190 | s += ', bias=False' 191 | return s.format(**self.__dict__) 192 | 193 | 194 | class Conv2dShift(_ConvNdShift): 195 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 196 | padding=0, dilation=1, groups=1, 197 | bias=True, padding_mode='zeros', 198 | check_grad=False, freeze_sign=False, use_kernel=False, use_cuda=True, rounding='deterministic', weight_bits=5, threshold=0.3, input_bits=16): 199 | kernel_size = _pair(kernel_size) 200 | stride = _pair(stride) 201 | padding = _pair(padding) 202 | dilation = _pair(dilation) 203 | self.use_kernel = use_kernel 204 | self.use_cuda = use_cuda 205 | self.conc_weight = None 206 | self.threshold = threshold 207 | self.input_bits = input_bits 208 | super(Conv2dShift, self).__init__( 209 | in_channels, out_channels, kernel_size, stride, padding, dilation, 210 | False, _pair(0), groups, bias, padding_mode, 211 | check_grad, freeze_sign, rounding, weight_bits) 212 | 213 | #@weak_script_method 214 | def forward(self, input): 215 | self.shift.data = ste.clamp(self.shift.data, *self.shift_range) 216 | shift_rounded = ste.round(self.shift, self.rounding) 217 | 218 | if self.threshold is None: 219 | sign_rounded_signed = ste.sign(ste.round(self.sign, self.rounding)) 220 | else: 221 | sign_rounded_signed = ste.sign(ste.myround(self.sign, self.threshold)) 222 | weight_ps = ste.unsym_grad_mul(2**shift_rounded, sign_rounded_signed) 223 | 224 | input_fixed_point = ste.round_fixed_point(input, quant_bits=self.input_bits) 225 | 226 | if self.bias is not None: 227 | bias_fixed_point = ste.round_fixed_point(self.bias, quant_bits=self.input_bits) 228 | else: 229 | bias_fixed_point = None 230 | 231 | if self.padding_mode == 'circular': 232 | expanded_padding = ((self.padding[1] + 1) // 2, self.padding[1] // 2, 233 | (self.padding[0] + 1) // 2, self.padding[0] // 2) 234 | 235 | input_padded = F.pad(input_fixed_point, expanded_padding, mode='circular') 236 | padding = _pair(0) 237 | else: 238 | input_padded = input_fixed_point 239 | padding = self.padding 240 | 241 | 242 | output = torch.nn.functional.conv2d(input_padded, weight_ps, bias_fixed_point, 243 | self.stride, padding, self.dilation, self.groups) 244 | 245 | # quantize backpropogation 246 | if self.input_bits > 0: 247 | output = quantize_grad(output, num_bits=self.input_bits, flatten_dims=(1, -1)) 248 | 249 | return output -------------------------------------------------------------------------------- /torchshiftadd/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet20 import resnet20 2 | from .resnet20_adder import resnet20_adder 3 | from .resnet20_shift import convert_to_shift 4 | from .resnet20_shiftadd import resnet20_shiftadd -------------------------------------------------------------------------------- /torchshiftadd/models/resnet20.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | __all__ = ['resnet20'] 6 | 7 | 8 | def conv3x3(in_planes, out_planes, stride=1): 9 | return nn.Conv2d( 10 | in_planes, 11 | out_planes, 12 | kernel_size=3, 13 | stride=stride, 14 | padding=1, 15 | bias=False, 16 | ) 17 | 18 | 19 | class BasicBlock(nn.Module): 20 | 21 | expansion=1 22 | 23 | def __init__(self, inplanes, planes, stride=1, downsample=None): 24 | super(BasicBlock, self).__init__() 25 | self.conv1 = conv3x3(inplanes, planes, stride=stride) 26 | self.bn1 = nn.BatchNorm2d(planes) 27 | self.relu = nn.ReLU(inplace=True) 28 | self.conv2 = conv3x3(planes, planes) 29 | self.bn2 = nn.BatchNorm2d(planes) 30 | self.downsample = downsample 31 | self.stride = stride 32 | 33 | def forward(self, x): 34 | residual = x 35 | 36 | out = self.conv1(x) 37 | out = self.bn1(out) 38 | out = self.relu(out) 39 | 40 | out = self.conv2(out) 41 | out = self.bn2(out) 42 | 43 | if self.downsample is not None: 44 | residual = self.downsample(x) 45 | 46 | out += residual 47 | out = self.relu(out) 48 | 49 | return out 50 | 51 | class ResNet(nn.Module): 52 | 53 | def __init__(self, block, layers, num_classes=10): 54 | super(ResNet, self).__init__() 55 | self.inplanes = 16 56 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 57 | self.bn1 = nn.BatchNorm2d(16) 58 | self.relu = nn.ReLU(inplace=True) 59 | self.layer1 = self._make_layer(block, 16, layers[0]) 60 | self.layer2 = self._make_layer(block, 32, layers[1], stride=2) 61 | self.layer3 = self._make_layer(block, 64, layers[2], stride=2) 62 | self.avgpool = nn.AvgPool2d(8, stride=1) 63 | self.fc = nn.Conv2d(64 * block.expansion, num_classes, 1, bias=False) 64 | self.bn2 = nn.BatchNorm2d(num_classes) 65 | 66 | 67 | # init 68 | for m in self.modules(): 69 | if isinstance(m, nn.BatchNorm2d): 70 | m.weight.data.fill_(1) 71 | m.bias.data.zero_() 72 | 73 | def _make_layer(self, block, planes, blocks, stride=1): 74 | downsample = None 75 | if stride != 1 or self.inplanes != planes * block.expansion: 76 | downsample = nn.Sequential( 77 | nn.Conv2d( 78 | self.inplanes, 79 | planes * block.expansion, 80 | kernel_size=1, 81 | stride=stride, 82 | bias=False, 83 | ), 84 | nn.BatchNorm2d(planes * block.expansion) 85 | ) 86 | 87 | layers = [] 88 | layers.append( 89 | block( 90 | inplanes=self.inplanes, 91 | planes=planes, 92 | stride=stride, 93 | downsample=downsample, 94 | ) 95 | ) 96 | self.inplanes = planes * block.expansion 97 | for _ in range(1, blocks): 98 | layers.append( 99 | block( 100 | inplanes=self.inplanes, 101 | planes=planes, 102 | ) 103 | ) 104 | 105 | return nn.Sequential(*layers) 106 | 107 | def forward(self, x): 108 | x = self.conv1(x) 109 | x = self.bn1(x) 110 | x = self.relu(x) 111 | 112 | x = self.layer1(x) 113 | x = self.layer2(x) 114 | x = self.layer3(x) 115 | 116 | x = self.avgpool(x) 117 | x = self.fc(x) 118 | x = self.bn2(x) 119 | 120 | return x.view(x.size(0), -1) 121 | 122 | def resnet20(num_classes=10, **kwargs): 123 | return ResNet( 124 | BasicBlock, 125 | [3, 3, 3], 126 | num_classes=num_classes, 127 | ) -------------------------------------------------------------------------------- /torchshiftadd/models/resnet20_adder.py: -------------------------------------------------------------------------------- 1 | from torchshiftadd.layers import adder 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | __all__ = ['resnet20_adder'] 6 | 7 | 8 | def conv3x3(in_planes, out_planes, stride=1): 9 | return adder.Adder2D( 10 | in_planes, 11 | out_planes, 12 | kernel_size=3, 13 | stride=stride, 14 | padding=1, 15 | bias=False, 16 | ) 17 | 18 | 19 | class BasicBlock(nn.Module): 20 | 21 | expansion=1 22 | 23 | def __init__(self, inplanes, planes, stride=1, downsample=None): 24 | super(BasicBlock, self).__init__() 25 | self.conv1 = conv3x3(inplanes, planes, stride=stride) 26 | self.bn1 = nn.BatchNorm2d(planes) 27 | self.relu = nn.ReLU(inplace=True) 28 | self.conv2 = conv3x3(planes, planes) 29 | self.bn2 = nn.BatchNorm2d(planes) 30 | self.downsample = downsample 31 | self.stride = stride 32 | 33 | def forward(self, x): 34 | residual = x 35 | 36 | out = self.conv1(x) 37 | out = self.bn1(out) 38 | out = self.relu(out) 39 | 40 | out = self.conv2(out) 41 | out = self.bn2(out) 42 | 43 | if self.downsample is not None: 44 | residual = self.downsample(x) 45 | 46 | out += residual 47 | out = self.relu(out) 48 | 49 | return out 50 | 51 | class ResNet(nn.Module): 52 | 53 | def __init__(self, block, layers, num_classes=10): 54 | super(ResNet, self).__init__() 55 | self.inplanes = 16 56 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 57 | self.bn1 = nn.BatchNorm2d(16) 58 | self.relu = nn.ReLU(inplace=True) 59 | self.layer1 = self._make_layer(block, 16, layers[0]) 60 | self.layer2 = self._make_layer(block, 32, layers[1], stride=2) 61 | self.layer3 = self._make_layer(block, 64, layers[2], stride=2) 62 | self.avgpool = nn.AvgPool2d(8, stride=1) 63 | # use conv as fc layer (addernet) 64 | self.fc = nn.Conv2d(64 * block.expansion, num_classes, 1, bias=False) 65 | self.bn2 = nn.BatchNorm2d(num_classes) 66 | 67 | 68 | # init (for adder) 69 | for m in self.modules(): 70 | if isinstance(m, nn.BatchNorm2d): 71 | m.weight.data.fill_(1) 72 | m.bias.data.zero_() 73 | 74 | def _make_layer(self, block, planes, blocks, stride=1): 75 | downsample = None 76 | if stride != 1 or self.inplanes != planes * block.expansion: 77 | downsample = nn.Sequential( 78 | adder.Adder2D( 79 | self.inplanes, 80 | planes * block.expansion, 81 | kernel_size=1, 82 | stride=stride, 83 | bias=False, 84 | ), 85 | nn.BatchNorm2d(planes * block.expansion) 86 | ) 87 | 88 | layers = [] 89 | layers.append( 90 | block( 91 | inplanes=self.inplanes, 92 | planes=planes, 93 | stride=stride, 94 | downsample=downsample, 95 | ) 96 | ) 97 | self.inplanes = planes * block.expansion 98 | for _ in range(1, blocks): 99 | layers.append( 100 | block( 101 | inplanes=self.inplanes, 102 | planes=planes, 103 | ) 104 | ) 105 | 106 | return nn.Sequential(*layers) 107 | 108 | def forward(self, x): 109 | x = self.conv1(x) 110 | x = self.bn1(x) 111 | x = self.relu(x) 112 | 113 | x = self.layer1(x) 114 | x = self.layer2(x) 115 | x = self.layer3(x) 116 | 117 | x = self.avgpool(x) 118 | x = self.fc(x) 119 | x = self.bn2(x) 120 | 121 | return x.view(x.size(0), -1) 122 | 123 | def resnet20_adder(num_classes=10, **kwargs): 124 | return ResNet( 125 | BasicBlock, 126 | [3, 3, 3], 127 | num_classes=num_classes, 128 | ) -------------------------------------------------------------------------------- /torchshiftadd/models/resnet20_shift.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from torchshiftadd.layers import shift 6 | 7 | def convert_to_shift(model): 8 | conversion_count = 0 9 | 10 | for name, module in reversed(model._modules.items()): 11 | 12 | if len(list(module.children())) > 0: 13 | model._modules[name], num_converted = convert_to_shift(model=module) 14 | conversion_count += num_converted 15 | 16 | if type(module) == nn.Conv2d: 17 | conv2d = module 18 | shift_conv2d = shift.Conv2dShift( 19 | module.in_channels, 20 | module.out_channels, 21 | module.kernel_size, 22 | module.stride, 23 | module.padding, 24 | module.dilation, 25 | module.groups, 26 | module.bias is not None, 27 | module.padding_mode 28 | ) 29 | shift_conv2d.shift.data, shift_conv2d.sign.data = get_shift_and_sign(conv2d.weight) 30 | shift_conv2d.bias = conv2d.bias 31 | model._modules[name] = shift_conv2d 32 | conversion_count += 1 33 | 34 | return model, conversion_count 35 | 36 | def get_shift_and_sign(x, rounding='deterministic'): 37 | sign = torch.sign(x) 38 | 39 | x_abs = torch.abs(x) 40 | shift = round(torch.log(x_abs) / np.log(2), rounding) 41 | 42 | return shift, sign 43 | 44 | def round(x, rounding='deterministic'): 45 | assert(rounding in ['deterministic', 'stochastic']) 46 | if rounding == 'stochastic': 47 | x_floor = x.floor() 48 | return x_floor + torch.bernoulli(x - x_floor) 49 | else: 50 | return x.round() -------------------------------------------------------------------------------- /torchshiftadd/models/resnet20_shiftadd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchshiftadd.layers import adder, shift 5 | 6 | __all__ = ['resnet20_shiftadd'] 7 | 8 | 9 | def conv3x3(in_planes, out_planes, stride=1): 10 | shift_layer = shift.Conv2dShift( 11 | in_planes, 12 | out_planes, 13 | kernel_size=3, 14 | stride=stride, 15 | padding=1, 16 | bias=False 17 | ) 18 | add_layer = adder.Adder2D( 19 | out_planes, 20 | out_planes, 21 | kernel_size=3, 22 | stride=1, 23 | padding=1, 24 | bias=False, 25 | ) 26 | return nn.Sequential(shift_layer, add_layer) 27 | 28 | 29 | class BasicBlock(nn.Module): 30 | 31 | expansion=1 32 | 33 | def __init__(self, inplanes, planes, stride=1, downsample=None): 34 | super(BasicBlock, self).__init__() 35 | self.conv1 = conv3x3(inplanes, planes, stride=stride) 36 | self.bn1 = nn.BatchNorm2d(planes) 37 | self.relu = nn.ReLU(inplace=True) 38 | self.conv2 = conv3x3(planes, planes) 39 | self.bn2 = nn.BatchNorm2d(planes) 40 | self.downsample = downsample 41 | self.stride = stride 42 | 43 | def forward(self, x): 44 | residual = x 45 | 46 | out = self.conv1(x) 47 | out = self.bn1(out) 48 | out = self.relu(out) 49 | 50 | out = self.conv2(out) 51 | out = self.bn2(out) 52 | 53 | if self.downsample is not None: 54 | residual = self.downsample(x) 55 | 56 | out += residual 57 | out = self.relu(out) 58 | 59 | return out 60 | 61 | class ResNet(nn.Module): 62 | 63 | def __init__(self, block, layers, num_classes=10): 64 | super(ResNet, self).__init__() 65 | self.inplanes = 16 66 | self.conv1 = shift.Conv2dShift(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 67 | self.bn1 = nn.BatchNorm2d(16) 68 | self.relu = nn.ReLU(inplace=True) 69 | self.layer1 = self._make_layer(block, 16, layers[0]) 70 | self.layer2 = self._make_layer(block, 32, layers[1], stride=2) 71 | self.layer3 = self._make_layer(block, 64, layers[2], stride=2) 72 | self.avgpool = nn.AvgPool2d(8, stride=1) 73 | self.fc = shift.Conv2dShift(64 * block.expansion, num_classes, 1, bias=False) 74 | self.bn2 = nn.BatchNorm2d(num_classes) 75 | 76 | 77 | # init 78 | for m in self.modules(): 79 | if isinstance(m, nn.BatchNorm2d): 80 | m.weight.data.fill_(1) 81 | m.bias.data.zero_() 82 | 83 | def _make_layer(self, block, planes, blocks, stride=1): 84 | downsample = None 85 | if stride != 1 or self.inplanes != planes * block.expansion: 86 | downsample = nn.Sequential( 87 | shift.Conv2dShift( 88 | self.inplanes, 89 | planes * block.expansion, 90 | kernel_size=1, 91 | stride=stride, 92 | bias=False, 93 | ), 94 | adder.Adder2D( 95 | planes * block.expansion, 96 | planes * block.expansion, 97 | kernel_size=1, 98 | stride=1, 99 | bias=False, 100 | ), 101 | nn.BatchNorm2d(planes * block.expansion) 102 | ) 103 | 104 | layers = [] 105 | layers.append( 106 | block( 107 | inplanes=self.inplanes, 108 | planes=planes, 109 | stride=stride, 110 | downsample=downsample, 111 | ) 112 | ) 113 | self.inplanes = planes * block.expansion 114 | for _ in range(1, blocks): 115 | layers.append( 116 | block( 117 | inplanes=self.inplanes, 118 | planes=planes, 119 | ) 120 | ) 121 | 122 | return nn.Sequential(*layers) 123 | 124 | def forward(self, x): 125 | x = self.conv1(x) 126 | x = self.bn1(x) 127 | x = self.relu(x) 128 | 129 | x = self.layer1(x) 130 | x = self.layer2(x) 131 | x = self.layer3(x) 132 | 133 | x = self.avgpool(x) 134 | x = self.fc(x) 135 | x = self.bn2(x) 136 | 137 | return x.view(x.size(0), -1) 138 | 139 | def resnet20_shiftadd(num_classes=10, **kwargs): 140 | return ResNet( 141 | BasicBlock, 142 | [3, 3, 3], 143 | num_classes=num_classes, 144 | ) -------------------------------------------------------------------------------- /torchshiftadd/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .torch import load_extension 2 | from . import comm 3 | from .ckpt_loading import load_add_state_dict, load_shiftadd_state_dict 4 | from .test_acc import test_acc 5 | 6 | __all__ = [ 7 | "load_extension", 8 | "comm", 9 | "load_add_state_dict", 10 | "load_shiftadd_state_dict", 11 | "test_acc" 12 | ] -------------------------------------------------------------------------------- /torchshiftadd/utils/ckpt_loading.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | def load_add_state_dict(state_dict): 4 | new_state_dict = OrderedDict() 5 | for k, v in state_dict.items(): 6 | if 'weight' in k and not 'bn' in k and not 'fc' in k: 7 | if k == 'conv1.weight' or 'downsample.1' in k: 8 | new_state_dict[k] = v 9 | continue 10 | k = k[:-6] + 'adder' 11 | # print(k) 12 | new_state_dict[k] = v 13 | return new_state_dict 14 | 15 | def load_shiftadd_state_dict(state_dict): 16 | from collections import OrderedDict 17 | new_state_dict = OrderedDict() 18 | for k, v in state_dict.items(): 19 | if 'weight' in k and not 'bn' in k and not 'fc' in k: 20 | if k == 'conv1.weight' or 'downsample.2' in k: 21 | new_state_dict[k] = v 22 | continue 23 | k = k[:-6] + 'adder' 24 | # print(k) 25 | new_state_dict[k] = v 26 | return new_state_dict -------------------------------------------------------------------------------- /torchshiftadd/utils/comm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import multiprocessing 3 | from collections import defaultdict 4 | 5 | import torch 6 | from torch import distributed as dist 7 | 8 | 9 | cpu_group = None 10 | gpu_group = None 11 | 12 | 13 | def get_rank(): 14 | """ 15 | Get the rank of this process in distributed processes. 16 | 17 | Return 0 for single process case. 18 | """ 19 | if dist.is_initialized(): 20 | return dist.get_rank() 21 | if "RANK" in os.environ: 22 | return int(os.environ["RANK"]) 23 | return 0 24 | 25 | 26 | def get_world_size(): 27 | """ 28 | Get the total number of distributed processes. 29 | 30 | Return 1 for single process case. 31 | """ 32 | if dist.is_initialized(): 33 | return dist.get_world_size() 34 | if "WORLD_SIZE" in os.environ: 35 | return int(os.environ["WORLD_SIZE"]) 36 | return 1 37 | 38 | 39 | def get_group(device): 40 | """ 41 | Get the process group corresponding to the given device. 42 | 43 | Parameters: 44 | device (torch.device): query device 45 | """ 46 | group = cpu_group if device.type == "cpu" else gpu_group 47 | if group is None: 48 | raise ValueError("%s group is not initialized. Use comm.init_process_group() to initialize it" 49 | % device.type.upper()) 50 | return group 51 | 52 | 53 | def init_process_group(backend, init_method=None, **kwargs): 54 | """ 55 | Initialize CPU and/or GPU process groups. 56 | 57 | Parameters: 58 | backend (str): Communication backend. Use ``nccl`` for GPUs and ``gloo`` for CPUs. 59 | init_method (str, optional): URL specifying how to initialize the process group 60 | """ 61 | global cpu_group 62 | global gpu_group 63 | 64 | dist.init_process_group(backend, init_method, **kwargs) 65 | gpu_group = dist.group.WORLD 66 | if backend == "nccl": 67 | cpu_group = dist.new_group(backend="gloo") 68 | else: 69 | cpu_group = gpu_group 70 | 71 | 72 | def get_cpu_count(): 73 | """ 74 | Get the number of CPUs on this node. 75 | """ 76 | return multiprocessing.cpu_count() 77 | 78 | 79 | def synchronize(): 80 | """ 81 | Synchronize among all distributed processes. 82 | """ 83 | if get_world_size() > 1: 84 | dist.barrier() 85 | 86 | 87 | def _recursive_read(obj): 88 | values = defaultdict(list) 89 | sizes = defaultdict(list) 90 | if isinstance(obj, torch.Tensor): 91 | values[obj.dtype] += [obj.flatten()] 92 | sizes[obj.dtype] += [torch.tensor([obj.numel()], device=obj.device)] 93 | elif isinstance(obj, dict): 94 | for v in obj.values(): 95 | child_values, child_sizes = _recursive_read(v) 96 | for k, v in child_values.items(): 97 | values[k] += v 98 | for k, v in child_sizes.items(): 99 | sizes[k] += v 100 | elif isinstance(obj, list) or isinstance(obj, tuple): 101 | for v in obj: 102 | child_values, child_sizes = _recursive_read(v) 103 | for k, v in child_values.items(): 104 | values[k] += v 105 | for k, v in child_sizes.items(): 106 | sizes[k] += v 107 | else: 108 | raise ValueError("Unknown type `%s`" % type(obj)) 109 | return values, sizes 110 | 111 | 112 | def _recursive_write(obj, values, sizes=None): 113 | if isinstance(obj, torch.Tensor): 114 | if sizes is None: 115 | size = torch.tensor([obj.numel()], device=obj.device) 116 | else: 117 | s = sizes[obj.dtype] 118 | size, s = s.split([1, len(s) - 1]) 119 | sizes[obj.dtype] = s 120 | v = values[obj.dtype] 121 | new_obj, v = v.split([size, v.shape[-1] - size], dim=-1) 122 | # compatible with reduce / stack / cat 123 | new_obj = new_obj.view(new_obj.shape[:-1] + (-1,) + obj.shape[1:]) 124 | values[obj.dtype] = v 125 | return new_obj, values 126 | elif isinstance(obj, dict): 127 | new_obj = {} 128 | for k, v in obj.items(): 129 | new_obj[k], values = _recursive_write(v, values, sizes) 130 | elif isinstance(obj, list) or isinstance(obj, tuple): 131 | new_obj = [] 132 | for v in obj: 133 | new_v, values = _recursive_write(v, values, sizes) 134 | new_obj.append(new_v) 135 | else: 136 | raise ValueError("Unknown type `%s`" % type(obj)) 137 | return new_obj, values 138 | 139 | 140 | def reduce(obj, op="sum", dst=None): 141 | """ 142 | Reduce any nested container of tensors. 143 | 144 | Parameters: 145 | obj (Object): any container object. Can be nested list, tuple or dict. 146 | op (str, optional): element-wise reduction operator. 147 | Available operators are ``sum``, ``mean``, ``min``, ``max``, ``product``. 148 | dst (int, optional): rank of destination worker. If not specified, broadcast the result to all workers. 149 | 150 | Example:: 151 | 152 | >>> # assume 4 workers 153 | >>> rank = comm.get_rank() 154 | >>> x = torch.rand(5) 155 | >>> obj = {"polynomial": x ** rank} 156 | >>> obj = comm.reduce(obj) 157 | >>> assert torch.allclose(obj["polynomial"], x ** 3 + x ** 2 + x + 1) 158 | """ 159 | values = _recursive_read(obj)[0] 160 | values = {k: torch.cat(v) for k, v in values.items()} 161 | 162 | is_mean = op == "mean" 163 | if is_mean: 164 | op = "sum" 165 | op = getattr(dist.ReduceOp, op.upper()) 166 | 167 | reduced = {} 168 | for k, v in values.items(): 169 | dtype = v.dtype 170 | # NCCL can't solve bool. Cast them to byte 171 | if dtype == torch.bool: 172 | v = v.byte() 173 | group = get_group(v.device) 174 | if dst is None: 175 | dist.all_reduce(v, op=op, group=group) 176 | else: 177 | dist.reduce(v, op=op, dst=dst, group=group) 178 | if is_mean: 179 | v = v / get_world_size() 180 | reduced[k] = v.type(dtype) 181 | 182 | return _recursive_write(obj, reduced)[0] 183 | 184 | 185 | def stack(obj, dst=None): 186 | """ 187 | Stack any nested container of tensors. The new dimension will be added at the 0-th axis. 188 | 189 | Parameters: 190 | obj (Object): any container object. Can be nested list, tuple or dict. 191 | dst (int, optional): rank of destination worker. If not specified, broadcast the result to all workers. 192 | 193 | Example:: 194 | 195 | >>> # assume 4 workers 196 | >>> rank = comm.get_rank() 197 | >>> x = torch.rand(5) 198 | >>> obj = {"exponent": x ** rank} 199 | >>> obj = comm.stack(obj) 200 | >>> truth = torch.stack([torch.ones_like(x), x, x ** 2, x ** 3] 201 | >>> assert torch.allclose(obj["exponent"], truth)) 202 | """ 203 | values = _recursive_read(obj)[0] 204 | values = {k: torch.cat(v) for k, v in values.items()} 205 | 206 | stacked = {} 207 | for k, v in values.items(): 208 | dtype = v.dtype 209 | # NCCL can't solve bool. Cast them to byte 210 | if dtype == torch.bool: 211 | dtype = torch.uint8 212 | s = torch.zeros(get_world_size(), *v.shape, dtype=dtype, device=v.device) 213 | s[get_rank()] = v 214 | group = get_group(s.device) 215 | if dst is None: 216 | dist.all_reduce(s, op=dist.ReduceOp.SUM, group=group) 217 | else: 218 | dist.reduce(s, op=dist.ReduceOp.SUM, dst=dst, group=group) 219 | stacked[k] = s.type(v.dtype) 220 | 221 | return _recursive_write(obj, stacked)[0] 222 | 223 | 224 | def cat(obj, dst=None): 225 | """ 226 | Concatenate any nested container of tensors along the 0-th axis. 227 | 228 | Parameters: 229 | obj (Object): any container object. Can be nested list, tuple or dict. 230 | dst (int, optional): rank of destination worker. If not specified, broadcast the result to all workers. 231 | 232 | Example:: 233 | 234 | >>> # assume 4 workers 235 | >>> rank = comm.get_rank() 236 | >>> rng = torch.arange(10) 237 | >>> obj = {"range": rng[rank * (rank + 1) // 2: (rank + 1) * (rank + 2) // 2]} 238 | >>> obj = comm.cat(obj) 239 | >>> assert torch.allclose(obj["range"], rng) 240 | """ 241 | values, sizes = _recursive_read(obj) 242 | sizes = {k: torch.cat(v) for k, v in sizes.items()} 243 | 244 | sizes = stack(sizes) 245 | cated = {} 246 | for k, value in values.items(): 247 | size = sizes[k].t().flatten() # sizes[k]: (num_worker, num_obj) 248 | dtype = value[0].dtype 249 | # NCCL can't solve bool. Cast them to byte 250 | if dtype == torch.bool: 251 | dtype = torch.uint8 252 | s = torch.zeros(size.sum(), dtype=dtype, device=value[0].device) 253 | obj_id = get_rank() 254 | world_size = get_world_size() 255 | offset = size[:obj_id].sum() 256 | for v in value: 257 | assert offset + v.numel() <= len(s) 258 | s[offset: offset + v.numel()] = v 259 | offset += size[obj_id: obj_id + world_size].sum() 260 | obj_id += world_size 261 | group = get_group(s.device) 262 | if dst is None: 263 | dist.all_reduce(s, op=dist.ReduceOp.SUM, group=group) 264 | else: 265 | dist.reduce(s, op=dist.ReduceOp.SUM, dst=dst, group=group) 266 | cated[k] = s.type(value[0].dtype) 267 | sizes = {k: v.sum(dim=0) for k, v in sizes.items()} 268 | 269 | return _recursive_write(obj, cated, sizes)[0] -------------------------------------------------------------------------------- /torchshiftadd/utils/decorator.py: -------------------------------------------------------------------------------- 1 | class cached_property(property): 2 | """ 3 | Cache the property once computed. 4 | """ 5 | 6 | def __init__(self, func): 7 | self.func = func 8 | self.__doc__ = func.__doc__ 9 | 10 | def __get__(self, obj, cls): 11 | if obj is None: 12 | return self 13 | result = self.func(obj) 14 | obj.__dict__[self.func.__name__] = result 15 | return result 16 | -------------------------------------------------------------------------------- /torchshiftadd/utils/quantize.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd.function import InplaceFunction, Function 7 | 8 | QParams = namedtuple('QParams', ['range', 'zero_point', 'num_bits']) 9 | 10 | _DEFAULT_FLATTEN = (1, -1) 11 | _DEFAULT_FLATTEN_GRAD = (0, -1) 12 | 13 | 14 | def _deflatten_as(x, x_full): 15 | shape = list(x.shape) + [1] * (x_full.dim() - x.dim()) 16 | return x.view(*shape) 17 | 18 | 19 | def calculate_qparams(x, num_bits, flatten_dims=_DEFAULT_FLATTEN, reduce_dim=0, reduce_type='mean', keepdim=False, true_zero=False): 20 | with torch.no_grad(): 21 | x_flat = x.flatten(*flatten_dims) 22 | if x_flat.dim() == 1: 23 | min_values = _deflatten_as(x_flat.min(), x) 24 | max_values = _deflatten_as(x_flat.max(), x) 25 | else: 26 | min_values = _deflatten_as(x_flat.min(-1)[0], x) 27 | max_values = _deflatten_as(x_flat.max(-1)[0], x) 28 | 29 | if reduce_dim is not None: 30 | if reduce_type == 'mean': 31 | min_values = min_values.mean(reduce_dim, keepdim=keepdim) 32 | max_values = max_values.mean(reduce_dim, keepdim=keepdim) 33 | else: 34 | min_values = min_values.min(reduce_dim, keepdim=keepdim)[0] 35 | max_values = max_values.max(reduce_dim, keepdim=keepdim)[0] 36 | 37 | range_values = max_values - min_values 38 | return QParams(range=range_values, zero_point=min_values, 39 | num_bits=num_bits) 40 | 41 | 42 | class UniformQuantize(InplaceFunction): 43 | 44 | @staticmethod 45 | def forward(ctx, input, num_bits=None, qparams=None, flatten_dims=_DEFAULT_FLATTEN, 46 | reduce_dim=0, dequantize=True, signed=False, stochastic=False, inplace=False): 47 | 48 | ctx.inplace = inplace 49 | 50 | if ctx.inplace: 51 | ctx.mark_dirty(input) 52 | output = input 53 | else: 54 | output = input.clone() 55 | 56 | if qparams is None: 57 | assert num_bits is not None, "either provide qparams of num_bits to quantize" 58 | qparams = calculate_qparams( 59 | input, num_bits=num_bits, flatten_dims=flatten_dims, reduce_dim=reduce_dim) 60 | 61 | zero_point = qparams.zero_point 62 | num_bits = qparams.num_bits 63 | qmin = -(2.**(num_bits - 1)) if signed else 0. 64 | qmax = qmin + 2.**num_bits - 1. 65 | scale = qparams.range / (qmax - qmin) 66 | 67 | min_scale = torch.tensor(1e-8).expand_as(scale).cuda() 68 | scale = torch.max(scale, min_scale) 69 | 70 | with torch.no_grad(): 71 | output.add_(qmin * scale - zero_point).div_(scale) 72 | if stochastic: 73 | # print('use stochastic') 74 | noise = output.new(output.shape).uniform_(-0.5, 0.5) 75 | output.add_(noise) 76 | # quantize 77 | output.clamp_(qmin, qmax).round_() 78 | 79 | if dequantize: 80 | output.mul_(scale).add_( 81 | zero_point - qmin * scale) # dequantize 82 | return output 83 | 84 | @staticmethod 85 | def backward(ctx, grad_output): 86 | # straight-through estimator 87 | grad_input = grad_output 88 | return grad_input, None, None, None, None, None, None, None, None 89 | 90 | 91 | class UniformQuantizeGrad(InplaceFunction): 92 | 93 | @staticmethod 94 | def forward(ctx, input, num_bits=None, qparams=None, flatten_dims=_DEFAULT_FLATTEN_GRAD, 95 | reduce_dim=0, dequantize=True, signed=False, stochastic=True): 96 | ctx.num_bits = num_bits 97 | ctx.qparams = qparams 98 | ctx.flatten_dims = flatten_dims 99 | ctx.stochastic = stochastic 100 | ctx.signed = signed 101 | ctx.dequantize = dequantize 102 | ctx.reduce_dim = reduce_dim 103 | ctx.inplace = False 104 | return input 105 | 106 | @staticmethod 107 | def backward(ctx, grad_output): 108 | qparams = ctx.qparams 109 | with torch.no_grad(): 110 | if qparams is None: 111 | assert ctx.num_bits is not None, "either provide qparams of num_bits to quantize" 112 | qparams = calculate_qparams( 113 | grad_output, num_bits=ctx.num_bits, flatten_dims=ctx.flatten_dims, reduce_dim=ctx.reduce_dim, reduce_type='extreme') 114 | 115 | grad_input = quantize(grad_output, num_bits=None, 116 | qparams=qparams, flatten_dims=ctx.flatten_dims, reduce_dim=ctx.reduce_dim, 117 | dequantize=True, signed=ctx.signed, stochastic=ctx.stochastic, inplace=False) 118 | return grad_input, None, None, None, None, None, None, None 119 | 120 | 121 | def conv2d_biprec(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, num_bits_grad=None): 122 | out1 = F.conv2d(input.detach(), weight, bias, 123 | stride, padding, dilation, groups) 124 | out2 = F.conv2d(input, weight.detach(), bias.detach() if bias is not None else None, 125 | stride, padding, dilation, groups) 126 | out2 = quantize_grad(out2, num_bits=num_bits_grad, flatten_dims=(1, -1)) 127 | return out1 + out2 - out1.detach() 128 | 129 | 130 | def linear_biprec(input, weight, bias=None, num_bits_grad=None): 131 | out1 = F.linear(input.detach(), weight, bias) 132 | out2 = F.linear(input, weight.detach(), bias.detach() 133 | if bias is not None else None) 134 | out2 = quantize_grad(out2, num_bits=num_bits_grad) 135 | return out1 + out2 - out1.detach() 136 | 137 | 138 | def quantize(x, num_bits=None, qparams=None, flatten_dims=_DEFAULT_FLATTEN, reduce_dim=0, dequantize=True, signed=False, stochastic=False, inplace=False): 139 | if qparams: 140 | if qparams.num_bits: 141 | return UniformQuantize().apply(x, num_bits, qparams, flatten_dims, reduce_dim, dequantize, signed, stochastic, inplace) 142 | elif num_bits: 143 | return UniformQuantize().apply(x, num_bits, qparams, flatten_dims, reduce_dim, dequantize, signed, stochastic, inplace) 144 | 145 | return x 146 | 147 | 148 | def quantize_grad(x, num_bits=None, qparams=None, flatten_dims=_DEFAULT_FLATTEN_GRAD, reduce_dim=0, dequantize=True, signed=False, stochastic=True): 149 | if qparams: 150 | if qparams.num_bits: 151 | return UniformQuantizeGrad().apply(x, num_bits, qparams, flatten_dims, reduce_dim, dequantize, signed, stochastic) 152 | elif num_bits: 153 | return UniformQuantizeGrad().apply(x, num_bits, qparams, flatten_dims, reduce_dim, dequantize, signed, stochastic) 154 | 155 | return x 156 | 157 | 158 | class QuantMeasure(nn.Module): 159 | """docstring for QuantMeasure.""" 160 | 161 | def __init__(self, shape_measure=(1,), flatten_dims=_DEFAULT_FLATTEN, 162 | inplace=False, dequantize=True, stochastic=False, momentum=0.9, measure=False): 163 | super(QuantMeasure, self).__init__() 164 | self.register_buffer('running_zero_point', torch.zeros(*shape_measure)) 165 | self.register_buffer('running_range', torch.zeros(*shape_measure)) 166 | self.measure = measure 167 | if self.measure: 168 | self.register_buffer('num_measured', torch.zeros(1)) 169 | self.flatten_dims = flatten_dims 170 | self.momentum = momentum 171 | self.dequantize = dequantize 172 | self.stochastic = stochastic 173 | self.inplace = inplace 174 | 175 | def forward(self, input, num_bits, qparams=None): 176 | 177 | if self.training or self.measure: 178 | if qparams is None: 179 | qparams = calculate_qparams( 180 | input, num_bits=num_bits, flatten_dims=self.flatten_dims, reduce_dim=0, reduce_type='extreme') 181 | with torch.no_grad(): 182 | if self.measure: 183 | momentum = self.num_measured / (self.num_measured + 1) 184 | self.num_measured += 1 185 | else: 186 | momentum = self.momentum 187 | self.running_zero_point.mul_(momentum).add_( 188 | qparams.zero_point * (1 - momentum)) 189 | self.running_range.mul_(momentum).add_( 190 | qparams.range * (1 - momentum)) 191 | else: 192 | qparams = QParams(range=self.running_range, 193 | zero_point=self.running_zero_point, num_bits=num_bits) 194 | if self.measure: 195 | return input 196 | else: 197 | q_input = quantize(input, qparams=qparams, dequantize=self.dequantize, 198 | stochastic=self.stochastic, inplace=self.inplace) 199 | return q_input 200 | 201 | 202 | class QConv2d(nn.Conv2d): 203 | """docstring for QConv2d.""" 204 | 205 | def __init__(self, in_channels, out_channels, kernel_size, 206 | stride=1, padding=0, dilation=1, groups=1, bias=True, momentum=0.1, quant_act_forward=0, quant_act_backward=0, quant_grad_act_error=0, quant_grad_act_gc=0, weight_bits=0, fix_prec=False): 207 | super(QConv2d, self).__init__(in_channels, out_channels, kernel_size, 208 | stride, padding, dilation, groups, bias) 209 | 210 | self.quantize_input_fw = QuantMeasure(shape_measure=(1, 1, 1, 1), flatten_dims=(1, -1), momentum=momentum) 211 | self.quantize_input_bw = QuantMeasure(shape_measure=(1, 1, 1, 1), flatten_dims=(1, -1), momentum=momentum) 212 | self.quant_act_forward = quant_act_forward 213 | self.quant_act_backward = quant_act_backward 214 | self.quant_grad_act_error = quant_grad_act_error 215 | self.quant_grad_act_gc = quant_grad_act_gc 216 | self.weight_bits = weight_bits 217 | self.fix_prec = fix_prec 218 | self.stride = stride 219 | 220 | 221 | def forward(self, input, num_bits, num_grad_bits): 222 | if num_bits == 0: 223 | output = F.conv2d(input, self.weight, self.bias, self.stride,self.padding, self.dilation, self.groups) 224 | return output 225 | 226 | if self.bias is not None: 227 | qbias = quantize( 228 | self.bias, num_bits=self.num_bits_weight + self.num_bits, 229 | flatten_dims=(0, -1)) 230 | else: 231 | qbias = None 232 | 233 | if self.fix_prec: 234 | if self.quant_act_forward or self.quant_act_backward or self.quant_grad_act_error or self.quant_grad_act_gc or self.weight_bits: 235 | weight_qparams = calculate_qparams(self.weight, num_bits=self.weight_bits, flatten_dims=(1, -1), reduce_dim=None) 236 | qweight = quantize(self.weight, qparams=weight_qparams) 237 | 238 | qinput_fw = self.quantize_input_fw(input, self.quant_act_forward) 239 | qinput_bw = self.quantize_input_bw(input, self.quant_act_backward) 240 | 241 | error_bits = self.quant_grad_act_error 242 | gc_bits = self.quant_grad_act_gc 243 | output = self.conv2d_quant_act(qinput_fw, qinput_bw, qweight, qbias, self.stride, self.padding, self.dilation, self.groups, error_bits, gc_bits) 244 | 245 | else: 246 | qinput = self.quantize_input_fw(input, num_bits) 247 | weight_qparams = calculate_qparams(self.weight, num_bits=num_bits, flatten_dims=(1, -1), reduce_dim=None) 248 | qweight = quantize(self.weight, qparams=weight_qparams) 249 | output = F.conv2d(qinput, qweight, qbias, self.stride, self.padding, self.dilation, self.groups) 250 | output = quantize_grad(output, num_bits=num_grad_bits, flatten_dims=(1, -1)) 251 | 252 | return output 253 | 254 | weight_qparams = calculate_qparams(self.weight, num_bits=self.weight_bits, flatten_dims=(1, -1), reduce_dim=None) 255 | qweight = quantize(self.weight, qparams=weight_qparams) 256 | 257 | qinput = self.quantize_input_fw(input, num_bits) 258 | output = F.conv2d(qinput, qweight, qbias, self.stride, self.padding, self.dilation, self.groups) 259 | output = quantize_grad(output, num_bits=num_grad_bits, flatten_dims=(1, -1)) 260 | 261 | # if self.quant_act_forward == -1: 262 | # qinput_fw = self.quantize_input_fw(input, num_bits) 263 | # else: 264 | # qinput_fw = self.quantize_input_fw(input, self.quant_act_forward) 265 | 266 | # if self.quant_act_backward == -1: 267 | # qinput_bw = self.quantize_input_bw(input, num_bits) 268 | # else: 269 | # qinput_bw = self.quantize_input_bw(input, self.quant_act_backward) 270 | 271 | # if self.quant_grad_act_error == -1: 272 | # error_bits = num_grad_bits 273 | # else: 274 | # error_bits = self.quant_grad_act_error 275 | 276 | # if self.quant_grad_act_gc == -1: 277 | # gc_bits = num_grad_bits 278 | # else: 279 | # gc_bits = self.quant_grad_act_gc 280 | 281 | # output = self.conv2d_quant_act(qinput_fw, qinput_bw, qweight, qbias, self.stride, self.padding, self.dilation, self.groups, error_bits, gc_bits) 282 | 283 | return output 284 | 285 | 286 | def conv2d_quant_act(self, input_fw, input_bw, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, error_bits=0, gc_bits=0): 287 | out1 = F.conv2d(input_fw, weight.detach(), bias.detach() if bias is not None else None, 288 | stride, padding, dilation, groups) 289 | out2 = F.conv2d(input_bw.detach(), weight, bias, 290 | stride, padding, dilation, groups) 291 | out1 = quantize_grad(out1, num_bits=error_bits) 292 | out2 = quantize_grad(out2, num_bits=gc_bits) 293 | return out1 + out2 - out2.detach() 294 | 295 | 296 | class QLinear(nn.Linear): 297 | """docstring for QConv2d.""" 298 | 299 | def __init__(self, in_features, out_features, bias=True, num_bits=8, num_bits_weight=8, num_bits_grad=8, biprecision=False): 300 | super(QLinear, self).__init__(in_features, out_features, bias) 301 | self.num_bits = num_bits 302 | self.num_bits_weight = num_bits_weight or num_bits 303 | self.num_bits_grad = num_bits_grad 304 | self.biprecision = biprecision 305 | self.quantize_input = QuantMeasure(shape_measure=(1, 1), flatten_dims=(1, -1), momentum=0.1) 306 | 307 | def forward(self, input, num_bits, num_bits_grad): 308 | # self.quantize_input = QuantMeasure(num_bits) 309 | 310 | qinput = self.quantize_input(input, num_bits) 311 | weight_qparams = calculate_qparams( 312 | self.weight, num_bits=num_bits, flatten_dims=(1, -1), reduce_dim=None) 313 | qweight = quantize(self.weight, qparams=weight_qparams) 314 | if self.bias is not None: 315 | qbias = quantize( 316 | self.bias, num_bits=num_bits, 317 | flatten_dims=(0, -1)) 318 | else: 319 | qbias = None 320 | 321 | # if not self.biprecision or self.num_bits_grad is None: 322 | output = F.linear(qinput, qweight, qbias) 323 | # if self.num_bits_grad is not None: 324 | output = quantize_grad( 325 | output, num_bits=num_bits_grad) 326 | # else: 327 | # output = linear_biprec(qinput, qweight, qbias, self.num_bits_grad) 328 | 329 | 330 | return output 331 | 332 | 333 | class RangeBN(nn.Module): 334 | # this is normalized RangeBN 335 | 336 | def __init__(self, num_features, dim=1, momentum=0.1, affine=True, num_chunks=16, eps=1e-5, num_bits=8, num_bits_grad=8): 337 | super(RangeBN, self).__init__() 338 | self.register_buffer('running_mean', torch.zeros(num_features)) 339 | self.register_buffer('running_var', torch.zeros(num_features)) 340 | 341 | self.momentum = momentum 342 | self.dim = dim 343 | if affine: 344 | self.bias = nn.Parameter(torch.Tensor(num_features)) 345 | self.weight = nn.Parameter(torch.Tensor(num_features)) 346 | self.num_bits = num_bits 347 | self.num_bits_grad = num_bits_grad 348 | self.quantize_input = QuantMeasure(inplace=True, shape_measure=(1, 1, 1, 1), flatten_dims=(1, -1)) 349 | self.eps = eps 350 | self.num_chunks = num_chunks 351 | self.reset_params() 352 | 353 | def reset_params(self): 354 | if self.weight is not None: 355 | self.weight.data.uniform_() 356 | if self.bias is not None: 357 | self.bias.data.zero_() 358 | 359 | def forward(self, x, num_bits, num_grad_bits): 360 | x = self.quantize_input(x, num_bits) 361 | if x.dim() == 2: # 1d 362 | x = x.unsqueeze(-1,).unsqueeze(-1) 363 | 364 | if self.training: 365 | B, C, H, W = x.shape 366 | y = x.transpose(0, 1).contiguous() # C x B x H x W 367 | y = y.view(C, self.num_chunks, (B * H * W) // self.num_chunks) 368 | mean_max = y.max(-1)[0].mean(-1) # C 369 | mean_min = y.min(-1)[0].mean(-1) # C 370 | mean = y.view(C, -1).mean(-1) # C 371 | scale_fix = (0.5 * 0.35) * (1 + (math.pi * math.log(4)) ** 372 | 0.5) / ((2 * math.log(y.size(-1))) ** 0.5) 373 | 374 | scale = (mean_max - mean_min) * scale_fix 375 | with torch.no_grad(): 376 | self.running_mean.mul_(self.momentum).add_( 377 | mean * (1 - self.momentum)) 378 | 379 | self.running_var.mul_(self.momentum).add_( 380 | scale * (1 - self.momentum)) 381 | else: 382 | mean = self.running_mean 383 | scale = self.running_var 384 | # scale = quantize(scale, num_bits=self.num_bits, min_value=float( 385 | # scale.min()), max_value=float(scale.max())) 386 | out = (x - mean.view(1, -1, 1, 1)) / \ 387 | (scale.view(1, -1, 1, 1) + self.eps) 388 | 389 | if self.weight is not None: 390 | qweight = self.weight 391 | # qweight = quantize(self.weight, num_bits=self.num_bits, 392 | # min_value=float(self.weight.min()), 393 | # max_value=float(self.weight.max())) 394 | out = out * qweight.view(1, -1, 1, 1) 395 | 396 | if self.bias is not None: 397 | qbias = self.bias 398 | # qbias = quantize(self.bias, num_bits=self.num_bits) 399 | out = out + qbias.view(1, -1, 1, 1) 400 | if num_grad_bits: 401 | out = quantize_grad( 402 | out, num_bits=num_grad_bits, flatten_dims=(1, -1)) 403 | 404 | if out.size(3) == 1 and out.size(2) == 1: 405 | out = out.squeeze(-1).squeeze(-1) 406 | return out 407 | 408 | 409 | class RangeBN1d(RangeBN): 410 | # this is normalized RangeBN 411 | 412 | def __init__(self, num_features, dim=1, momentum=0.1, affine=True, num_chunks=16, eps=1e-5, num_bits=8, num_bits_grad=8): 413 | super(RangeBN1d, self).__init__(num_features, dim, momentum, 414 | affine, num_chunks, eps, num_bits, num_bits_grad) 415 | self.quantize_input = QuantMeasure( 416 | self.num_bits, inplace=True, shape_measure=(1, 1), flatten_dims=(1, -1)) 417 | 418 | if __name__ == '__main__': 419 | x = torch.rand(2, 3) 420 | x_q = quantize(x, flatten_dims=(-1), num_bits=8, dequantize=True) 421 | print(x) 422 | print(x_q) -------------------------------------------------------------------------------- /torchshiftadd/utils/ste.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | import numpy as np 4 | import math 5 | 6 | def dynamic_range_for_sign(sign, threshold): 7 | # print(sign, threshold) 8 | with torch.no_grad(): 9 | sign.data[sign.data < -threshold] = -1 10 | sign.data[sign.data > threshold] = 1 11 | sign.data[(-threshold <= sign.data) & (sign.data <= threshold)] = 0 12 | return sign 13 | 14 | class myRoundFunction(Function): 15 | @staticmethod 16 | def forward(ctx, input, threshold): 17 | return dynamic_range_for_sign(input, threshold) 18 | 19 | @staticmethod 20 | def backward(ctx, grad_output): 21 | return grad_output, None 22 | 23 | def myround(input, threshold): 24 | return myRoundFunction.apply(input, threshold) 25 | 26 | 27 | ############## 28 | 29 | def get_shift_and_sign(x, rounding='deterministic'): 30 | sign = torch.sign(x) 31 | 32 | x_abs = torch.abs(x) 33 | shift = round(torch.log(x_abs) / np.log(2), rounding) 34 | 35 | return shift, sign 36 | 37 | def round_power_of_2(x, rounding='deterministic'): 38 | shift, sign = get_shift_and_sign(x, rounding) 39 | # print(shift) 40 | x_rounded = (2.0 ** shift) * sign 41 | return x_rounded 42 | 43 | class RoundPowerOf2(Function): 44 | @staticmethod 45 | def forward(ctx, input, stochastic=False): 46 | return round_power_of_2(input, stochastic) 47 | 48 | @staticmethod 49 | def backward(ctx, grad_output): 50 | return grad_output, None 51 | 52 | def round_power_of_2(input, stochastic=False): 53 | return RoundPowerOf2.apply(input, stochastic) 54 | 55 | def round_to_fixed(input, fraction=16, integer=16): 56 | assert integer >= 1, integer 57 | if integer == 1: 58 | return torch.sign(input) - 1 59 | delta = math.pow(2.0, -(fraction)) 60 | bound = math.pow(2.0, integer-1) 61 | min_val = - bound 62 | max_val = bound - 1 63 | rounded = torch.floor(input / delta) * delta 64 | 65 | clipped_value = torch.clamp(rounded, min_val, max_val) 66 | return clipped_value 67 | 68 | class RoundFixedPoint(Function): 69 | @staticmethod 70 | def forward(ctx, input, quant_bits): 71 | return round_to_fixed(input, fraction=quant_bits) 72 | 73 | @staticmethod 74 | def backward(ctx, grad_output): 75 | return grad_output, None 76 | 77 | def round_fixed_point(input, quant_bits): 78 | return RoundFixedPoint.apply(input, quant_bits) 79 | 80 | def new_round(x, rounding='deterministic'): 81 | assert(rounding in ['deterministic', 'stochastic']) 82 | if rounding == 'stochastic': 83 | x_floor = x.floor() 84 | return x_floor + torch.bernoulli(x - x_floor) 85 | else: 86 | return x.round() 87 | 88 | class RoundFunction(Function): 89 | @staticmethod 90 | def forward(ctx, input, rounding='deterministic'): 91 | return new_round(input, rounding) 92 | 93 | @staticmethod 94 | def backward(ctx, grad_output): 95 | return grad_output, None 96 | 97 | def round(input, rounding='deterministic'): 98 | return RoundFunction.apply(input, rounding) 99 | 100 | class SignFunction(Function): 101 | @staticmethod 102 | def forward(ctx, input): 103 | return torch.sign(input) 104 | 105 | @staticmethod 106 | def backward(ctx, grad_output): 107 | return grad_output 108 | 109 | def sign(input): 110 | return SignFunction.apply(input) 111 | 112 | class ClampFunction(Function): 113 | @staticmethod 114 | def forward(ctx, input, min, max): 115 | return torch.clamp(input, min, max) 116 | 117 | @staticmethod 118 | def backward(ctx, grad_output): 119 | return grad_output, None, None 120 | 121 | def clamp(input, min, max): 122 | return ClampFunction.apply(input, min, max) 123 | 124 | class ClampAbsFunction(Function): 125 | @staticmethod 126 | def forward(ctx, input, min, max): 127 | assert(min >= 0 and max >=0) 128 | 129 | input[input > max] = max 130 | input[input < -max] = -max 131 | 132 | input[(input > torch.zeros_like(input)) & (input < min)] = min 133 | input[(input < torch.zeros_like(input)) & (input > -min)] = -min 134 | return input 135 | 136 | @staticmethod 137 | def backward(ctx, grad_output): 138 | return grad_output, None, None 139 | 140 | def clampabs(input, min, max): 141 | return ClampAbsFunction.apply(input, min, max) 142 | 143 | class LogFunction(Function): 144 | @staticmethod 145 | def forward(ctx, input): 146 | return torch.log(input) 147 | 148 | @staticmethod 149 | def backward(ctx, grad_output): 150 | return grad_output 151 | 152 | def log(input): 153 | return LogFunction.apply(input) 154 | 155 | class UnsymmetricGradMulFunction(Function): 156 | @staticmethod 157 | def forward(ctx, input1, input2): 158 | ctx.save_for_backward(input1, input2) 159 | return torch.mul(input1, input2) 160 | 161 | @staticmethod 162 | def backward(ctx, grad_output): 163 | input1, input2 = ctx.saved_tensors 164 | return grad_output*input2, grad_output 165 | 166 | def unsym_grad_mul(input1, input2): 167 | return UnsymmetricGradMulFunction.apply(input1, input2) 168 | 169 | 170 | class AbsFunction(Function): 171 | @staticmethod 172 | def forward(ctx, input): 173 | return torch.abs(input) 174 | 175 | @staticmethod 176 | def backward(ctx, grad_output): 177 | return grad_output 178 | 179 | def abs(input): 180 | return AbsFunction.apply(input) -------------------------------------------------------------------------------- /torchshiftadd/utils/test_acc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import numpy as np 5 | 6 | def accuracy(output, target, topk=(1,)): 7 | """Computes the precision@k for the specified values of k""" 8 | maxk = max(topk) 9 | batch_size = target.size(0) 10 | 11 | _, pred = output.topk(maxk, 1, True, True) 12 | pred = pred.t() 13 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 14 | 15 | res = [] 16 | for k in topk: 17 | correct_k = correct[:k].reshape(-1).float().sum(0) 18 | res.append(correct_k.mul_(100.0 / batch_size)) 19 | return res 20 | 21 | def test_acc(model, test_loader): 22 | model.eval() 23 | test_loss = 0 24 | test_acc = 0 25 | test_acc_5 = 0 26 | for data, target in test_loader: 27 | data, target = data.cuda(), target.cuda() 28 | data, target = Variable(data, volatile=True), Variable(target) 29 | output = model(data) 30 | test_loss += F.cross_entropy(output, target, size_average=False).item() # sum up batch loss 31 | prec1, prec5 = accuracy(output.data, target.data, topk=(1, 5)) 32 | test_acc += prec1.item() 33 | test_acc_5 += prec5.item() 34 | 35 | test_loss /= len(test_loader.dataset) 36 | print('\nTest set: Average loss: {:.4f}, Prec1: {}/{} ({:.2f}%), Prec5: ({:.2f}%)\n'.format( 37 | test_loss, test_acc, len(test_loader), test_acc / len(test_loader), test_acc_5 / len(test_loader))) 38 | return np.round(test_acc / len(test_loader), 2), np.round(test_acc_5 / len(test_loader), 2) -------------------------------------------------------------------------------- /torchshiftadd/utils/torch.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.utils import cpp_extension 5 | from . import decorator, comm 6 | 7 | class LazyExtensionLoader(object): 8 | 9 | def __init__(self, name, sources, extra_cflags=None, extra_cuda_cflags=None, extra_ldflags=None, 10 | extra_include_paths=None, build_directory=None, verbose=False, **kwargs): 11 | self.name = name 12 | self.sources = sources 13 | self.extra_cflags = extra_cflags 14 | self.extra_cuda_cflags = extra_cuda_cflags 15 | self.extra_ldflags = extra_ldflags 16 | self.extra_include_paths = extra_include_paths 17 | worker_name = "%s_%d" % (name, comm.get_rank()) 18 | self.build_directory = build_directory or cpp_extension._get_build_directory(worker_name, verbose) 19 | self.verbose = verbose 20 | self.kwargs = kwargs 21 | 22 | def __getattr__(self, key): 23 | return getattr(self.module, key) 24 | 25 | @decorator.cached_property 26 | def module(self): 27 | return cpp_extension.load(self.name, self.sources, self.extra_cflags, self.extra_cuda_cflags, 28 | self.extra_ldflags, self.extra_include_paths, self.build_directory, 29 | self.verbose, **self.kwargs) 30 | 31 | 32 | def load_extension(name, sources, extra_cflags=None, extra_cuda_cflags=None, **kwargs): 33 | """ 34 | Load a PyTorch C++ extension just-in-time (JIT). 35 | Automatically decide the compilation flags if not specified. 36 | 37 | This function performs lazy evaluation and is multi-process-safe. 38 | 39 | See `torch.utils.cpp_extension.load`_ for more details. 40 | 41 | .. _torch.utils.cpp_extension.load: 42 | https://pytorch.org/docs/stable/cpp_extension.html#torch.utils.cpp_extension.load 43 | """ 44 | if extra_cflags is None: 45 | extra_cflags = ["-Ofast"] 46 | if torch.backends.openmp.is_available(): 47 | extra_cflags += ["-fopenmp", "-DAT_PARALLEL_OPENMP"] 48 | else: 49 | extra_cflags.append("-DAT_PARALLEL_NATIVE") 50 | if extra_cuda_cflags is None: 51 | if torch.cuda.is_available(): 52 | extra_cuda_cflags = ["-O3"] 53 | extra_cflags.append("-DCUDA_OP") 54 | else: 55 | new_sources = [] 56 | for source in sources: 57 | if not cpp_extension._is_cuda_file(source): 58 | new_sources.append(source) 59 | sources = new_sources 60 | 61 | return LazyExtensionLoader(name, sources, extra_cflags, extra_cuda_cflags, **kwargs) 62 | 63 | --------------------------------------------------------------------------------