├── .gitignore
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── figs
└── logo_torchshiftadd.png
├── setup.py
├── test
├── layers
│ ├── test_adder.py
│ └── test_shift.py
└── models
│ ├── test_resnet20_adder.py
│ ├── test_resnet20_shift.py
│ └── test_resnet20_shiftadd.py
└── torchshiftadd
├── layers
├── __init__.py
├── adder.py
├── attention.py
├── extension
│ ├── adder_cuda.cpp
│ ├── adder_cuda_kernel.cu
│ └── adder_cuda_kernel_torch_1.4.cu
└── shift.py
├── models
├── __init__.py
├── resnet20.py
├── resnet20_adder.py
├── resnet20_shift.py
└── resnet20_shiftadd.py
└── utils
├── __init__.py
├── ckpt_loading.py
├── comm.py
├── decorator.py
├── quantize.py
├── ste.py
├── test_acc.py
└── torch.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | ckpts/
11 | data.cifar*/
12 | .Python
13 | build/
14 | develop-eggs/
15 | dist/
16 | downloads/
17 | eggs/
18 | .eggs/
19 | lib/
20 | lib64/
21 | parts/
22 | sdist/
23 | var/
24 | wheels/
25 | share/python-wheels/
26 | *.egg-info/
27 | .installed.cfg
28 | *.egg
29 | MANIFEST
30 |
31 | # PyInstaller
32 | # Usually these files are written by a python script from a template
33 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
34 | *.manifest
35 | *.spec
36 |
37 | # Installer logs
38 | pip-log.txt
39 | pip-delete-this-directory.txt
40 |
41 | # Unit test / coverage reports
42 | htmlcov/
43 | .tox/
44 | .nox/
45 | .coverage
46 | .coverage.*
47 | .cache
48 | nosetests.xml
49 | coverage.xml
50 | *.cover
51 | *.py,cover
52 | .hypothesis/
53 | .pytest_cache/
54 | cover/
55 |
56 | # Translations
57 | *.mo
58 | *.pot
59 |
60 | # Django stuff:
61 | *.log
62 | local_settings.py
63 | db.sqlite3
64 | db.sqlite3-journal
65 |
66 | # Flask stuff:
67 | instance/
68 | .webassets-cache
69 |
70 | # Scrapy stuff:
71 | .scrapy
72 |
73 | # Sphinx documentation
74 | docs/_build/
75 |
76 | # PyBuilder
77 | .pybuilder/
78 | target/
79 |
80 | # Jupyter Notebook
81 | .ipynb_checkpoints
82 |
83 | # IPython
84 | profile_default/
85 | ipython_config.py
86 |
87 | # pyenv
88 | # For a library or package, you might want to ignore these files since the code is
89 | # intended to run in multiple environments; otherwise, check them in:
90 | # .python-version
91 |
92 | # pipenv
93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
96 | # install all needed dependencies.
97 | #Pipfile.lock
98 |
99 | # poetry
100 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
101 | # This is especially recommended for binary packages to ensure reproducibility, and is more
102 | # commonly ignored for libraries.
103 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
104 | #poetry.lock
105 |
106 | # pdm
107 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
108 | #pdm.lock
109 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
110 | # in version control.
111 | # https://pdm.fming.dev/#use-with-ide
112 | .pdm.toml
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
163 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | Coming soon.
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | A PyTorch library for developing energy efficient multiplication-less models.
7 |
8 |
9 | [](https://opensource.org/licenses/Apache-2.0)
10 | [](https://github.com/GATECH-EIC/torchshiftadd/blob/master/CONTRIBUTING.md)
11 |
12 | # TorchShiftAdd Overview
13 |
14 | Welcome to TorchShiftAdd, your go-to open-source library for crafting energy-efficient multiplication-less models and applications!
15 |
16 | [TorchShiftAdd](https://github.com/GATECH-EIC/torchshiftadd) embodies a pioneering initiative to simplify and expand the realm of multiplication-less networks within the machine learning community. Key features include:
17 |
18 | * Ready-to-use implementation of a wide range of ShiftAdd-based multiplication-less CNNs or Transformers.
19 | * CUDA kernels and TVM compilation support for seamless GPU deployment.
20 | * Profiling tools to furnish FLOPs, energy, and latency breakdown data for in-depth analysis and optimization.
21 | * Hardware accelerator simulators to estimate energy savings and latency improvements on ASICs or FPGAs.
22 | * Flexible support for developing both algorithmic and hardware accelerator designs tailored for multiplication-less networks.
23 |
24 |
25 |
26 | ## List of Implemented Papers
27 | * **ShiftAdd-based Convolutional Neural Networks**
28 | + [[NeurIPS'20] ShiftAddNet: A Hardware-Inspired Deep Network](https://arxiv.org/abs/2010.12785)
29 | + [[CVPR'20 Oral] AdderNet: Do We Really Need Multiplications in Deep Learning?](https://arxiv.org/abs/1912.13200)
30 | + [[CVPR'21 Workshop] DeepShift: Towards Multiplication-Less Neural Networks](https://arxiv.org/abs/1905.13298)
31 | * **ShiftAdd-based Transformers**
32 | + [[NeurIPS'23] ShiftAddViT: Mixture of Multiplication Primitives Towards Efficient Vision Transformer](https://arxiv.org/abs/2306.06446)
33 | + [[NeurIPS'24] ShiftAddLLM: Accelerating Pretrained LLMs via Post-Training Multiplication-Less Reparameterization](https://arxiv.org/abs/2406.05981)
34 | * **Hardware Accelerators for ShiftAdd-based Multiplication-less Networks**
35 | + [[ICCAD'22] NASA: Neural Architecture Search and Acceleration for Hardware Inspired Hybrid Networks](https://arxiv.org/abs/2210.13361)
36 | + [[IEEE TCAS-I] NASA+: Neural Architecture Search and Acceleration for Multiplication-Reduced Hybrid Networks](https://ieeexplore.ieee.org/document/10078392)
37 |
38 | # Installation
39 |
40 | ````bash
41 | python setup.py install -e .
42 | ````
43 |
44 | # Qucik Start
45 |
46 | Currently codebase supports ShiftAdd-based CNNs. To use them, check our test files:
47 |
48 | ````bash
49 | python test/models/test_resnet20_adder.py
50 | python test/models/test_resnet20_shift.py
51 | python test/models/test_resnet20_shiftadd.py
52 | ````
53 |
54 | # Upcoming Features
55 |
56 | We will continously develop this toolbox:
57 |
58 | - [x] ShiftAdd-based Convolutional Neural Networks
59 | - [ ] ShiftAdd-based Transformers
60 | - [ ] Hardware Accelerators for Energy & Latency Estimation
61 |
62 | # Contributing
63 |
64 | TorchShiftAdd is released under [Apache-2.0 License](LICENSE). Everyone is welcome to contribute to the development of TorchShiftAdd. Please refer to [contributing guidelines](CONTRIBUTING.md) for more details.
65 |
66 | # Acknowledgement
67 |
68 | All co-authors of ShiftAddNet, ShiftAddNAS, ShiftAddViT, and ShiftAddLLM.
--------------------------------------------------------------------------------
/figs/logo_torchshiftadd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GATECH-EIC/torchshiftadd/f48be2983b0a0f3a299eca449e88c6fb8e5264fb/figs/logo_torchshiftadd.png
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import setuptools
2 |
3 | with open("README.md", "r") as fh:
4 | long_description = fh.read()
5 |
6 | setuptools.setup(
7 | name="torchshiftadd",
8 | version="0.0.1",
9 | author="TorchShiftAdd Team",
10 | author_email="haoran.you@gatech.edu",
11 | description="A PyTorch library for developing energy efficient multiplication-less models",
12 | license="Apache License 2.0",
13 | long_description=long_description,
14 | long_description_content_type="text/markdown",
15 | url="https://github.com/GATECH-EIC/torchshiftadd",
16 | packages=setuptools.find_packages(),
17 | package_data={
18 | "torchshiftadd": [
19 | "layers/extension/adder_cuda.cpp",
20 | "layers/extension/adder_cuda_kernel.cu",
21 | ]
22 | },
23 | install_requires=[
24 | "torch>=1.7.0",
25 | "torchvision",
26 | "numpy>=1.19.0",
27 | "scipy>=1.5.0",
28 | "scikit-learn>=0.23.0",
29 | "matplotlib>=3.2.0",
30 | "tqdm>=4.46.0",
31 | "ninja",
32 | ],
33 | python_requires=">=3.6,<3.13",
34 | classifiers=[
35 | "Programming Language :: Python :: 3",
36 | "Operating System :: OS Independent",
37 | ],
38 | )
--------------------------------------------------------------------------------
/test/layers/test_adder.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import torch
3 |
4 | from torchshiftadd import layers
5 |
6 | class Adder2DTest(unittest.TestCase):
7 |
8 | def setup(self):
9 | self.input = torch.rand(1, 3, 32, 32)
10 | self.weight = torch.rand(64, 3, 3, 3)
11 | self.bias = torch.rand(64)
12 | self.stride = 1
13 | self.padding = 1
14 | self.groups = 1
15 | self.eta = 1.0
16 |
17 | def test_adder2d(self):
18 | self.setup()
19 | adder = layers.Adder2D(
20 | input_channel=3,
21 | output_channel=64,
22 | kernel_size=3,
23 | stride=self.stride,
24 | padding=self.padding,
25 | groups=self.groups,
26 | bias=True,
27 | eta=self.eta,
28 | )
29 | output = adder(self.input)
30 | self.assertEqual(output.shape, (1, 64, 32, 32))
31 |
32 | if __name__ == "__main__":
33 | unittest.main()
--------------------------------------------------------------------------------
/test/layers/test_shift.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import torch
3 |
4 | from torchshiftadd import layers
5 |
6 | class LinearShiftTest(unittest.TestCase):
7 |
8 | def setup(self):
9 | self.input = torch.rand(32, 32)
10 |
11 | def test_adder2d(self):
12 | self.setup()
13 | shift = layers.LinearShift(
14 | in_features=32,
15 | out_features=64,
16 | bias=True,
17 | )
18 | output = shift(self.input)
19 | self.assertEqual(output.shape, (32, 64))
20 |
21 | class Conv2dShiftTest(unittest.TestCase):
22 |
23 | def setup(self):
24 | self.input = torch.rand(1, 3, 32, 32)
25 | self.weight = torch.rand(64, 3, 3, 3)
26 | self.bias = torch.rand(64)
27 | self.stride = 1
28 | self.padding = 1
29 | self.groups = 1
30 |
31 | def test_adder2d(self):
32 | self.setup()
33 | shift = layers.Conv2dShift(
34 | in_channels=3,
35 | out_channels=64,
36 | kernel_size=3,
37 | stride=self.stride,
38 | padding=self.padding,
39 | groups=self.groups,
40 | bias=True,
41 | weight_bits=4,
42 | input_bits=16,
43 | )
44 | output = shift(self.input)
45 | self.assertEqual(output.shape, (1, 64, 32, 32))
46 |
47 | if __name__ == "__main__":
48 | unittest.main()
--------------------------------------------------------------------------------
/test/models/test_resnet20_adder.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import torch
3 | import torchvision.transforms as transforms
4 | import torchvision.datasets as datasets
5 | from huggingface_hub import hf_hub_download
6 | from torchshiftadd import layers, models
7 | from torchshiftadd.utils import load_add_state_dict, test_acc
8 |
9 | class ResNetAdderShapeTest(unittest.TestCase):
10 |
11 | def setup(self):
12 | self.input = torch.rand(2, 3, 32, 32)
13 | self.model = models.resnet20_adder()
14 |
15 | def test_resnet_adder(self):
16 | self.setup()
17 | output = self.model(self.input)
18 | self.assertEqual(output.shape, (2, 10))
19 |
20 | class ResNetAdderAccTest(unittest.TestCase):
21 |
22 | def setup(self):
23 | self.model = models.resnet20_adder().cuda()
24 | self.test_dataloader = torch.utils.data.DataLoader(
25 | datasets.CIFAR10(
26 | './data.cifar10',
27 | train=False,
28 | download=True,
29 | transform=transforms.Compose([
30 | transforms.ToTensor(),
31 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
32 | ])
33 | ),
34 | batch_size=64,
35 | shuffle=True,
36 | )
37 |
38 | def test_resnet_adder(self):
39 | self.setup()
40 |
41 | ckpt_path = hf_hub_download(
42 | repo_id="hryou1998/pretrained-ckpts",
43 | filename="resnet20-adder-cifar10-FP32.pth.tar",
44 | )
45 |
46 | checkpoint = torch.load(ckpt_path, map_location='cpu')
47 | self.model.load_state_dict(load_add_state_dict(checkpoint['state_dict']))
48 |
49 | top_1, top_5 = test_acc(self.model, self.test_dataloader)
50 |
51 | print("Top-1 Acc: {:.2f}%".format(top_1))
52 | print("Top-5 Acc: {:.2f}%".format(top_5))
53 |
54 | if __name__ == "__main__":
55 | unittest.main()
--------------------------------------------------------------------------------
/test/models/test_resnet20_shift.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import torch
3 | import torchvision.transforms as transforms
4 | import torchvision.datasets as datasets
5 | from huggingface_hub import hf_hub_download
6 | from torchshiftadd import layers, models
7 | from torchshiftadd.utils import test_acc
8 |
9 | class ResNetShiftShapeTest(unittest.TestCase):
10 |
11 | def setup(self):
12 | self.input = torch.rand(2, 3, 32, 32)
13 | self.model = models.resnet20()
14 | models.convert_to_shift(self.model)
15 |
16 | def test_resnet_shift(self):
17 | self.setup()
18 | output = self.model(self.input)
19 | self.assertEqual(output.shape, (2, 10))
20 |
21 | class ResNetShiftAccTest(unittest.TestCase):
22 |
23 | def setup(self):
24 | self.model = models.resnet20()
25 | models.convert_to_shift(self.model)
26 | self.model.cuda()
27 | self.test_dataloader = torch.utils.data.DataLoader(
28 | datasets.CIFAR10(
29 | './data.cifar10',
30 | train=False,
31 | download=True,
32 | transform=transforms.Compose([
33 | transforms.ToTensor(),
34 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
35 | ])
36 | ),
37 | batch_size=64,
38 | shuffle=True,
39 | )
40 |
41 | def test_resnet_shift(self):
42 | self.setup()
43 |
44 | ckpt_path = hf_hub_download(
45 | repo_id="hryou1998/pretrained-ckpts",
46 | filename="resnet20-shift-cifar10.pth.tar",
47 | )
48 |
49 | checkpoint = torch.load(ckpt_path, map_location='cpu')
50 | self.model.load_state_dict(checkpoint['state_dict'])
51 | models.convert_to_shift(self.model)
52 |
53 | top_1, top_5 = test_acc(self.model, self.test_dataloader)
54 |
55 | print("Top-1 Acc: {:.2f}%".format(top_1))
56 | print("Top-5 Acc: {:.2f}%".format(top_5))
57 |
58 | if __name__ == "__main__":
59 | unittest.main()
--------------------------------------------------------------------------------
/test/models/test_resnet20_shiftadd.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import torch
3 | import torchvision.transforms as transforms
4 | import torchvision.datasets as datasets
5 | from huggingface_hub import hf_hub_download
6 | from torchshiftadd import layers, models
7 | from torchshiftadd.utils import load_shiftadd_state_dict, test_acc
8 |
9 | class ResNetShiftAddShapeTest(unittest.TestCase):
10 |
11 | def setup(self):
12 | self.input = torch.rand(2, 3, 32, 32)
13 | self.model = models.resnet20_shiftadd()
14 |
15 | def test_resnet_shiftadd(self):
16 | self.setup()
17 | output = self.model(self.input)
18 | self.assertEqual(output.shape, (2, 10))
19 |
20 | class ResNetShiftAddAccTest(unittest.TestCase):
21 |
22 | def setup(self):
23 | self.model = models.resnet20_shiftadd().cuda()
24 | self.test_dataloader = torch.utils.data.DataLoader(
25 | datasets.CIFAR10(
26 | './data.cifar10',
27 | train=False,
28 | download=True,
29 | transform=transforms.Compose([
30 | transforms.ToTensor(),
31 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
32 | ])
33 | ),
34 | batch_size=64,
35 | shuffle=True,
36 | )
37 |
38 | def test_resnet_shiftadd(self):
39 | self.setup()
40 |
41 | ckpt_path = hf_hub_download(
42 | repo_id="hryou1998/pretrained-ckpts",
43 | filename="resnet20-shiftadd-cifar10-FIX16.pth.tar",
44 | )
45 |
46 | checkpoint = torch.load(ckpt_path, map_location='cpu')
47 | self.model.load_state_dict(load_shiftadd_state_dict(checkpoint['state_dict']), strict=False)
48 |
49 | top_1, top_5 = test_acc(self.model, self.test_dataloader)
50 |
51 | print("Top-1 Acc: {:.2f}%".format(top_1))
52 | print("Top-5 Acc: {:.2f}%".format(top_5))
53 |
54 | if __name__ == "__main__":
55 | unittest.main()
--------------------------------------------------------------------------------
/torchshiftadd/layers/__init__.py:
--------------------------------------------------------------------------------
1 | from .adder import Adder2D
2 | from .shift import LinearShift, Conv2dShift
--------------------------------------------------------------------------------
/torchshiftadd/layers/adder.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import torch
4 | import torch.nn as nn
5 | import numpy as np
6 |
7 | from torchshiftadd import utils
8 |
9 | path = os.path.join(os.path.dirname(__file__), "extension")
10 | adder_cuda = utils.load_extension(
11 | "adder_cuda", [
12 | os.path.join(path, "adder_cuda.cpp"),
13 | os.path.join(path, "adder_cuda_kernel.cu"),
14 | ]
15 | )
16 |
17 | class Adder2D(nn.Module):
18 |
19 | def __init__(self,
20 | input_channel,
21 | output_channel,
22 | kernel_size,
23 | stride = 1,
24 | padding = 0,
25 | groups = 1,
26 | bias = False,
27 | eta = 0.2,
28 | ):
29 | super(Adder2D, self).__init__()
30 | self.stride = stride
31 | self.padding = padding
32 | self.groups = groups
33 | self.input_channel = input_channel
34 | self.output_channel = output_channel
35 | self.kernel_size = kernel_size
36 | self.bias = bias
37 | self.eta = eta
38 |
39 | self.adder = torch.nn.Parameter(
40 | nn.init.normal_(
41 | torch.randn(output_channel, input_channel // groups, kernel_size, kernel_size)
42 | )
43 | )
44 |
45 | if self.bias:
46 | self.bias = torch.nn.Parameter(
47 | nn.init.uniform_(torch.zeros(output_channel))
48 | )
49 | else:
50 | self.bias = None
51 |
52 | def forward(self, input, ratio_out=1, ratio_in=1, ratio_g=1, kernel=None):
53 |
54 | sample_weight = self.adder[:(self.output_channel//ratio_out),:(self.input_channel//ratio_in),:,:]
55 | if (kernel!=None):
56 | start, end = sub_filter_start_end(5, kernel)
57 | sample_weight = sample_weight[:,:, start:end, start:end]
58 | padding = kernel//2
59 | else:
60 | padding = self.padding
61 |
62 | output = Adder2DFunction.apply(
63 | input,
64 | sample_weight,
65 | self.kernel_size,
66 | self.stride,
67 | padding,
68 | (self.groups//ratio_g),
69 | self.eta,
70 | )
71 | if self.bias is not None:
72 | output += self.bias.unsqueeze(0).unsqueeze(2).unsqueeze(3)
73 |
74 | return output
75 |
76 |
77 | class Adder2DFunction(torch.autograd.Function):
78 | @staticmethod
79 | def forward(ctx, input, weight, kernel_size, stride, padding, groups, eta):
80 | ctx.save_for_backward(input, weight)
81 | ctx.kernel_size = kernel_size
82 | ctx.stride = stride
83 | ctx.padding = padding
84 | ctx.groups = groups
85 | ctx.eta = eta
86 | ctx.quantize = False
87 |
88 | output = input.new_zeros(
89 | get_conv2d_output_shape(input, weight, stride, padding)
90 | )
91 |
92 | adder_cuda.forward(
93 | input,
94 | weight,
95 | output,
96 | kernel_size, kernel_size,
97 | stride, stride,
98 | padding, padding,
99 | groups, groups
100 | )
101 |
102 | return output
103 |
104 | @staticmethod
105 | def backward(ctx, grad_output):
106 | input, weight = ctx.saved_tensors
107 | grad_input = grad_weight = None
108 | eta, kernel_size, stride, padding, groups = (
109 | ctx.eta, ctx.kernel_size, ctx.stride, ctx.padding, ctx.groups
110 | )
111 |
112 | # input
113 | if ctx.needs_input_grad[0]:
114 | grad_input = torch.zeros_like(input)
115 | adder_cuda.backward_input(
116 | grad_output,
117 | input,
118 | weight,
119 | grad_input,
120 | kernel_size, kernel_size,
121 | stride, stride,
122 | padding, padding,
123 | groups, groups
124 | )
125 |
126 | # weight
127 | if ctx.needs_input_grad[1]:
128 | grad_weight = torch.zeros_like(weight)
129 | adder_cuda.backward_weight(
130 | grad_output,
131 | input,
132 | weight,
133 | grad_weight,
134 | kernel_size, kernel_size,
135 | stride, stride,
136 | padding, padding,
137 | groups, groups)
138 | grad_weight = eta * np.sqrt(grad_weight.numel()) / torch.norm(grad_weight).clamp(min=1e-12) * grad_weight
139 |
140 | return grad_input, grad_weight, None, None, None, None, None
141 |
142 |
143 | def get_conv2d_output_shape(input, weight, stride, padding):
144 | n_filters, d_filter, h_filter, w_filter = weight.size()
145 | n_x, d_x, h_x, w_x = input.size()
146 |
147 | h_out = (h_x - h_filter + 2 * padding) // stride + 1
148 | w_out = (w_x - w_filter + 2 * padding) // stride + 1
149 |
150 | return (n_x, n_filters, h_out, w_out)
151 |
152 |
153 | def sub_filter_start_end(kernel_size, sub_kernel_size):
154 | center = kernel_size // 2
155 | dev = sub_kernel_size // 2
156 | start, end = center - dev, center + dev + 1
157 | assert end - start == sub_kernel_size
158 | return start, end
--------------------------------------------------------------------------------
/torchshiftadd/layers/attention.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GATECH-EIC/torchshiftadd/f48be2983b0a0f3a299eca449e88c6fb8e5264fb/torchshiftadd/layers/attention.py
--------------------------------------------------------------------------------
/torchshiftadd/layers/extension/adder_cuda.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | int adder_cuda_forward(
5 | const at::Tensor &input,
6 | const at::Tensor &weight,
7 | // const at::Tensor &bias,
8 | at::Tensor &output,
9 | int KW, int KH,
10 | int SW, int SH,
11 | int PW, int PH,
12 | int GW, int GH
13 | );
14 |
15 | int adder_cuda_backward_grad_in(
16 | at::Tensor &grad_out,
17 | at::Tensor &input,
18 | at::Tensor &weight,
19 | at::Tensor &grad_in,
20 | int KW, int KH,
21 | int SW, int SH,
22 | int PW, int PH,
23 | int GW, int GH
24 | );
25 |
26 | int adder_cuda_backward_grad_weight(
27 | at::Tensor &grad_out,
28 | at::Tensor &input,
29 | at::Tensor &weight,
30 | at::Tensor &grad_weight,
31 | int KW, int KH,
32 | int SW, int SH,
33 | int PW, int PH,
34 | int GW, int GH
35 | );
36 |
37 | #define CHECK_CUDA(x) AT_ASSERT((x).type().is_cuda(), #x "must be a CUDA tensor")
38 | #define CHECK_CONTIGUOUS(x) AT_ASSERT((x).type().is_contiguous(), #x "must be contiguous")
39 | #define CHECK_INPUT(x) \
40 | CHECK_CUDA((x)); \
41 | CHECK_CONTIGUOUS((x))
42 |
43 | int adder_forward(
44 | const at::Tensor &input,
45 | const at::Tensor &weight,
46 | // const at::Tensor &bias,
47 | at::Tensor &output,
48 | int KW, int KH,
49 | int SW, int SH,
50 | int PW, int PH,
51 | int GW, int GH
52 | )
53 | {
54 | // TODO: add checks checks
55 | return adder_cuda_forward(
56 | input,
57 | weight,
58 | // bias,
59 | output,
60 | KW, KH,
61 | SW, SH,
62 | PW, PH,
63 | GW, GH
64 | );
65 | }
66 |
67 | int adder_backward_input(
68 | at::Tensor &grad_out,
69 | at::Tensor &input,
70 | at::Tensor &weight,
71 | at::Tensor &grad_in,
72 | int KW, int KH,
73 | int SW, int SH,
74 | int PW, int PH,
75 | int GW, int GH
76 | )
77 | {
78 | // TODO: add checks checks
79 | return adder_cuda_backward_grad_in(
80 | grad_out,
81 | input,
82 | weight,
83 | grad_in,
84 | KW, KH,
85 | SW, SH,
86 | PW, PH,
87 | GW, GH
88 | );
89 | }
90 |
91 | int adder_backward_weight(
92 | at::Tensor &grad_out,
93 | at::Tensor &input,
94 | at::Tensor &weight,
95 | at::Tensor &grad_weight,
96 | int KW, int KH,
97 | int SW, int SH,
98 | int PW, int PH,
99 | int GW, int GH
100 | )
101 | {
102 | // TODO: add checks checks
103 | return adder_cuda_backward_grad_weight(
104 | grad_out,
105 | input,
106 | weight,
107 | grad_weight,
108 | KW, KH,
109 | SW, SH,
110 | PW, PH,
111 | GW, GH
112 | );
113 | }
114 |
115 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
116 | {
117 | m.def("forward", &adder_forward, "adder forward (CUDA)");
118 | m.def("backward_input", &adder_backward_input, "adder backward input (CUDA)");
119 | m.def("backward_weight", &adder_backward_weight, "adder backward weight (CUDA)");
120 | }
--------------------------------------------------------------------------------
/torchshiftadd/layers/extension/adder_cuda_kernel.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | #include
6 |
7 | #include
8 | #include
9 | #define MAX_BLOCKS 256
10 | #define NUM_THREADS 256
11 | #define MAX(a, b) ((a) > (b)) ? (a) : (b)
12 | #define MIN(a, b) ((a) < (b)) ? (a) : (b)
13 | #define HARDTANH(x) ((x) < (-1.0)) ? (-1.0) : (((x) <= (1.0)) ? (x) : (1.0))
14 | const int WARP_SIZE = 32;
15 | const int MAX_BLOCK_SIZE = 256;
16 |
17 | template
18 | struct SharedMem {
19 | __device__ T *getPointer() {
20 | __shared__ T smem[MAX_BLOCK_SIZE];
21 | return smem;
22 | }
23 | };
24 |
25 | static int getGradParamsNumThreads(int batchSize) {
26 | return std::min(batchSize * WARP_SIZE, MAX_BLOCK_SIZE);
27 | }
28 |
29 | int get_blocks(int n) {
30 | return MIN(MAX_BLOCKS, (n - NUM_THREADS + 1) / NUM_THREADS) + 1;
31 | }
32 |
33 | template
34 | __global__ void adder_forward_kernel(
35 | const scalar_t* __restrict__ input,
36 | const scalar_t* __restrict__ weight,
37 | scalar_t* __restrict__ output,
38 | const int num_elem,
39 | const int out_channels,
40 | const int in_channels,
41 | const int IW, const int IH,
42 | const int OW, const int OH,
43 | const int KW, const int KH,
44 | const int SW, const int SH,
45 | const int PW, const int PH,
46 | const int GW, const int GH)
47 | {
48 | // #TODO:
49 | if (GW==1)
50 | {
51 | for (int index = blockDim.x * blockIdx.x + threadIdx.x; index < num_elem; index += gridDim.x * blockDim.x)
52 | {
53 | const int n = index / OW / OH / out_channels;
54 | const int m = index / OW / OH % out_channels;
55 | const int h = index / OW % OH;
56 | const int w = index % OW;
57 |
58 | const scalar_t *p_weight = weight + m * in_channels * KH * KW; //the start position of the kernel(corresponding to the output)
59 | // scalar_t value = bias[m];
60 | scalar_t value = 0;
61 | // #TODO:
62 | // #pragma unroll
63 | for (int cc = 0; cc < in_channels; cc++)
64 | {
65 | // #pragma unroll
66 | const int image_offset0 = (n * in_channels + cc) * IH * IW;
67 | for (int kh = 0; kh < KH; kh++)
68 | {
69 | // #pragma unroll
70 | for (int kw = 0; kw < KW; kw++)
71 | {
72 | const int ih = h * SH - PH + kh;
73 | const int iw = w * SW - PW + kw;
74 |
75 | bool boundary_condition = (ih >= 0) && (ih < IH) && (iw >= 0) && (iw < IW);
76 | if (boundary_condition)
77 | {
78 | // value += input[image_offset0 + ih * IW + iw] * (*p_weight);
79 | value -= abs(input[image_offset0 + ih * IW + iw] - (*p_weight));
80 | }
81 | else // padded area
82 | {
83 | value -= abs(*p_weight);
84 | }
85 | p_weight++;
86 | }
87 | }
88 | }
89 | output[index] = value;
90 | }
91 | }
92 | if (GW==2)
93 | {
94 | for (int index = blockDim.x * blockIdx.x + threadIdx.x; index < num_elem; index += gridDim.x * blockDim.x)
95 | // total size of output (batch_size * out_channels * W-out * H_out)
96 | {
97 | // TODO
98 | const int n = index / OW / OH / out_channels; // batch size of n
99 | const int m = index / OW / OH % out_channels; // relative output channel size
100 | const int h = index / OW % OH; // relative position of H
101 | const int w = index % OW; //relative position of W
102 |
103 | const scalar_t *p_weight = weight + m * in_channels/2 * KH * KW; //the start position of the kernel(corresponding to the output)
104 |
105 | // scalar_t value = bias[m];
106 | scalar_t value = 0;
107 | // #TODO:
108 | // #pragma unroll
109 | if (m < out_channels/2)
110 | {
111 | for (int cc = 0; cc < in_channels/2; cc++)
112 | {
113 | // #pragma unroll
114 | const int image_offset0 = (n * in_channels + cc) * IH * IW; //channel offset (absolute)
115 | for (int kh = 0; kh < KH; kh++)
116 | {
117 | // #pragma unroll
118 | for (int kw = 0; kw < KW; kw++)
119 | {
120 | const int ih = h * SH - PH + kh; // *stride-padding of H
121 | const int iw = w * SW - PW + kw; // *stride-padding of W
122 |
123 | bool boundary_condition = (ih >= 0) && (ih < IH) && (iw >= 0) && (iw < IW);
124 | if (boundary_condition)
125 | {
126 | // value += input[image_offset0 + ih * IW + iw] * (*p_weight);
127 | value -= abs(input[image_offset0 + ih * IW + iw] - (*p_weight)); //pure operation
128 | }
129 | else // padded area
130 | {
131 | value -= abs(*p_weight);
132 | }
133 | p_weight++;
134 | }
135 | }
136 | }
137 | }
138 | else
139 | {
140 | for (int cc = in_channels/2; cc < in_channels; cc++)
141 | {
142 | // #pragma unroll
143 | const int image_offset0 = (n * in_channels + cc) * IH * IW; //channel offset (absolute)
144 | for (int kh = 0; kh < KH; kh++)
145 | {
146 | // #pragma unroll
147 | for (int kw = 0; kw < KW; kw++)
148 | {
149 | const int ih = h * SH - PH + kh; // *stride-padding of H
150 | const int iw = w * SW - PW + kw; // *stride-padding of W
151 |
152 | bool boundary_condition = (ih >= 0) && (ih < IH) && (iw >= 0) && (iw < IW);
153 | if (boundary_condition)
154 | {
155 | // value += input[image_offset0 + ih * IW + iw] * (*p_weight);
156 | value -= abs(input[image_offset0 + ih * IW + iw] - (*p_weight)); //pure operation
157 | }
158 | else // padded area
159 | {
160 | value -= abs(*p_weight);
161 | }
162 | p_weight++;
163 | }
164 | }
165 | }
166 | }
167 | output[index] = value;
168 | }
169 | }
170 | if (GW==in_channels) //Dpws
171 | {
172 | for (int index = blockDim.x * blockIdx.x + threadIdx.x; index < num_elem; index += gridDim.x * blockDim.x)
173 | {
174 | const int n = index / OW / OH / out_channels;
175 | const int m = index / OW / OH % out_channels;
176 | const int h = index / OW % OH;
177 | const int w = index % OW;
178 |
179 | const scalar_t *p_weight = weight + m * 1 * KH * KW;
180 | // scalar_t value = bias[m];
181 | scalar_t value = 0;
182 | // #TODO:
183 | // #pragma unroll
184 | // #pragma unroll
185 | const int image_offset0 = (n * in_channels + m) * IH * IW;
186 | for (int kh = 0; kh < KH; kh++)
187 | {
188 | // #pragma unroll
189 | for (int kw = 0; kw < KW; kw++)
190 | {
191 | const int ih = h * SH - PH + kh;
192 | const int iw = w * SW - PW + kw;
193 |
194 | bool boundary_condition = (ih >= 0) && (ih < IH) && (iw >= 0) && (iw < IW);
195 | if (boundary_condition)
196 | {
197 | // value += input[image_offset0 + ih * IW + iw] * (*p_weight);
198 | value -= abs(input[image_offset0 + ih * IW + iw] - (*p_weight));
199 | }
200 | else // padded area
201 | {
202 | value -= abs(*p_weight);
203 | }
204 | p_weight++;
205 | }
206 | }
207 | output[index] = value;
208 | }
209 | }
210 |
211 | }
212 |
213 | template
214 | __global__ void adder_backward_grad_in_kernel(
215 | scalar_t *grad_out,
216 | scalar_t *input,
217 | scalar_t *weight,
218 | scalar_t *grad_in,
219 | const int num_elem,
220 | const int out_channels,
221 | const int in_channels,
222 | const int IW, const int IH,
223 | const int OW, const int OH,
224 | const int KW, const int KH,
225 | const int SW, const int SH,
226 | const int PW, const int PH,
227 | const int GW, const int GH)
228 | { if (GW==1)
229 | {
230 | for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_elem; index += gridDim.x * blockDim.x)
231 | {
232 | const int n = index / IW / IH / in_channels;
233 | const int c = index / IW / IH % in_channels;
234 | const int h = index / IW % IH;
235 | const int w = index % IW;
236 |
237 | scalar_t value = 0;
238 | for (int mm = 0; mm < out_channels; mm++)
239 | {
240 | const int grad_out_offset0 = (n * out_channels + mm) * OH * OW;
241 | scalar_t *p_weight = weight + (mm * in_channels + c) * KH * KW;
242 | for (int kh = 0; kh < KH; kh++)
243 | {
244 | for (int kw = 0; kw < KW; kw++)
245 | {
246 | int oh = h + PH - kh;
247 | int ow = w + PW - kw;
248 |
249 | if ((oh % SH == 0) && (ow % SW == 0))
250 | {
251 | const bool boundary_condition = (oh >= 0) && (oh < OH) && (ow >= 0) && (ow < OW);
252 | if (boundary_condition)
253 | {
254 | oh = oh / SH;
255 | ow = ow / SW;
256 | // value += grad_out[grad_out_offset0 + oh * OW + ow] * (*p_weight);
257 | scalar_t ht = HARDTANH(*p_weight - input[index]);
258 | value += grad_out[grad_out_offset0 + oh * OW + ow] * ht;
259 | }
260 | }
261 | p_weight++;
262 | }
263 | }
264 | }
265 | grad_in[index] = value;
266 | }
267 | }
268 | if (GW==2)
269 | {
270 | for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_elem; index += gridDim.x * blockDim.x)
271 | {
272 | const int n = index / IW / IH / in_channels;
273 | const int c = index / IW / IH % in_channels;
274 | const int h = index / IW % IH;
275 | const int w = index % IW;
276 |
277 | scalar_t value = 0;
278 | if (c < in_channels/2)
279 | {
280 | for (int mm = 0; mm < out_channels/2; mm++)
281 | {
282 | const int grad_out_offset0 = (n * out_channels + mm) * OH * OW;
283 | scalar_t *p_weight = weight + (mm * in_channels/2 + c) * KH * KW;
284 | for (int kh = 0; kh < KH; kh++)
285 | {
286 | for (int kw = 0; kw < KW; kw++)
287 | {
288 | int oh = h + PH - kh;
289 | int ow = w + PW - kw;
290 |
291 | if ((oh % SH == 0) && (ow % SW == 0))
292 | {
293 | const bool boundary_condition = (oh >= 0) && (oh < OH) && (ow >= 0) && (ow < OW);
294 | if (boundary_condition)
295 | {
296 | oh = oh / SH;
297 | ow = ow / SW;
298 | // value += grad_out[grad_out_offset0 + oh * OW + ow] * (*p_weight);
299 | scalar_t ht = HARDTANH(*p_weight - input[index]);
300 | value += grad_out[grad_out_offset0 + oh * OW + ow] * ht;
301 | }
302 | }
303 | p_weight++;
304 | }
305 | }
306 | }
307 | }
308 | else
309 | {
310 | for (int mm = out_channels/2; mm < out_channels; mm++)
311 | {
312 | const int grad_out_offset0 = (n * out_channels + mm) * OH * OW;
313 | scalar_t *p_weight = weight + (mm * in_channels/2 + c - out_channels/2) * KH * KW;
314 | for (int kh = 0; kh < KH; kh++)
315 | {
316 | for (int kw = 0; kw < KW; kw++)
317 | {
318 | int oh = h + PH - kh;
319 | int ow = w + PW - kw;
320 |
321 | if ((oh % SH == 0) && (ow % SW == 0))
322 | {
323 | const bool boundary_condition = (oh >= 0) && (oh < OH) && (ow >= 0) && (ow < OW);
324 | if (boundary_condition)
325 | {
326 | oh = oh / SH;
327 | ow = ow / SW;
328 | // value += grad_out[grad_out_offset0 + oh * OW + ow] * (*p_weight);
329 | scalar_t ht = HARDTANH(*p_weight - input[index]);
330 | value += grad_out[grad_out_offset0 + oh * OW + ow] * ht;
331 | }
332 | }
333 | p_weight++;
334 | }
335 | }
336 | }
337 | }
338 | grad_in[index] = value;
339 | }
340 | }
341 | if (GW==in_channels)
342 | {
343 | for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_elem; index += gridDim.x * blockDim.x)
344 | {
345 | const int n = index / IW / IH / in_channels;
346 | const int c = index / IW / IH % in_channels;
347 | const int h = index / IW % IH;
348 | const int w = index % IW;
349 |
350 | scalar_t value = 0;
351 |
352 | const int grad_out_offset0 = (n * out_channels + c) * OH * OW;
353 | scalar_t *p_weight = weight + c * KH * KW;
354 | for (int kh = 0; kh < KH; kh++)
355 | {
356 | for (int kw = 0; kw < KW; kw++)
357 | {
358 | int oh = h + PH - kh;
359 | int ow = w + PW - kw;
360 |
361 | if ((oh % SH == 0) && (ow % SW == 0))
362 | {
363 | const bool boundary_condition = (oh >= 0) && (oh < OH) && (ow >= 0) && (ow < OW);
364 | if (boundary_condition)
365 | {
366 | oh = oh / SH;
367 | ow = ow / SW;
368 | // value += grad_out[grad_out_offset0 + oh * OW + ow] * (*p_weight);
369 | scalar_t ht = HARDTANH(*p_weight - input[index]);
370 | value += grad_out[grad_out_offset0 + oh * OW + ow] * ht;
371 | }
372 | }
373 | p_weight++;
374 | }
375 | }
376 | grad_in[index] = value;
377 | }
378 | }
379 | }
380 |
381 | template
382 | __global__ void adder_backward_grad_weight_kernel(
383 | scalar_t *grad_out,
384 | scalar_t *input,
385 | scalar_t *weight,
386 | scalar_t *grad_weight,
387 | const int batch_size,
388 | const int out_channels,
389 | const int in_channels,
390 | const int IW, const int IH,
391 | const int OW, const int OH,
392 | const int KW, const int KH,
393 | const int SW, const int SH,
394 | const int PW, const int PH,
395 | const int GW, const int GH)
396 | {
397 | int bidx = blockIdx.x;
398 | int kW = bidx % KW;
399 | int kH = bidx / KW % KH;
400 | int ch = bidx / KW / KH % in_channels;
401 | int mh = bidx / KW / KH / in_channels;
402 |
403 | if (GW == 2) {
404 | ch = bidx / KW / KH % (in_channels / 2);
405 | mh = bidx / KW / KH / (in_channels / 2);
406 | if (mh >= out_channels / 2) {
407 | ch = ch + in_channels / 2;
408 | }
409 | }
410 | if (GW == in_channels) {
411 | ch = bidx / KW / KH;
412 | mh = bidx / KW / KH;
413 | }
414 |
415 | scalar_t grad = 0;
416 | const int laneId = threadIdx.x % WARP_SIZE;
417 | const int batch = threadIdx.x / WARP_SIZE;
418 | const int nwarps = blockDim.x / WARP_SIZE;
419 | const int imageElements = OW * OH;
420 |
421 | for (int batchIdx = batch; batchIdx < batch_size; batchIdx += nwarps) {
422 | for (int idx = laneId; idx < imageElements; idx += WARP_SIZE) {
423 | int go_w_offset = idx % OW;
424 | int go_h_offset = idx / OW;
425 |
426 | int i_w_offset = go_w_offset * SW + kW - PW;
427 | int i_h_offset = go_h_offset * SH + kH - PH;
428 |
429 | int outputOffset = ((batchIdx * out_channels + mh) * OH) * OW + idx;
430 | if (i_w_offset >= 0 && i_h_offset >= 0 && i_w_offset < IW && i_h_offset < IH) {
431 | int inputOffset = ((batchIdx * in_channels + ch) * IH + i_h_offset) * IW + i_w_offset;
432 | grad += (input[inputOffset] - weight[bidx]) * grad_out[outputOffset];
433 | } else {
434 | grad += -weight[bidx] * grad_out[outputOffset];
435 | }
436 | }
437 | }
438 |
439 | __shared__ scalar_t shared[NUM_THREADS];
440 | shared[threadIdx.x] = grad;
441 | __syncthreads();
442 |
443 | for (int stride = blockDim.x / 2; stride > 0; stride /= 2) {
444 | if (threadIdx.x < stride) {
445 | shared[threadIdx.x] += shared[threadIdx.x + stride];
446 | }
447 | __syncthreads();
448 | }
449 |
450 | if (threadIdx.x == 0) {
451 | scalar_t tval = shared[0];
452 | if (GW == 1) {
453 | int weightOffset = kW + (KW * kH) + (KW * KH * ch) + (KW * KH * in_channels * mh);
454 | grad_weight[weightOffset] = tval;
455 | }
456 | if (GW == 2) {
457 | if (mh < out_channels / 2) {
458 | int weightOffset = kW + (KW * kH) + (KW * KH * ch) + (KW * KH * (in_channels / 2) * mh);
459 | grad_weight[weightOffset] = tval;
460 | } else {
461 | int weightOffset = kW + (KW * kH) + (KW * KH * (ch - in_channels / 2)) + (KW * KH * (in_channels / 2) * mh);
462 | grad_weight[weightOffset] = tval;
463 | }
464 | }
465 | if (GW == in_channels) {
466 | int weightOffset = kW + (KW * kH) + (KW * KH * 0) + (KW * KH * 1 * mh);
467 | grad_weight[weightOffset] = tval;
468 | }
469 | }
470 | }
471 |
472 | ////////////////////////////////////////////////////////////////////////
473 | ////////////////////////////END OF KERNEL///////////////////////////////
474 | ////////////////////////////////////////////////////////////////////////
475 |
476 | int adder_cuda_forward(
477 | const at::Tensor &input,
478 | const at::Tensor &weight,
479 | // const at::Tensor &bias,
480 | at::Tensor &output,
481 | int KW, int KH,
482 | int SW, int SH,
483 | int PW, int PH,
484 | int GW, int GH
485 | )
486 | {
487 | const int batch_size = output.size(0);
488 | const int in_channels = input.size(1);
489 | const int out_channels = output.size(1);
490 | const int IW = input.size(3);
491 | const int IH = input.size(2);
492 | const int OW = output.size(3);
493 | const int OH = output.size(2);
494 | const int num_elem = batch_size * out_channels * OH * OW;
495 | const int num_blocks = get_blocks(num_elem);
496 |
497 | AT_DISPATCH_FLOATING_TYPES(output.scalar_type(), "adder_cuda_forward", ([&] {
498 | adder_forward_kernel<<>>(
499 | input.data_ptr(),
500 | weight.data_ptr(),
501 | output.data_ptr(),
502 | num_elem,
503 | out_channels,
504 | in_channels,
505 | IW, IH,
506 | OW, OH,
507 | KW, KH,
508 | SW, SH,
509 | PW, PH,
510 | GW, GH
511 | );
512 | }));
513 | // AT_CUDA_CHECK(cudaGetLastError());
514 | C10_CUDA_CHECK(cudaGetLastError());
515 | return 1;
516 | }
517 |
518 | /*
519 | scalar_t *grad_out,
520 | scalar_t *weight,
521 | scalar_t *grad_in,
522 | const int num_elem,
523 | const int out_channels,
524 | const int in_channels,
525 | const int IW, const int IH,
526 | const int OW, const int OH,
527 | const int KW, const int KH,
528 | const int SW, const int SH,
529 | const int PW, const int PH,
530 | const int GW, const int GH
531 | */
532 |
533 | int adder_cuda_backward_grad_in(
534 | at::Tensor &grad_out,
535 | at::Tensor &input,
536 | at::Tensor &weight,
537 | at::Tensor &grad_in,
538 | int KW, int KH,
539 | int SW, int SH,
540 | int PW, int PH,
541 | int GW, int GH
542 | )
543 | {
544 | const int batch_size = grad_in.size(0);
545 | const int in_channels = grad_in.size(1);
546 | const int out_channels = grad_out.size(1);
547 | const int IW = grad_in.size(3);
548 | const int IH = grad_in.size(2);
549 | const int OW = grad_out.size(3);
550 | const int OH = grad_out.size(2);
551 | const int num_elem = batch_size * in_channels * IH * IW;
552 | const int num_blocks = get_blocks(num_elem);
553 |
554 | AT_DISPATCH_FLOATING_TYPES(grad_in.type(), "adder_cuda_backward_grad_in", ([&] {
555 | adder_backward_grad_in_kernel<<>>(
556 | grad_out.data_ptr(),
557 | input.data_ptr(),
558 | weight.data_ptr(),
559 | grad_in.data_ptr(),
560 | num_elem,
561 | out_channels,
562 | in_channels,
563 | IW, IH,
564 | OW, OH,
565 | KW, KH,
566 | SW, SH,
567 | PW, PH,
568 | GW, GH
569 | );
570 | }));
571 | // AT_CUDA_CHECK(cudaGetLastError());
572 | C10_CUDA_CHECK(cudaGetLastError());
573 | return 1;
574 | }
575 |
576 | int adder_cuda_backward_grad_weight(
577 | at::Tensor &grad_out,
578 | at::Tensor &input,
579 | at::Tensor &weight,
580 | at::Tensor &grad_weight,
581 | int KW, int KH,
582 | int SW, int SH,
583 | int PW, int PH,
584 | int GW, int GH
585 | )
586 | {
587 | const int batch_size = input.size(0);
588 | const int in_channels = input.size(1);
589 | const int out_channels = grad_out.size(1);
590 | const int IW = input.size(3);
591 | const int IH = input.size(2);
592 | const int OW = grad_out.size(3);
593 | const int OH = grad_out.size(2);
594 |
595 | int blocks = out_channels * in_channels * KH * KW;
596 |
597 | if (GW==2)
598 | {
599 | blocks = out_channels * (in_channels/2) * KH * KW;
600 | }
601 | if (GW==in_channels)
602 | {
603 | blocks = out_channels * 1 * KH * KW;
604 | }
605 |
606 | // Make sure we have enough threads to perform the reduction, and use this number
607 | // to create the shared memory size for the reduction
608 | dim3 grid(blocks);
609 | dim3 block(getGradParamsNumThreads(batch_size));
610 | // int smem = block.x * sizeof(accreal);
611 |
612 | AT_DISPATCH_FLOATING_TYPES(grad_weight.type(), "adder_cuda_backward_grad_weight", ([&] {
613 | adder_backward_grad_weight_kernel<<>>(
614 | grad_out.data_ptr(),
615 | input.data_ptr(),
616 | weight.data_ptr(),
617 | grad_weight.data_ptr(),
618 | batch_size,
619 | out_channels,
620 | in_channels,
621 | IW, IH,
622 | OW, OH,
623 | KW, KH,
624 | SW, SH,
625 | PW, PH,
626 | GW, GH);
627 | }));
628 | // AT_CUDA_CHECK(cudaGetLastError());
629 | C10_CUDA_CHECK(cudaGetLastError());
630 | return 1;
631 | }
632 |
633 | /*
634 | scalar_t *grad_out,
635 | scalar_t *input,
636 | scalar_t *grad_weight,
637 | const int batch_size,
638 | const int out_channels,
639 | const int in_channels,
640 | const int IW, const int IH,
641 | const int OW, const int OH,
642 | const int KW, const int KH,
643 | const int SW, const int SH,
644 | const int PW, const int PH,
645 | const int GW, const int GH
646 | */
--------------------------------------------------------------------------------
/torchshiftadd/layers/extension/adder_cuda_kernel_torch_1.4.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | // #include
4 | // #include
5 | #include
6 | #include
7 | #include
8 | #include
9 |
10 | #include
11 | #define MAX_BLOCKS 256
12 | #define NUM_THREADS 256
13 | #define MAX(a, b) ((a) > (b)) ? (a) : (b)
14 | #define MIN(a, b) ((a) < (b)) ? (a) : (b)
15 | #define HARDTANH(x) ((x) < (-1.0)) ? (-1.0) : (((x) <= (1.0)) ? (x) : (1.0))
16 | const int WARP_SIZE = 32;
17 | // Crude benchmarks suggest 256 is better than 512 and 1024
18 | // TODO: Autotune/use better heuristics, improve speed more.
19 | const int MAX_BLOCK_SIZE = 256;
20 |
21 | template
22 | struct SharedMem {
23 | __device__ T *getPointer()
24 | {
25 | extern __device__ void error(void);
26 | error();
27 | return NULL;
28 | }
29 | };
30 |
31 |
32 | static int getGradParamsNumThreads(int batchSize)
33 | {
34 | //warp per item in a batch, up to a maximum
35 | return std::min(batchSize * WARP_SIZE, MAX_BLOCK_SIZE);
36 | }
37 |
38 | int get_blocks(int n)
39 | {
40 | // return MAX(1, MIN(MAX_BLOCKS, (n - NUM_THREADS + 1) / NUM_THREADS));
41 | return MIN(MAX_BLOCKS, (n - NUM_THREADS + 1) / NUM_THREADS) + 1;
42 | }
43 |
44 | template
45 | __global__ void adder_forward_kernel(
46 | const scalar_t const *input,
47 | const scalar_t const *weight,
48 | // const scalar_t const *bias,
49 | scalar_t *output,
50 | const int num_elem,
51 | const int out_channels,
52 | const int in_channels,
53 | const int IW, const int IH,
54 | const int OW, const int OH,
55 | const int KW, const int KH,
56 | const int SW, const int SH,
57 | const int PW, const int PH,
58 | const int GW, const int GH)
59 | {
60 | // #TODO:
61 | if (GW==1)
62 | {
63 | for (int index = blockDim.x * blockIdx.x + threadIdx.x; index < num_elem; index += gridDim.x * blockDim.x)
64 | {
65 | const int n = index / OW / OH / out_channels;
66 | const int m = index / OW / OH % out_channels;
67 | const int h = index / OW % OH;
68 | const int w = index % OW;
69 |
70 | const scalar_t *p_weight = weight + m * in_channels * KH * KW; //the start position of the kernel(corresponding to the output)
71 | // scalar_t value = bias[m];
72 | scalar_t value = 0;
73 | // #TODO:
74 | // #pragma unroll
75 | for (int cc = 0; cc < in_channels; cc++)
76 | {
77 | // #pragma unroll
78 | const int image_offset0 = (n * in_channels + cc) * IH * IW;
79 | for (int kh = 0; kh < KH; kh++)
80 | {
81 | // #pragma unroll
82 | for (int kw = 0; kw < KW; kw++)
83 | {
84 | const int ih = h * SH - PH + kh;
85 | const int iw = w * SW - PW + kw;
86 |
87 | bool boundary_condition = (ih >= 0) && (ih < IH) && (iw >= 0) && (iw < IW);
88 | if (boundary_condition)
89 | {
90 | // value += input[image_offset0 + ih * IW + iw] * (*p_weight);
91 | value -= abs(input[image_offset0 + ih * IW + iw] - (*p_weight));
92 | }
93 | else // padded area
94 | {
95 | value -= abs(*p_weight);
96 | }
97 | p_weight++;
98 | }
99 | }
100 | }
101 | output[index] = value;
102 | }
103 | }
104 | if (GW==2)
105 | {
106 | for (int index = blockDim.x * blockIdx.x + threadIdx.x; index < num_elem; index += gridDim.x * blockDim.x)
107 | // total size of output (batch_size * out_channels * W-out * H_out)
108 | {
109 | // TODO
110 | const int n = index / OW / OH / out_channels; // batch size of n
111 | const int m = index / OW / OH % out_channels; // relative output channel size
112 | const int h = index / OW % OH; // relative position of H
113 | const int w = index % OW; //relative position of W
114 |
115 | const scalar_t *p_weight = weight + m * in_channels/2 * KH * KW; //the start position of the kernel(corresponding to the output)
116 |
117 | // scalar_t value = bias[m];
118 | scalar_t value = 0;
119 | // #TODO:
120 | // #pragma unroll
121 | if (m < out_channels/2)
122 | {
123 | for (int cc = 0; cc < in_channels/2; cc++)
124 | {
125 | // #pragma unroll
126 | const int image_offset0 = (n * in_channels + cc) * IH * IW; //channel offset (absolute)
127 | for (int kh = 0; kh < KH; kh++)
128 | {
129 | // #pragma unroll
130 | for (int kw = 0; kw < KW; kw++)
131 | {
132 | const int ih = h * SH - PH + kh; // *stride-padding of H
133 | const int iw = w * SW - PW + kw; // *stride-padding of W
134 |
135 | bool boundary_condition = (ih >= 0) && (ih < IH) && (iw >= 0) && (iw < IW);
136 | if (boundary_condition)
137 | {
138 | // value += input[image_offset0 + ih * IW + iw] * (*p_weight);
139 | value -= abs(input[image_offset0 + ih * IW + iw] - (*p_weight)); //pure operation
140 | }
141 | else // padded area
142 | {
143 | value -= abs(*p_weight);
144 | }
145 | p_weight++;
146 | }
147 | }
148 | }
149 | }
150 | else
151 | {
152 | for (int cc = in_channels/2; cc < in_channels; cc++)
153 | {
154 | // #pragma unroll
155 | const int image_offset0 = (n * in_channels + cc) * IH * IW; //channel offset (absolute)
156 | for (int kh = 0; kh < KH; kh++)
157 | {
158 | // #pragma unroll
159 | for (int kw = 0; kw < KW; kw++)
160 | {
161 | const int ih = h * SH - PH + kh; // *stride-padding of H
162 | const int iw = w * SW - PW + kw; // *stride-padding of W
163 |
164 | bool boundary_condition = (ih >= 0) && (ih < IH) && (iw >= 0) && (iw < IW);
165 | if (boundary_condition)
166 | {
167 | // value += input[image_offset0 + ih * IW + iw] * (*p_weight);
168 | value -= abs(input[image_offset0 + ih * IW + iw] - (*p_weight)); //pure operation
169 | }
170 | else // padded area
171 | {
172 | value -= abs(*p_weight);
173 | }
174 | p_weight++;
175 | }
176 | }
177 | }
178 | }
179 | output[index] = value;
180 | }
181 | }
182 | if (GW==in_channels) //Dpws
183 | {
184 | for (int index = blockDim.x * blockIdx.x + threadIdx.x; index < num_elem; index += gridDim.x * blockDim.x)
185 | {
186 | const int n = index / OW / OH / out_channels;
187 | const int m = index / OW / OH % out_channels;
188 | const int h = index / OW % OH;
189 | const int w = index % OW;
190 |
191 | const scalar_t *p_weight = weight + m * 1 * KH * KW;
192 | // scalar_t value = bias[m];
193 | scalar_t value = 0;
194 | // #TODO:
195 | // #pragma unroll
196 | // #pragma unroll
197 | const int image_offset0 = (n * in_channels + m) * IH * IW;
198 | for (int kh = 0; kh < KH; kh++)
199 | {
200 | // #pragma unroll
201 | for (int kw = 0; kw < KW; kw++)
202 | {
203 | const int ih = h * SH - PH + kh;
204 | const int iw = w * SW - PW + kw;
205 |
206 | bool boundary_condition = (ih >= 0) && (ih < IH) && (iw >= 0) && (iw < IW);
207 | if (boundary_condition)
208 | {
209 | // value += input[image_offset0 + ih * IW + iw] * (*p_weight);
210 | value -= abs(input[image_offset0 + ih * IW + iw] - (*p_weight));
211 | }
212 | else // padded area
213 | {
214 | value -= abs(*p_weight);
215 | }
216 | p_weight++;
217 | }
218 | }
219 | output[index] = value;
220 | }
221 | }
222 |
223 | }
224 |
225 | template
226 | __global__ void adder_backward_grad_in_kernel(
227 | scalar_t *grad_out,
228 | scalar_t *input,
229 | scalar_t *weight,
230 | scalar_t *grad_in,
231 | const int num_elem,
232 | const int out_channels,
233 | const int in_channels,
234 | const int IW, const int IH,
235 | const int OW, const int OH,
236 | const int KW, const int KH,
237 | const int SW, const int SH,
238 | const int PW, const int PH,
239 | const int GW, const int GH)
240 | { if (GW==1)
241 | {
242 | for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_elem; index += gridDim.x * blockDim.x)
243 | {
244 | const int n = index / IW / IH / in_channels;
245 | const int c = index / IW / IH % in_channels;
246 | const int h = index / IW % IH;
247 | const int w = index % IW;
248 |
249 | scalar_t value = 0;
250 | for (int mm = 0; mm < out_channels; mm++)
251 | {
252 | const int grad_out_offset0 = (n * out_channels + mm) * OH * OW;
253 | scalar_t *p_weight = weight + (mm * in_channels + c) * KH * KW;
254 | for (int kh = 0; kh < KH; kh++)
255 | {
256 | for (int kw = 0; kw < KW; kw++)
257 | {
258 | int oh = h + PH - kh;
259 | int ow = w + PW - kw;
260 |
261 | if ((oh % SH == 0) && (ow % SW == 0))
262 | {
263 | const bool boundary_condition = (oh >= 0) && (oh < OH) && (ow >= 0) && (ow < OW);
264 | if (boundary_condition)
265 | {
266 | oh = oh / SH;
267 | ow = ow / SW;
268 | // value += grad_out[grad_out_offset0 + oh * OW + ow] * (*p_weight);
269 | scalar_t ht = HARDTANH(*p_weight - input[index]);
270 | value += grad_out[grad_out_offset0 + oh * OW + ow] * ht;
271 | }
272 | }
273 | p_weight++;
274 | }
275 | }
276 | }
277 | grad_in[index] = value;
278 | }
279 | }
280 | if (GW==2)
281 | {
282 | for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_elem; index += gridDim.x * blockDim.x)
283 | {
284 | const int n = index / IW / IH / in_channels;
285 | const int c = index / IW / IH % in_channels;
286 | const int h = index / IW % IH;
287 | const int w = index % IW;
288 |
289 | scalar_t value = 0;
290 | if (c < in_channels/2)
291 | {
292 | for (int mm = 0; mm < out_channels/2; mm++)
293 | {
294 | const int grad_out_offset0 = (n * out_channels + mm) * OH * OW;
295 | scalar_t *p_weight = weight + (mm * in_channels/2 + c) * KH * KW;
296 | for (int kh = 0; kh < KH; kh++)
297 | {
298 | for (int kw = 0; kw < KW; kw++)
299 | {
300 | int oh = h + PH - kh;
301 | int ow = w + PW - kw;
302 |
303 | if ((oh % SH == 0) && (ow % SW == 0))
304 | {
305 | const bool boundary_condition = (oh >= 0) && (oh < OH) && (ow >= 0) && (ow < OW);
306 | if (boundary_condition)
307 | {
308 | oh = oh / SH;
309 | ow = ow / SW;
310 | // value += grad_out[grad_out_offset0 + oh * OW + ow] * (*p_weight);
311 | scalar_t ht = HARDTANH(*p_weight - input[index]);
312 | value += grad_out[grad_out_offset0 + oh * OW + ow] * ht;
313 | }
314 | }
315 | p_weight++;
316 | }
317 | }
318 | }
319 | }
320 | else
321 | {
322 | for (int mm = out_channels/2; mm < out_channels; mm++)
323 | {
324 | const int grad_out_offset0 = (n * out_channels + mm) * OH * OW;
325 | scalar_t *p_weight = weight + (mm * in_channels/2 + c - out_channels/2) * KH * KW;
326 | for (int kh = 0; kh < KH; kh++)
327 | {
328 | for (int kw = 0; kw < KW; kw++)
329 | {
330 | int oh = h + PH - kh;
331 | int ow = w + PW - kw;
332 |
333 | if ((oh % SH == 0) && (ow % SW == 0))
334 | {
335 | const bool boundary_condition = (oh >= 0) && (oh < OH) && (ow >= 0) && (ow < OW);
336 | if (boundary_condition)
337 | {
338 | oh = oh / SH;
339 | ow = ow / SW;
340 | // value += grad_out[grad_out_offset0 + oh * OW + ow] * (*p_weight);
341 | scalar_t ht = HARDTANH(*p_weight - input[index]);
342 | value += grad_out[grad_out_offset0 + oh * OW + ow] * ht;
343 | }
344 | }
345 | p_weight++;
346 | }
347 | }
348 | }
349 | }
350 | grad_in[index] = value;
351 | }
352 | }
353 | if (GW==in_channels)
354 | {
355 | for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_elem; index += gridDim.x * blockDim.x)
356 | {
357 | const int n = index / IW / IH / in_channels;
358 | const int c = index / IW / IH % in_channels;
359 | const int h = index / IW % IH;
360 | const int w = index % IW;
361 |
362 | scalar_t value = 0;
363 |
364 | const int grad_out_offset0 = (n * out_channels + c) * OH * OW;
365 | scalar_t *p_weight = weight + c * KH * KW;
366 | for (int kh = 0; kh < KH; kh++)
367 | {
368 | for (int kw = 0; kw < KW; kw++)
369 | {
370 | int oh = h + PH - kh;
371 | int ow = w + PW - kw;
372 |
373 | if ((oh % SH == 0) && (ow % SW == 0))
374 | {
375 | const bool boundary_condition = (oh >= 0) && (oh < OH) && (ow >= 0) && (ow < OW);
376 | if (boundary_condition)
377 | {
378 | oh = oh / SH;
379 | ow = ow / SW;
380 | // value += grad_out[grad_out_offset0 + oh * OW + ow] * (*p_weight);
381 | scalar_t ht = HARDTANH(*p_weight - input[index]);
382 | value += grad_out[grad_out_offset0 + oh * OW + ow] * ht;
383 | }
384 | }
385 | p_weight++;
386 | }
387 | }
388 | grad_in[index] = value;
389 | }
390 | }
391 | }
392 |
393 | template
394 | __global__ void adder_backward_grad_weight_kernel(
395 | scalar_t *grad_out,
396 | scalar_t *input,
397 | scalar_t *weight,
398 | scalar_t *grad_weight,
399 | const int batch_size,
400 | const int out_channels,
401 | const int in_channels,
402 | const int IW, const int IH,
403 | const int OW, const int OH,
404 | const int KW, const int KH,
405 | const int SW, const int SH,
406 | const int PW, const int PH,
407 | const int GW, const int GH)
408 | {
409 | SharedMem smem;
410 | int bidx = blockIdx.x;
411 | int kW = bidx % KW;
412 | int kH = bidx / KW % KH;
413 | int ch = bidx / KW / KH % in_channels;
414 | int mh = bidx / KW / KH / in_channels;
415 |
416 | if (GW==2)
417 | {
418 | ch = bidx / KW / KH % (in_channels/2);
419 | mh = bidx / KW / KH / (in_channels/2);
420 | if (mh >= out_channels/2)
421 | {
422 | ch = ch + in_channels/2;
423 | }
424 | }
425 | if (GW==in_channels)
426 | {
427 | ch = bidx / KW / KH;
428 | mh = bidx / KW / KH;
429 | }
430 |
431 | scalar_t grad = 0;
432 | const int laneId = threadIdx.x % WARP_SIZE;
433 | const int batch = threadIdx.x / WARP_SIZE;
434 | const int nwarps = blockDim.x / WARP_SIZE;
435 | const int imageElements = OW * OH;
436 | for (int batchIdx = batch; batchIdx < batch_size; batchIdx += nwarps)
437 | {
438 | // Warp-stride loop over elements in a batch item
439 | for (int idx = laneId; idx < imageElements; idx += WARP_SIZE)
440 | {
441 | // Need to calculate the following: batch position, and offset into the gradOutput
442 | // in height, and width. We can intuit the corresponding position in the input from
443 | // the other parameters we have
444 | int go_w_offset = idx % OW;
445 | int go_h_offset = (idx / OW);
446 |
447 | int i_w_offset = go_w_offset * SW + kW - PW;
448 | int i_h_offset = go_h_offset * SH + kH - PH;
449 |
450 | int outputOffset = ((batchIdx * out_channels + mh) * OH) * OW + idx;
451 | if (i_w_offset >= 0 && i_h_offset >= 0 && i_w_offset < IW && i_h_offset < IH)
452 | {
453 | int inputOffset = ((batchIdx * in_channels + ch) * IH + i_h_offset) * IW + i_w_offset;
454 | // int outputOffset = ((batchIdx * out_channels + mh) * OH) * OW + idx;
455 | // grad += input[inputOffset] * grad_out[outputOffset];
456 | grad += (input[inputOffset] - weight[bidx]) * grad_out[outputOffset];
457 | }
458 | else // padded area
459 | {
460 | grad += - weight[bidx] * grad_out[outputOffset];
461 | }
462 | }
463 | }
464 | __syncthreads();
465 | scalar_t *buf = smem.getPointer();
466 | scalar_t tval = reduceBlock>(
467 | buf, blockDim.x, grad, ReduceAdd(), 0);
468 |
469 | // After reduction, first thread in the block has the gradient, so its responsible
470 | // for writing it to gradWeight
471 | if (threadIdx.x == 0)
472 | {
473 | if (GW==1)
474 | {
475 | int weightOffset = kW + (KW * kH) + (KW * KH * ch) + (KW * KH * in_channels * mh);
476 | grad_weight[weightOffset] = tval;
477 | }
478 | if (GW==2)
479 | {
480 | if (mh < out_channels/2)
481 | {
482 | int weightOffset = kW + (KW * kH) + (KW * KH * ch) + (KW * KH * in_channels/2 * mh);
483 | grad_weight[weightOffset] = tval;
484 | }
485 | else
486 | {
487 | int weightOffset = kW + (KW * kH) + (KW * KH * (ch - in_channels/2)) + (KW * KH * in_channels/2 * mh);
488 | grad_weight[weightOffset] = tval;
489 | }
490 |
491 | }
492 | if (GW==in_channels)
493 | {
494 | int weightOffset = kW + (KW * kH) + (KW * KH * 0) + (KW * KH * 1 * mh);
495 | grad_weight[weightOffset] = tval;
496 | }
497 | }
498 | }
499 |
500 | ////////////////////////////////////////////////////////////////////////
501 | ////////////////////////////END OF KERNEL///////////////////////////////
502 | ////////////////////////////////////////////////////////////////////////
503 |
504 | int adder_cuda_forward(
505 | const at::Tensor &input,
506 | const at::Tensor &weight,
507 | // const at::Tensor &bias,
508 | at::Tensor &output,
509 | int KW, int KH,
510 | int SW, int SH,
511 | int PW, int PH,
512 | int GW, int GH
513 | )
514 | {
515 | const int batch_size = output.size(0);
516 | const int in_channels = input.size(1);
517 | const int out_channels = output.size(1);
518 | const int IW = input.size(3);
519 | const int IH = input.size(2);
520 | const int OW = output.size(3);
521 | const int OH = output.size(2);
522 | const int num_elem = batch_size * out_channels * OH * OW;
523 | const int num_blocks = get_blocks(num_elem);
524 |
525 | AT_DISPATCH_FLOATING_TYPES(output.type(), "adder_cuda_forward", ([&] {
526 | adder_forward_kernel<<>>(
527 | input.data(),
528 | weight.data(),
529 | // bias.data(),
530 | output.data(),
531 | num_elem,
532 | out_channels,
533 | in_channels,
534 | IW, IH,
535 | OW, OH,
536 | KW, KH,
537 | SW, SH,
538 | PW, PH,
539 | GW, GH
540 | );
541 | }));
542 | AT_CUDA_CHECK(cudaGetLastError());
543 | return 1;
544 | }
545 |
546 | /*
547 | scalar_t *grad_out,
548 | scalar_t *weight,
549 | scalar_t *grad_in,
550 | const int num_elem,
551 | const int out_channels,
552 | const int in_channels,
553 | const int IW, const int IH,
554 | const int OW, const int OH,
555 | const int KW, const int KH,
556 | const int SW, const int SH,
557 | const int PW, const int PH,
558 | const int GW, const int GH
559 | */
560 |
561 | int adder_cuda_backward_grad_in(
562 | at::Tensor &grad_out,
563 | at::Tensor &input,
564 | at::Tensor &weight,
565 | at::Tensor &grad_in,
566 | int KW, int KH,
567 | int SW, int SH,
568 | int PW, int PH,
569 | int GW, int GH
570 | )
571 | {
572 | const int batch_size = grad_in.size(0);
573 | const int in_channels = grad_in.size(1);
574 | const int out_channels = grad_out.size(1);
575 | const int IW = grad_in.size(3);
576 | const int IH = grad_in.size(2);
577 | const int OW = grad_out.size(3);
578 | const int OH = grad_out.size(2);
579 | const int num_elem = batch_size * in_channels * IH * IW;
580 | const int num_blocks = get_blocks(num_elem);
581 |
582 | AT_DISPATCH_FLOATING_TYPES(grad_in.type(), "adder_cuda_backward_grad_in", ([&] {
583 | adder_backward_grad_in_kernel<<>>(
584 | grad_out.data(),
585 | input.data(),
586 | weight.data(),
587 | grad_in.data(),
588 | num_elem,
589 | out_channels,
590 | in_channels,
591 | IW, IH,
592 | OW, OH,
593 | KW, KH,
594 | SW, SH,
595 | PW, PH,
596 | GW, GH
597 | );
598 | }));
599 | AT_CUDA_CHECK(cudaGetLastError());
600 | return 1;
601 | }
602 |
603 | int adder_cuda_backward_grad_weight(
604 | at::Tensor &grad_out,
605 | at::Tensor &input,
606 | at::Tensor &weight,
607 | at::Tensor &grad_weight,
608 | int KW, int KH,
609 | int SW, int SH,
610 | int PW, int PH,
611 | int GW, int GH
612 | )
613 | {
614 | const int batch_size = input.size(0);
615 | const int in_channels = input.size(1);
616 | const int out_channels = grad_out.size(1);
617 | const int IW = input.size(3);
618 | const int IH = input.size(2);
619 | const int OW = grad_out.size(3);
620 | const int OH = grad_out.size(2);
621 |
622 | int blocks = out_channels * in_channels * KH * KW;
623 |
624 | if (GW==2)
625 | {
626 | blocks = out_channels * (in_channels/2) * KH * KW;
627 | }
628 | if (GW==in_channels)
629 | {
630 | blocks = out_channels * 1 * KH * KW;
631 | }
632 |
633 | // Make sure we have enough threads to perform the reduction, and use this number
634 | // to create the shared memory size for the reduction
635 | dim3 grid(blocks);
636 | dim3 block(getGradParamsNumThreads(batch_size));
637 | // int smem = block.x * sizeof(accreal);
638 |
639 | AT_DISPATCH_FLOATING_TYPES(grad_weight.type(), "adder_cuda_backward_grad_weight", ([&] {
640 | adder_backward_grad_weight_kernel<<>>(
641 | grad_out.data(),
642 | input.data(),
643 | weight.data(),
644 | grad_weight.data(),
645 | batch_size,
646 | out_channels,
647 | in_channels,
648 | IW, IH,
649 | OW, OH,
650 | KW, KH,
651 | SW, SH,
652 | PW, PH,
653 | GW, GH);
654 | }));
655 | AT_CUDA_CHECK(cudaGetLastError());
656 | return 1;
657 | }
658 |
659 | /*
660 | scalar_t *grad_out,
661 | scalar_t *input,
662 | scalar_t *grad_weight,
663 | const int batch_size,
664 | const int out_channels,
665 | const int in_channels,
666 | const int IW, const int IH,
667 | const int OW, const int OH,
668 | const int KW, const int KH,
669 | const int SW, const int SH,
670 | const int PW, const int PH,
671 | const int GW, const int GH
672 | */
--------------------------------------------------------------------------------
/torchshiftadd/layers/shift.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import torch.nn as nn
4 | from torch.nn import init
5 | import torch.nn.functional as F
6 | from torch.autograd import Function
7 | from torch.nn.modules.utils import _pair
8 | import torchshiftadd.utils.ste as ste
9 | from torchshiftadd.utils.quantize import quantize_grad
10 |
11 | log2 = math.log(2)
12 |
13 | ###### FC
14 |
15 | class LinearShiftFunction(Function):
16 |
17 | @staticmethod
18 | def forward(ctx, input, shift, sign, bias=None, conc_weight=None, use_kernel=False, use_cuda=True, rounding='deterministic', shift_range=(-14, 0)):
19 | fraction_bits = 16
20 | integer_bit = 16
21 |
22 | sign = sign.clamp(-1,1)
23 | shift = shift.clamp(*shift_range)
24 | input.data = ste.round_to_fixed(input.data, fraction_bits, integer_bit)
25 | if bias is not None:
26 | bias.data = ste.round_to_fixed(bias.data, fraction_bits, integer_bit)
27 |
28 | v = 2**shift.round() * sign.round().sign()
29 | out = input.mm(v.t())
30 | if bias is not None:
31 | out += bias.unsqueeze(0).expand_as(out)
32 |
33 | ctx.save_for_backward(input, shift, sign, bias, v)
34 |
35 | return out
36 |
37 | @staticmethod
38 | def backward(ctx, grad_output):
39 |
40 | input, shift, sign, bias, v = ctx.saved_tensors
41 | grad_input = grad_shift = grad_sign = grad_bias = None
42 |
43 | if ctx.needs_input_grad[0]:
44 | grad_input = grad_output.mm(v)
45 | if ctx.needs_input_grad[2]:
46 | grad_sign = grad_output.t().mm(input)
47 | if ctx.needs_input_grad[1]:
48 | if grad_sign is None:
49 | grad_shift = grad_output.t().mm(input) * v * log2
50 | else:
51 | grad_shift = grad_sign * v * log2
52 | if bias is not None and ctx.needs_input_grad[3]:
53 | grad_bias = grad_output.sum(0).squeeze(0)
54 |
55 | return grad_input, grad_shift, grad_sign, grad_bias, None, None, None
56 |
57 |
58 | class LinearShift(nn.Module):
59 | def __init__(self, in_features, out_features, bias=True, check_grad=False, freeze_sign=False, use_kernel=False, use_cuda=True, rounding='deterministic', weight_bits=5, threshold=None):
60 | super(LinearShift, self).__init__()
61 | self.in_features = in_features
62 | self.out_features = out_features
63 | self.use_kernel = use_kernel
64 | self.check_grad = check_grad
65 | self.use_cuda = use_cuda
66 | self.conc_weight = None
67 | self.rounding = rounding
68 | self.shift_range = (-1 * (2**(weight_bits - 1) - 2), 0) # we use ternary weights to represent sign
69 | self.threshold = threshold
70 | print(self.shift_range)
71 |
72 | if check_grad:
73 | tensor_constructor = torch.DoubleTensor # double precision required to check grad
74 | else:
75 | tensor_constructor = torch.Tensor # In PyTorch torch.Tensor is alias torch.FloatTensor
76 |
77 | self.shift = nn.Parameter(tensor_constructor(out_features, in_features))
78 | self.sign = nn.Parameter(tensor_constructor(out_features, in_features), requires_grad = (freeze_sign == False))
79 |
80 | if bias:
81 | self.bias = nn.Parameter(tensor_constructor(out_features))
82 | else:
83 | self.register_parameter('bias', None)
84 |
85 | self.reset_parameters()
86 |
87 | def reset_parameters(self):
88 | self.shift.data.uniform_(*self.shift_range)
89 | self.sign.data.uniform_(-1, 1)
90 |
91 | if self.bias is not None:
92 | fan_in, _ = init._calculate_fan_in_and_fan_out(self.shift)
93 | bound = 1 / math.sqrt(fan_in)
94 | init.uniform_(self.bias, -bound, bound)
95 |
96 | def forward(self, input):
97 | self.shift.data = ste.clamp(self.shift.data, *self.shift_range)
98 | shift_rounded = ste.round(self.shift, rounding=self.rounding)
99 | if self.threshold == None:
100 | sign_rounded_signed = ste.sign(ste.round(self.sign, rounding=self.rounding))
101 | else:
102 | sign_rounded_signed = ste.sign(round(self.sign, self.threshold))
103 | weight_ps = ste.unsym_grad_mul(2**shift_rounded, sign_rounded_signed)
104 |
105 | if self.use_kernel:
106 | return LinearShiftFunction.apply(input, self.shift, self.sign, self.bias, self.conc_weight, self.use_kernel, self.use_cuda, self.rounding, self.shift_range)
107 | else:
108 | return torch.nn.functional.linear(input, weight_ps, self.bias)
109 |
110 | def extra_repr(self):
111 | return 'in_features={}, out_features={}, bias={}'.format(
112 | self.in_features, self.out_features, self.bias is not None
113 | )
114 |
115 | ##### Conv
116 |
117 | class _ConvNdShift(nn.Module):
118 |
119 | __constants__ = ['stride', 'padding', 'dilation', 'groups', 'bias', 'padding_mode']
120 |
121 | def __init__(self, in_channels, out_channels, kernel_size, stride,
122 | padding, dilation, transposed, output_padding,
123 | groups, bias, padding_mode,
124 | check_grad=False, freeze_sign=False,
125 | rounding='deterministic', weight_bits=5):
126 | super(_ConvNdShift, self).__init__()
127 | if in_channels % groups != 0:
128 | raise ValueError('in_channels must be divisible by groups')
129 | if out_channels % groups != 0:
130 | raise ValueError('out_channels must be divisible by groups')
131 | self.in_channels = in_channels
132 | self.out_channels = out_channels
133 | self.kernel_size = kernel_size
134 | self.stride = stride
135 | self.padding = padding
136 | self.dilation = dilation
137 | self.transposed = transposed
138 | self.output_padding = output_padding
139 | self.groups = groups
140 | self.padding_mode = padding_mode
141 | self.rounding=rounding
142 | self.shift_range = (-1 * (2**(weight_bits - 1) - 2), 0) # we use ternary weights to represent sign
143 | # for ps
144 | # self.shift_range = (-1 * weight_bits, 0)
145 |
146 | if check_grad:
147 | tensor_constructor = torch.DoubleTensor # double precision required to check grad
148 | else:
149 | tensor_constructor = torch.Tensor # In PyTorch torch.Tensor is alias torch.FloatTensor
150 |
151 | if transposed:
152 | self.shift = nn.Parameter(tensor_constructor(
153 | in_channels, out_channels // groups, *kernel_size))
154 | self.sign = nn.Parameter(tensor_constructor(
155 | in_channels, out_channels // groups, *kernel_size),
156 | requires_grad = (freeze_sign == False))
157 | else:
158 | self.shift = nn.Parameter(tensor_constructor(
159 | out_channels, in_channels // groups, *kernel_size))
160 | self.sign = nn.Parameter(tensor_constructor(
161 | out_channels, in_channels // groups, *kernel_size),
162 | requires_grad = (freeze_sign == False))
163 | if bias:
164 | self.bias = nn.Parameter(tensor_constructor(out_channels))
165 | else:
166 | self.register_parameter('bias', None)
167 | self.reset_parameters(weight_bits)
168 |
169 | def reset_parameters(self, weight_bits):
170 | self.shift.data.uniform_(*self.shift_range) # (-0.1, 0.1)
171 | self.sign.data.uniform_(-1, 1) # (-0.1, 0.1)
172 |
173 | if self.bias is not None:
174 | fan_in, _ = init._calculate_fan_in_and_fan_out(self.shift)
175 | bound = 1 / math.sqrt(fan_in)
176 | init.uniform_(self.bias, -bound, bound)
177 |
178 | def extra_repr(self):
179 | s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
180 | ', stride={stride}')
181 | if self.padding != (0,) * len(self.padding):
182 | s += ', padding={padding}'
183 | if self.dilation != (1,) * len(self.dilation):
184 | s += ', dilation={dilation}'
185 | if self.output_padding != (0,) * len(self.output_padding):
186 | s += ', output_padding={output_padding}'
187 | if self.groups != 1:
188 | s += ', groups={groups}'
189 | if self.bias is None:
190 | s += ', bias=False'
191 | return s.format(**self.__dict__)
192 |
193 |
194 | class Conv2dShift(_ConvNdShift):
195 | def __init__(self, in_channels, out_channels, kernel_size, stride=1,
196 | padding=0, dilation=1, groups=1,
197 | bias=True, padding_mode='zeros',
198 | check_grad=False, freeze_sign=False, use_kernel=False, use_cuda=True, rounding='deterministic', weight_bits=5, threshold=0.3, input_bits=16):
199 | kernel_size = _pair(kernel_size)
200 | stride = _pair(stride)
201 | padding = _pair(padding)
202 | dilation = _pair(dilation)
203 | self.use_kernel = use_kernel
204 | self.use_cuda = use_cuda
205 | self.conc_weight = None
206 | self.threshold = threshold
207 | self.input_bits = input_bits
208 | super(Conv2dShift, self).__init__(
209 | in_channels, out_channels, kernel_size, stride, padding, dilation,
210 | False, _pair(0), groups, bias, padding_mode,
211 | check_grad, freeze_sign, rounding, weight_bits)
212 |
213 | #@weak_script_method
214 | def forward(self, input):
215 | self.shift.data = ste.clamp(self.shift.data, *self.shift_range)
216 | shift_rounded = ste.round(self.shift, self.rounding)
217 |
218 | if self.threshold is None:
219 | sign_rounded_signed = ste.sign(ste.round(self.sign, self.rounding))
220 | else:
221 | sign_rounded_signed = ste.sign(ste.myround(self.sign, self.threshold))
222 | weight_ps = ste.unsym_grad_mul(2**shift_rounded, sign_rounded_signed)
223 |
224 | input_fixed_point = ste.round_fixed_point(input, quant_bits=self.input_bits)
225 |
226 | if self.bias is not None:
227 | bias_fixed_point = ste.round_fixed_point(self.bias, quant_bits=self.input_bits)
228 | else:
229 | bias_fixed_point = None
230 |
231 | if self.padding_mode == 'circular':
232 | expanded_padding = ((self.padding[1] + 1) // 2, self.padding[1] // 2,
233 | (self.padding[0] + 1) // 2, self.padding[0] // 2)
234 |
235 | input_padded = F.pad(input_fixed_point, expanded_padding, mode='circular')
236 | padding = _pair(0)
237 | else:
238 | input_padded = input_fixed_point
239 | padding = self.padding
240 |
241 |
242 | output = torch.nn.functional.conv2d(input_padded, weight_ps, bias_fixed_point,
243 | self.stride, padding, self.dilation, self.groups)
244 |
245 | # quantize backpropogation
246 | if self.input_bits > 0:
247 | output = quantize_grad(output, num_bits=self.input_bits, flatten_dims=(1, -1))
248 |
249 | return output
--------------------------------------------------------------------------------
/torchshiftadd/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .resnet20 import resnet20
2 | from .resnet20_adder import resnet20_adder
3 | from .resnet20_shift import convert_to_shift
4 | from .resnet20_shiftadd import resnet20_shiftadd
--------------------------------------------------------------------------------
/torchshiftadd/models/resnet20.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | __all__ = ['resnet20']
6 |
7 |
8 | def conv3x3(in_planes, out_planes, stride=1):
9 | return nn.Conv2d(
10 | in_planes,
11 | out_planes,
12 | kernel_size=3,
13 | stride=stride,
14 | padding=1,
15 | bias=False,
16 | )
17 |
18 |
19 | class BasicBlock(nn.Module):
20 |
21 | expansion=1
22 |
23 | def __init__(self, inplanes, planes, stride=1, downsample=None):
24 | super(BasicBlock, self).__init__()
25 | self.conv1 = conv3x3(inplanes, planes, stride=stride)
26 | self.bn1 = nn.BatchNorm2d(planes)
27 | self.relu = nn.ReLU(inplace=True)
28 | self.conv2 = conv3x3(planes, planes)
29 | self.bn2 = nn.BatchNorm2d(planes)
30 | self.downsample = downsample
31 | self.stride = stride
32 |
33 | def forward(self, x):
34 | residual = x
35 |
36 | out = self.conv1(x)
37 | out = self.bn1(out)
38 | out = self.relu(out)
39 |
40 | out = self.conv2(out)
41 | out = self.bn2(out)
42 |
43 | if self.downsample is not None:
44 | residual = self.downsample(x)
45 |
46 | out += residual
47 | out = self.relu(out)
48 |
49 | return out
50 |
51 | class ResNet(nn.Module):
52 |
53 | def __init__(self, block, layers, num_classes=10):
54 | super(ResNet, self).__init__()
55 | self.inplanes = 16
56 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
57 | self.bn1 = nn.BatchNorm2d(16)
58 | self.relu = nn.ReLU(inplace=True)
59 | self.layer1 = self._make_layer(block, 16, layers[0])
60 | self.layer2 = self._make_layer(block, 32, layers[1], stride=2)
61 | self.layer3 = self._make_layer(block, 64, layers[2], stride=2)
62 | self.avgpool = nn.AvgPool2d(8, stride=1)
63 | self.fc = nn.Conv2d(64 * block.expansion, num_classes, 1, bias=False)
64 | self.bn2 = nn.BatchNorm2d(num_classes)
65 |
66 |
67 | # init
68 | for m in self.modules():
69 | if isinstance(m, nn.BatchNorm2d):
70 | m.weight.data.fill_(1)
71 | m.bias.data.zero_()
72 |
73 | def _make_layer(self, block, planes, blocks, stride=1):
74 | downsample = None
75 | if stride != 1 or self.inplanes != planes * block.expansion:
76 | downsample = nn.Sequential(
77 | nn.Conv2d(
78 | self.inplanes,
79 | planes * block.expansion,
80 | kernel_size=1,
81 | stride=stride,
82 | bias=False,
83 | ),
84 | nn.BatchNorm2d(planes * block.expansion)
85 | )
86 |
87 | layers = []
88 | layers.append(
89 | block(
90 | inplanes=self.inplanes,
91 | planes=planes,
92 | stride=stride,
93 | downsample=downsample,
94 | )
95 | )
96 | self.inplanes = planes * block.expansion
97 | for _ in range(1, blocks):
98 | layers.append(
99 | block(
100 | inplanes=self.inplanes,
101 | planes=planes,
102 | )
103 | )
104 |
105 | return nn.Sequential(*layers)
106 |
107 | def forward(self, x):
108 | x = self.conv1(x)
109 | x = self.bn1(x)
110 | x = self.relu(x)
111 |
112 | x = self.layer1(x)
113 | x = self.layer2(x)
114 | x = self.layer3(x)
115 |
116 | x = self.avgpool(x)
117 | x = self.fc(x)
118 | x = self.bn2(x)
119 |
120 | return x.view(x.size(0), -1)
121 |
122 | def resnet20(num_classes=10, **kwargs):
123 | return ResNet(
124 | BasicBlock,
125 | [3, 3, 3],
126 | num_classes=num_classes,
127 | )
--------------------------------------------------------------------------------
/torchshiftadd/models/resnet20_adder.py:
--------------------------------------------------------------------------------
1 | from torchshiftadd.layers import adder
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | __all__ = ['resnet20_adder']
6 |
7 |
8 | def conv3x3(in_planes, out_planes, stride=1):
9 | return adder.Adder2D(
10 | in_planes,
11 | out_planes,
12 | kernel_size=3,
13 | stride=stride,
14 | padding=1,
15 | bias=False,
16 | )
17 |
18 |
19 | class BasicBlock(nn.Module):
20 |
21 | expansion=1
22 |
23 | def __init__(self, inplanes, planes, stride=1, downsample=None):
24 | super(BasicBlock, self).__init__()
25 | self.conv1 = conv3x3(inplanes, planes, stride=stride)
26 | self.bn1 = nn.BatchNorm2d(planes)
27 | self.relu = nn.ReLU(inplace=True)
28 | self.conv2 = conv3x3(planes, planes)
29 | self.bn2 = nn.BatchNorm2d(planes)
30 | self.downsample = downsample
31 | self.stride = stride
32 |
33 | def forward(self, x):
34 | residual = x
35 |
36 | out = self.conv1(x)
37 | out = self.bn1(out)
38 | out = self.relu(out)
39 |
40 | out = self.conv2(out)
41 | out = self.bn2(out)
42 |
43 | if self.downsample is not None:
44 | residual = self.downsample(x)
45 |
46 | out += residual
47 | out = self.relu(out)
48 |
49 | return out
50 |
51 | class ResNet(nn.Module):
52 |
53 | def __init__(self, block, layers, num_classes=10):
54 | super(ResNet, self).__init__()
55 | self.inplanes = 16
56 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
57 | self.bn1 = nn.BatchNorm2d(16)
58 | self.relu = nn.ReLU(inplace=True)
59 | self.layer1 = self._make_layer(block, 16, layers[0])
60 | self.layer2 = self._make_layer(block, 32, layers[1], stride=2)
61 | self.layer3 = self._make_layer(block, 64, layers[2], stride=2)
62 | self.avgpool = nn.AvgPool2d(8, stride=1)
63 | # use conv as fc layer (addernet)
64 | self.fc = nn.Conv2d(64 * block.expansion, num_classes, 1, bias=False)
65 | self.bn2 = nn.BatchNorm2d(num_classes)
66 |
67 |
68 | # init (for adder)
69 | for m in self.modules():
70 | if isinstance(m, nn.BatchNorm2d):
71 | m.weight.data.fill_(1)
72 | m.bias.data.zero_()
73 |
74 | def _make_layer(self, block, planes, blocks, stride=1):
75 | downsample = None
76 | if stride != 1 or self.inplanes != planes * block.expansion:
77 | downsample = nn.Sequential(
78 | adder.Adder2D(
79 | self.inplanes,
80 | planes * block.expansion,
81 | kernel_size=1,
82 | stride=stride,
83 | bias=False,
84 | ),
85 | nn.BatchNorm2d(planes * block.expansion)
86 | )
87 |
88 | layers = []
89 | layers.append(
90 | block(
91 | inplanes=self.inplanes,
92 | planes=planes,
93 | stride=stride,
94 | downsample=downsample,
95 | )
96 | )
97 | self.inplanes = planes * block.expansion
98 | for _ in range(1, blocks):
99 | layers.append(
100 | block(
101 | inplanes=self.inplanes,
102 | planes=planes,
103 | )
104 | )
105 |
106 | return nn.Sequential(*layers)
107 |
108 | def forward(self, x):
109 | x = self.conv1(x)
110 | x = self.bn1(x)
111 | x = self.relu(x)
112 |
113 | x = self.layer1(x)
114 | x = self.layer2(x)
115 | x = self.layer3(x)
116 |
117 | x = self.avgpool(x)
118 | x = self.fc(x)
119 | x = self.bn2(x)
120 |
121 | return x.view(x.size(0), -1)
122 |
123 | def resnet20_adder(num_classes=10, **kwargs):
124 | return ResNet(
125 | BasicBlock,
126 | [3, 3, 3],
127 | num_classes=num_classes,
128 | )
--------------------------------------------------------------------------------
/torchshiftadd/models/resnet20_shift.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | from torchshiftadd.layers import shift
6 |
7 | def convert_to_shift(model):
8 | conversion_count = 0
9 |
10 | for name, module in reversed(model._modules.items()):
11 |
12 | if len(list(module.children())) > 0:
13 | model._modules[name], num_converted = convert_to_shift(model=module)
14 | conversion_count += num_converted
15 |
16 | if type(module) == nn.Conv2d:
17 | conv2d = module
18 | shift_conv2d = shift.Conv2dShift(
19 | module.in_channels,
20 | module.out_channels,
21 | module.kernel_size,
22 | module.stride,
23 | module.padding,
24 | module.dilation,
25 | module.groups,
26 | module.bias is not None,
27 | module.padding_mode
28 | )
29 | shift_conv2d.shift.data, shift_conv2d.sign.data = get_shift_and_sign(conv2d.weight)
30 | shift_conv2d.bias = conv2d.bias
31 | model._modules[name] = shift_conv2d
32 | conversion_count += 1
33 |
34 | return model, conversion_count
35 |
36 | def get_shift_and_sign(x, rounding='deterministic'):
37 | sign = torch.sign(x)
38 |
39 | x_abs = torch.abs(x)
40 | shift = round(torch.log(x_abs) / np.log(2), rounding)
41 |
42 | return shift, sign
43 |
44 | def round(x, rounding='deterministic'):
45 | assert(rounding in ['deterministic', 'stochastic'])
46 | if rounding == 'stochastic':
47 | x_floor = x.floor()
48 | return x_floor + torch.bernoulli(x - x_floor)
49 | else:
50 | return x.round()
--------------------------------------------------------------------------------
/torchshiftadd/models/resnet20_shiftadd.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torchshiftadd.layers import adder, shift
5 |
6 | __all__ = ['resnet20_shiftadd']
7 |
8 |
9 | def conv3x3(in_planes, out_planes, stride=1):
10 | shift_layer = shift.Conv2dShift(
11 | in_planes,
12 | out_planes,
13 | kernel_size=3,
14 | stride=stride,
15 | padding=1,
16 | bias=False
17 | )
18 | add_layer = adder.Adder2D(
19 | out_planes,
20 | out_planes,
21 | kernel_size=3,
22 | stride=1,
23 | padding=1,
24 | bias=False,
25 | )
26 | return nn.Sequential(shift_layer, add_layer)
27 |
28 |
29 | class BasicBlock(nn.Module):
30 |
31 | expansion=1
32 |
33 | def __init__(self, inplanes, planes, stride=1, downsample=None):
34 | super(BasicBlock, self).__init__()
35 | self.conv1 = conv3x3(inplanes, planes, stride=stride)
36 | self.bn1 = nn.BatchNorm2d(planes)
37 | self.relu = nn.ReLU(inplace=True)
38 | self.conv2 = conv3x3(planes, planes)
39 | self.bn2 = nn.BatchNorm2d(planes)
40 | self.downsample = downsample
41 | self.stride = stride
42 |
43 | def forward(self, x):
44 | residual = x
45 |
46 | out = self.conv1(x)
47 | out = self.bn1(out)
48 | out = self.relu(out)
49 |
50 | out = self.conv2(out)
51 | out = self.bn2(out)
52 |
53 | if self.downsample is not None:
54 | residual = self.downsample(x)
55 |
56 | out += residual
57 | out = self.relu(out)
58 |
59 | return out
60 |
61 | class ResNet(nn.Module):
62 |
63 | def __init__(self, block, layers, num_classes=10):
64 | super(ResNet, self).__init__()
65 | self.inplanes = 16
66 | self.conv1 = shift.Conv2dShift(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
67 | self.bn1 = nn.BatchNorm2d(16)
68 | self.relu = nn.ReLU(inplace=True)
69 | self.layer1 = self._make_layer(block, 16, layers[0])
70 | self.layer2 = self._make_layer(block, 32, layers[1], stride=2)
71 | self.layer3 = self._make_layer(block, 64, layers[2], stride=2)
72 | self.avgpool = nn.AvgPool2d(8, stride=1)
73 | self.fc = shift.Conv2dShift(64 * block.expansion, num_classes, 1, bias=False)
74 | self.bn2 = nn.BatchNorm2d(num_classes)
75 |
76 |
77 | # init
78 | for m in self.modules():
79 | if isinstance(m, nn.BatchNorm2d):
80 | m.weight.data.fill_(1)
81 | m.bias.data.zero_()
82 |
83 | def _make_layer(self, block, planes, blocks, stride=1):
84 | downsample = None
85 | if stride != 1 or self.inplanes != planes * block.expansion:
86 | downsample = nn.Sequential(
87 | shift.Conv2dShift(
88 | self.inplanes,
89 | planes * block.expansion,
90 | kernel_size=1,
91 | stride=stride,
92 | bias=False,
93 | ),
94 | adder.Adder2D(
95 | planes * block.expansion,
96 | planes * block.expansion,
97 | kernel_size=1,
98 | stride=1,
99 | bias=False,
100 | ),
101 | nn.BatchNorm2d(planes * block.expansion)
102 | )
103 |
104 | layers = []
105 | layers.append(
106 | block(
107 | inplanes=self.inplanes,
108 | planes=planes,
109 | stride=stride,
110 | downsample=downsample,
111 | )
112 | )
113 | self.inplanes = planes * block.expansion
114 | for _ in range(1, blocks):
115 | layers.append(
116 | block(
117 | inplanes=self.inplanes,
118 | planes=planes,
119 | )
120 | )
121 |
122 | return nn.Sequential(*layers)
123 |
124 | def forward(self, x):
125 | x = self.conv1(x)
126 | x = self.bn1(x)
127 | x = self.relu(x)
128 |
129 | x = self.layer1(x)
130 | x = self.layer2(x)
131 | x = self.layer3(x)
132 |
133 | x = self.avgpool(x)
134 | x = self.fc(x)
135 | x = self.bn2(x)
136 |
137 | return x.view(x.size(0), -1)
138 |
139 | def resnet20_shiftadd(num_classes=10, **kwargs):
140 | return ResNet(
141 | BasicBlock,
142 | [3, 3, 3],
143 | num_classes=num_classes,
144 | )
--------------------------------------------------------------------------------
/torchshiftadd/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .torch import load_extension
2 | from . import comm
3 | from .ckpt_loading import load_add_state_dict, load_shiftadd_state_dict
4 | from .test_acc import test_acc
5 |
6 | __all__ = [
7 | "load_extension",
8 | "comm",
9 | "load_add_state_dict",
10 | "load_shiftadd_state_dict",
11 | "test_acc"
12 | ]
--------------------------------------------------------------------------------
/torchshiftadd/utils/ckpt_loading.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | def load_add_state_dict(state_dict):
4 | new_state_dict = OrderedDict()
5 | for k, v in state_dict.items():
6 | if 'weight' in k and not 'bn' in k and not 'fc' in k:
7 | if k == 'conv1.weight' or 'downsample.1' in k:
8 | new_state_dict[k] = v
9 | continue
10 | k = k[:-6] + 'adder'
11 | # print(k)
12 | new_state_dict[k] = v
13 | return new_state_dict
14 |
15 | def load_shiftadd_state_dict(state_dict):
16 | from collections import OrderedDict
17 | new_state_dict = OrderedDict()
18 | for k, v in state_dict.items():
19 | if 'weight' in k and not 'bn' in k and not 'fc' in k:
20 | if k == 'conv1.weight' or 'downsample.2' in k:
21 | new_state_dict[k] = v
22 | continue
23 | k = k[:-6] + 'adder'
24 | # print(k)
25 | new_state_dict[k] = v
26 | return new_state_dict
--------------------------------------------------------------------------------
/torchshiftadd/utils/comm.py:
--------------------------------------------------------------------------------
1 | import os
2 | import multiprocessing
3 | from collections import defaultdict
4 |
5 | import torch
6 | from torch import distributed as dist
7 |
8 |
9 | cpu_group = None
10 | gpu_group = None
11 |
12 |
13 | def get_rank():
14 | """
15 | Get the rank of this process in distributed processes.
16 |
17 | Return 0 for single process case.
18 | """
19 | if dist.is_initialized():
20 | return dist.get_rank()
21 | if "RANK" in os.environ:
22 | return int(os.environ["RANK"])
23 | return 0
24 |
25 |
26 | def get_world_size():
27 | """
28 | Get the total number of distributed processes.
29 |
30 | Return 1 for single process case.
31 | """
32 | if dist.is_initialized():
33 | return dist.get_world_size()
34 | if "WORLD_SIZE" in os.environ:
35 | return int(os.environ["WORLD_SIZE"])
36 | return 1
37 |
38 |
39 | def get_group(device):
40 | """
41 | Get the process group corresponding to the given device.
42 |
43 | Parameters:
44 | device (torch.device): query device
45 | """
46 | group = cpu_group if device.type == "cpu" else gpu_group
47 | if group is None:
48 | raise ValueError("%s group is not initialized. Use comm.init_process_group() to initialize it"
49 | % device.type.upper())
50 | return group
51 |
52 |
53 | def init_process_group(backend, init_method=None, **kwargs):
54 | """
55 | Initialize CPU and/or GPU process groups.
56 |
57 | Parameters:
58 | backend (str): Communication backend. Use ``nccl`` for GPUs and ``gloo`` for CPUs.
59 | init_method (str, optional): URL specifying how to initialize the process group
60 | """
61 | global cpu_group
62 | global gpu_group
63 |
64 | dist.init_process_group(backend, init_method, **kwargs)
65 | gpu_group = dist.group.WORLD
66 | if backend == "nccl":
67 | cpu_group = dist.new_group(backend="gloo")
68 | else:
69 | cpu_group = gpu_group
70 |
71 |
72 | def get_cpu_count():
73 | """
74 | Get the number of CPUs on this node.
75 | """
76 | return multiprocessing.cpu_count()
77 |
78 |
79 | def synchronize():
80 | """
81 | Synchronize among all distributed processes.
82 | """
83 | if get_world_size() > 1:
84 | dist.barrier()
85 |
86 |
87 | def _recursive_read(obj):
88 | values = defaultdict(list)
89 | sizes = defaultdict(list)
90 | if isinstance(obj, torch.Tensor):
91 | values[obj.dtype] += [obj.flatten()]
92 | sizes[obj.dtype] += [torch.tensor([obj.numel()], device=obj.device)]
93 | elif isinstance(obj, dict):
94 | for v in obj.values():
95 | child_values, child_sizes = _recursive_read(v)
96 | for k, v in child_values.items():
97 | values[k] += v
98 | for k, v in child_sizes.items():
99 | sizes[k] += v
100 | elif isinstance(obj, list) or isinstance(obj, tuple):
101 | for v in obj:
102 | child_values, child_sizes = _recursive_read(v)
103 | for k, v in child_values.items():
104 | values[k] += v
105 | for k, v in child_sizes.items():
106 | sizes[k] += v
107 | else:
108 | raise ValueError("Unknown type `%s`" % type(obj))
109 | return values, sizes
110 |
111 |
112 | def _recursive_write(obj, values, sizes=None):
113 | if isinstance(obj, torch.Tensor):
114 | if sizes is None:
115 | size = torch.tensor([obj.numel()], device=obj.device)
116 | else:
117 | s = sizes[obj.dtype]
118 | size, s = s.split([1, len(s) - 1])
119 | sizes[obj.dtype] = s
120 | v = values[obj.dtype]
121 | new_obj, v = v.split([size, v.shape[-1] - size], dim=-1)
122 | # compatible with reduce / stack / cat
123 | new_obj = new_obj.view(new_obj.shape[:-1] + (-1,) + obj.shape[1:])
124 | values[obj.dtype] = v
125 | return new_obj, values
126 | elif isinstance(obj, dict):
127 | new_obj = {}
128 | for k, v in obj.items():
129 | new_obj[k], values = _recursive_write(v, values, sizes)
130 | elif isinstance(obj, list) or isinstance(obj, tuple):
131 | new_obj = []
132 | for v in obj:
133 | new_v, values = _recursive_write(v, values, sizes)
134 | new_obj.append(new_v)
135 | else:
136 | raise ValueError("Unknown type `%s`" % type(obj))
137 | return new_obj, values
138 |
139 |
140 | def reduce(obj, op="sum", dst=None):
141 | """
142 | Reduce any nested container of tensors.
143 |
144 | Parameters:
145 | obj (Object): any container object. Can be nested list, tuple or dict.
146 | op (str, optional): element-wise reduction operator.
147 | Available operators are ``sum``, ``mean``, ``min``, ``max``, ``product``.
148 | dst (int, optional): rank of destination worker. If not specified, broadcast the result to all workers.
149 |
150 | Example::
151 |
152 | >>> # assume 4 workers
153 | >>> rank = comm.get_rank()
154 | >>> x = torch.rand(5)
155 | >>> obj = {"polynomial": x ** rank}
156 | >>> obj = comm.reduce(obj)
157 | >>> assert torch.allclose(obj["polynomial"], x ** 3 + x ** 2 + x + 1)
158 | """
159 | values = _recursive_read(obj)[0]
160 | values = {k: torch.cat(v) for k, v in values.items()}
161 |
162 | is_mean = op == "mean"
163 | if is_mean:
164 | op = "sum"
165 | op = getattr(dist.ReduceOp, op.upper())
166 |
167 | reduced = {}
168 | for k, v in values.items():
169 | dtype = v.dtype
170 | # NCCL can't solve bool. Cast them to byte
171 | if dtype == torch.bool:
172 | v = v.byte()
173 | group = get_group(v.device)
174 | if dst is None:
175 | dist.all_reduce(v, op=op, group=group)
176 | else:
177 | dist.reduce(v, op=op, dst=dst, group=group)
178 | if is_mean:
179 | v = v / get_world_size()
180 | reduced[k] = v.type(dtype)
181 |
182 | return _recursive_write(obj, reduced)[0]
183 |
184 |
185 | def stack(obj, dst=None):
186 | """
187 | Stack any nested container of tensors. The new dimension will be added at the 0-th axis.
188 |
189 | Parameters:
190 | obj (Object): any container object. Can be nested list, tuple or dict.
191 | dst (int, optional): rank of destination worker. If not specified, broadcast the result to all workers.
192 |
193 | Example::
194 |
195 | >>> # assume 4 workers
196 | >>> rank = comm.get_rank()
197 | >>> x = torch.rand(5)
198 | >>> obj = {"exponent": x ** rank}
199 | >>> obj = comm.stack(obj)
200 | >>> truth = torch.stack([torch.ones_like(x), x, x ** 2, x ** 3]
201 | >>> assert torch.allclose(obj["exponent"], truth))
202 | """
203 | values = _recursive_read(obj)[0]
204 | values = {k: torch.cat(v) for k, v in values.items()}
205 |
206 | stacked = {}
207 | for k, v in values.items():
208 | dtype = v.dtype
209 | # NCCL can't solve bool. Cast them to byte
210 | if dtype == torch.bool:
211 | dtype = torch.uint8
212 | s = torch.zeros(get_world_size(), *v.shape, dtype=dtype, device=v.device)
213 | s[get_rank()] = v
214 | group = get_group(s.device)
215 | if dst is None:
216 | dist.all_reduce(s, op=dist.ReduceOp.SUM, group=group)
217 | else:
218 | dist.reduce(s, op=dist.ReduceOp.SUM, dst=dst, group=group)
219 | stacked[k] = s.type(v.dtype)
220 |
221 | return _recursive_write(obj, stacked)[0]
222 |
223 |
224 | def cat(obj, dst=None):
225 | """
226 | Concatenate any nested container of tensors along the 0-th axis.
227 |
228 | Parameters:
229 | obj (Object): any container object. Can be nested list, tuple or dict.
230 | dst (int, optional): rank of destination worker. If not specified, broadcast the result to all workers.
231 |
232 | Example::
233 |
234 | >>> # assume 4 workers
235 | >>> rank = comm.get_rank()
236 | >>> rng = torch.arange(10)
237 | >>> obj = {"range": rng[rank * (rank + 1) // 2: (rank + 1) * (rank + 2) // 2]}
238 | >>> obj = comm.cat(obj)
239 | >>> assert torch.allclose(obj["range"], rng)
240 | """
241 | values, sizes = _recursive_read(obj)
242 | sizes = {k: torch.cat(v) for k, v in sizes.items()}
243 |
244 | sizes = stack(sizes)
245 | cated = {}
246 | for k, value in values.items():
247 | size = sizes[k].t().flatten() # sizes[k]: (num_worker, num_obj)
248 | dtype = value[0].dtype
249 | # NCCL can't solve bool. Cast them to byte
250 | if dtype == torch.bool:
251 | dtype = torch.uint8
252 | s = torch.zeros(size.sum(), dtype=dtype, device=value[0].device)
253 | obj_id = get_rank()
254 | world_size = get_world_size()
255 | offset = size[:obj_id].sum()
256 | for v in value:
257 | assert offset + v.numel() <= len(s)
258 | s[offset: offset + v.numel()] = v
259 | offset += size[obj_id: obj_id + world_size].sum()
260 | obj_id += world_size
261 | group = get_group(s.device)
262 | if dst is None:
263 | dist.all_reduce(s, op=dist.ReduceOp.SUM, group=group)
264 | else:
265 | dist.reduce(s, op=dist.ReduceOp.SUM, dst=dst, group=group)
266 | cated[k] = s.type(value[0].dtype)
267 | sizes = {k: v.sum(dim=0) for k, v in sizes.items()}
268 |
269 | return _recursive_write(obj, cated, sizes)[0]
--------------------------------------------------------------------------------
/torchshiftadd/utils/decorator.py:
--------------------------------------------------------------------------------
1 | class cached_property(property):
2 | """
3 | Cache the property once computed.
4 | """
5 |
6 | def __init__(self, func):
7 | self.func = func
8 | self.__doc__ = func.__doc__
9 |
10 | def __get__(self, obj, cls):
11 | if obj is None:
12 | return self
13 | result = self.func(obj)
14 | obj.__dict__[self.func.__name__] = result
15 | return result
16 |
--------------------------------------------------------------------------------
/torchshiftadd/utils/quantize.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | import math
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch.autograd.function import InplaceFunction, Function
7 |
8 | QParams = namedtuple('QParams', ['range', 'zero_point', 'num_bits'])
9 |
10 | _DEFAULT_FLATTEN = (1, -1)
11 | _DEFAULT_FLATTEN_GRAD = (0, -1)
12 |
13 |
14 | def _deflatten_as(x, x_full):
15 | shape = list(x.shape) + [1] * (x_full.dim() - x.dim())
16 | return x.view(*shape)
17 |
18 |
19 | def calculate_qparams(x, num_bits, flatten_dims=_DEFAULT_FLATTEN, reduce_dim=0, reduce_type='mean', keepdim=False, true_zero=False):
20 | with torch.no_grad():
21 | x_flat = x.flatten(*flatten_dims)
22 | if x_flat.dim() == 1:
23 | min_values = _deflatten_as(x_flat.min(), x)
24 | max_values = _deflatten_as(x_flat.max(), x)
25 | else:
26 | min_values = _deflatten_as(x_flat.min(-1)[0], x)
27 | max_values = _deflatten_as(x_flat.max(-1)[0], x)
28 |
29 | if reduce_dim is not None:
30 | if reduce_type == 'mean':
31 | min_values = min_values.mean(reduce_dim, keepdim=keepdim)
32 | max_values = max_values.mean(reduce_dim, keepdim=keepdim)
33 | else:
34 | min_values = min_values.min(reduce_dim, keepdim=keepdim)[0]
35 | max_values = max_values.max(reduce_dim, keepdim=keepdim)[0]
36 |
37 | range_values = max_values - min_values
38 | return QParams(range=range_values, zero_point=min_values,
39 | num_bits=num_bits)
40 |
41 |
42 | class UniformQuantize(InplaceFunction):
43 |
44 | @staticmethod
45 | def forward(ctx, input, num_bits=None, qparams=None, flatten_dims=_DEFAULT_FLATTEN,
46 | reduce_dim=0, dequantize=True, signed=False, stochastic=False, inplace=False):
47 |
48 | ctx.inplace = inplace
49 |
50 | if ctx.inplace:
51 | ctx.mark_dirty(input)
52 | output = input
53 | else:
54 | output = input.clone()
55 |
56 | if qparams is None:
57 | assert num_bits is not None, "either provide qparams of num_bits to quantize"
58 | qparams = calculate_qparams(
59 | input, num_bits=num_bits, flatten_dims=flatten_dims, reduce_dim=reduce_dim)
60 |
61 | zero_point = qparams.zero_point
62 | num_bits = qparams.num_bits
63 | qmin = -(2.**(num_bits - 1)) if signed else 0.
64 | qmax = qmin + 2.**num_bits - 1.
65 | scale = qparams.range / (qmax - qmin)
66 |
67 | min_scale = torch.tensor(1e-8).expand_as(scale).cuda()
68 | scale = torch.max(scale, min_scale)
69 |
70 | with torch.no_grad():
71 | output.add_(qmin * scale - zero_point).div_(scale)
72 | if stochastic:
73 | # print('use stochastic')
74 | noise = output.new(output.shape).uniform_(-0.5, 0.5)
75 | output.add_(noise)
76 | # quantize
77 | output.clamp_(qmin, qmax).round_()
78 |
79 | if dequantize:
80 | output.mul_(scale).add_(
81 | zero_point - qmin * scale) # dequantize
82 | return output
83 |
84 | @staticmethod
85 | def backward(ctx, grad_output):
86 | # straight-through estimator
87 | grad_input = grad_output
88 | return grad_input, None, None, None, None, None, None, None, None
89 |
90 |
91 | class UniformQuantizeGrad(InplaceFunction):
92 |
93 | @staticmethod
94 | def forward(ctx, input, num_bits=None, qparams=None, flatten_dims=_DEFAULT_FLATTEN_GRAD,
95 | reduce_dim=0, dequantize=True, signed=False, stochastic=True):
96 | ctx.num_bits = num_bits
97 | ctx.qparams = qparams
98 | ctx.flatten_dims = flatten_dims
99 | ctx.stochastic = stochastic
100 | ctx.signed = signed
101 | ctx.dequantize = dequantize
102 | ctx.reduce_dim = reduce_dim
103 | ctx.inplace = False
104 | return input
105 |
106 | @staticmethod
107 | def backward(ctx, grad_output):
108 | qparams = ctx.qparams
109 | with torch.no_grad():
110 | if qparams is None:
111 | assert ctx.num_bits is not None, "either provide qparams of num_bits to quantize"
112 | qparams = calculate_qparams(
113 | grad_output, num_bits=ctx.num_bits, flatten_dims=ctx.flatten_dims, reduce_dim=ctx.reduce_dim, reduce_type='extreme')
114 |
115 | grad_input = quantize(grad_output, num_bits=None,
116 | qparams=qparams, flatten_dims=ctx.flatten_dims, reduce_dim=ctx.reduce_dim,
117 | dequantize=True, signed=ctx.signed, stochastic=ctx.stochastic, inplace=False)
118 | return grad_input, None, None, None, None, None, None, None
119 |
120 |
121 | def conv2d_biprec(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, num_bits_grad=None):
122 | out1 = F.conv2d(input.detach(), weight, bias,
123 | stride, padding, dilation, groups)
124 | out2 = F.conv2d(input, weight.detach(), bias.detach() if bias is not None else None,
125 | stride, padding, dilation, groups)
126 | out2 = quantize_grad(out2, num_bits=num_bits_grad, flatten_dims=(1, -1))
127 | return out1 + out2 - out1.detach()
128 |
129 |
130 | def linear_biprec(input, weight, bias=None, num_bits_grad=None):
131 | out1 = F.linear(input.detach(), weight, bias)
132 | out2 = F.linear(input, weight.detach(), bias.detach()
133 | if bias is not None else None)
134 | out2 = quantize_grad(out2, num_bits=num_bits_grad)
135 | return out1 + out2 - out1.detach()
136 |
137 |
138 | def quantize(x, num_bits=None, qparams=None, flatten_dims=_DEFAULT_FLATTEN, reduce_dim=0, dequantize=True, signed=False, stochastic=False, inplace=False):
139 | if qparams:
140 | if qparams.num_bits:
141 | return UniformQuantize().apply(x, num_bits, qparams, flatten_dims, reduce_dim, dequantize, signed, stochastic, inplace)
142 | elif num_bits:
143 | return UniformQuantize().apply(x, num_bits, qparams, flatten_dims, reduce_dim, dequantize, signed, stochastic, inplace)
144 |
145 | return x
146 |
147 |
148 | def quantize_grad(x, num_bits=None, qparams=None, flatten_dims=_DEFAULT_FLATTEN_GRAD, reduce_dim=0, dequantize=True, signed=False, stochastic=True):
149 | if qparams:
150 | if qparams.num_bits:
151 | return UniformQuantizeGrad().apply(x, num_bits, qparams, flatten_dims, reduce_dim, dequantize, signed, stochastic)
152 | elif num_bits:
153 | return UniformQuantizeGrad().apply(x, num_bits, qparams, flatten_dims, reduce_dim, dequantize, signed, stochastic)
154 |
155 | return x
156 |
157 |
158 | class QuantMeasure(nn.Module):
159 | """docstring for QuantMeasure."""
160 |
161 | def __init__(self, shape_measure=(1,), flatten_dims=_DEFAULT_FLATTEN,
162 | inplace=False, dequantize=True, stochastic=False, momentum=0.9, measure=False):
163 | super(QuantMeasure, self).__init__()
164 | self.register_buffer('running_zero_point', torch.zeros(*shape_measure))
165 | self.register_buffer('running_range', torch.zeros(*shape_measure))
166 | self.measure = measure
167 | if self.measure:
168 | self.register_buffer('num_measured', torch.zeros(1))
169 | self.flatten_dims = flatten_dims
170 | self.momentum = momentum
171 | self.dequantize = dequantize
172 | self.stochastic = stochastic
173 | self.inplace = inplace
174 |
175 | def forward(self, input, num_bits, qparams=None):
176 |
177 | if self.training or self.measure:
178 | if qparams is None:
179 | qparams = calculate_qparams(
180 | input, num_bits=num_bits, flatten_dims=self.flatten_dims, reduce_dim=0, reduce_type='extreme')
181 | with torch.no_grad():
182 | if self.measure:
183 | momentum = self.num_measured / (self.num_measured + 1)
184 | self.num_measured += 1
185 | else:
186 | momentum = self.momentum
187 | self.running_zero_point.mul_(momentum).add_(
188 | qparams.zero_point * (1 - momentum))
189 | self.running_range.mul_(momentum).add_(
190 | qparams.range * (1 - momentum))
191 | else:
192 | qparams = QParams(range=self.running_range,
193 | zero_point=self.running_zero_point, num_bits=num_bits)
194 | if self.measure:
195 | return input
196 | else:
197 | q_input = quantize(input, qparams=qparams, dequantize=self.dequantize,
198 | stochastic=self.stochastic, inplace=self.inplace)
199 | return q_input
200 |
201 |
202 | class QConv2d(nn.Conv2d):
203 | """docstring for QConv2d."""
204 |
205 | def __init__(self, in_channels, out_channels, kernel_size,
206 | stride=1, padding=0, dilation=1, groups=1, bias=True, momentum=0.1, quant_act_forward=0, quant_act_backward=0, quant_grad_act_error=0, quant_grad_act_gc=0, weight_bits=0, fix_prec=False):
207 | super(QConv2d, self).__init__(in_channels, out_channels, kernel_size,
208 | stride, padding, dilation, groups, bias)
209 |
210 | self.quantize_input_fw = QuantMeasure(shape_measure=(1, 1, 1, 1), flatten_dims=(1, -1), momentum=momentum)
211 | self.quantize_input_bw = QuantMeasure(shape_measure=(1, 1, 1, 1), flatten_dims=(1, -1), momentum=momentum)
212 | self.quant_act_forward = quant_act_forward
213 | self.quant_act_backward = quant_act_backward
214 | self.quant_grad_act_error = quant_grad_act_error
215 | self.quant_grad_act_gc = quant_grad_act_gc
216 | self.weight_bits = weight_bits
217 | self.fix_prec = fix_prec
218 | self.stride = stride
219 |
220 |
221 | def forward(self, input, num_bits, num_grad_bits):
222 | if num_bits == 0:
223 | output = F.conv2d(input, self.weight, self.bias, self.stride,self.padding, self.dilation, self.groups)
224 | return output
225 |
226 | if self.bias is not None:
227 | qbias = quantize(
228 | self.bias, num_bits=self.num_bits_weight + self.num_bits,
229 | flatten_dims=(0, -1))
230 | else:
231 | qbias = None
232 |
233 | if self.fix_prec:
234 | if self.quant_act_forward or self.quant_act_backward or self.quant_grad_act_error or self.quant_grad_act_gc or self.weight_bits:
235 | weight_qparams = calculate_qparams(self.weight, num_bits=self.weight_bits, flatten_dims=(1, -1), reduce_dim=None)
236 | qweight = quantize(self.weight, qparams=weight_qparams)
237 |
238 | qinput_fw = self.quantize_input_fw(input, self.quant_act_forward)
239 | qinput_bw = self.quantize_input_bw(input, self.quant_act_backward)
240 |
241 | error_bits = self.quant_grad_act_error
242 | gc_bits = self.quant_grad_act_gc
243 | output = self.conv2d_quant_act(qinput_fw, qinput_bw, qweight, qbias, self.stride, self.padding, self.dilation, self.groups, error_bits, gc_bits)
244 |
245 | else:
246 | qinput = self.quantize_input_fw(input, num_bits)
247 | weight_qparams = calculate_qparams(self.weight, num_bits=num_bits, flatten_dims=(1, -1), reduce_dim=None)
248 | qweight = quantize(self.weight, qparams=weight_qparams)
249 | output = F.conv2d(qinput, qweight, qbias, self.stride, self.padding, self.dilation, self.groups)
250 | output = quantize_grad(output, num_bits=num_grad_bits, flatten_dims=(1, -1))
251 |
252 | return output
253 |
254 | weight_qparams = calculate_qparams(self.weight, num_bits=self.weight_bits, flatten_dims=(1, -1), reduce_dim=None)
255 | qweight = quantize(self.weight, qparams=weight_qparams)
256 |
257 | qinput = self.quantize_input_fw(input, num_bits)
258 | output = F.conv2d(qinput, qweight, qbias, self.stride, self.padding, self.dilation, self.groups)
259 | output = quantize_grad(output, num_bits=num_grad_bits, flatten_dims=(1, -1))
260 |
261 | # if self.quant_act_forward == -1:
262 | # qinput_fw = self.quantize_input_fw(input, num_bits)
263 | # else:
264 | # qinput_fw = self.quantize_input_fw(input, self.quant_act_forward)
265 |
266 | # if self.quant_act_backward == -1:
267 | # qinput_bw = self.quantize_input_bw(input, num_bits)
268 | # else:
269 | # qinput_bw = self.quantize_input_bw(input, self.quant_act_backward)
270 |
271 | # if self.quant_grad_act_error == -1:
272 | # error_bits = num_grad_bits
273 | # else:
274 | # error_bits = self.quant_grad_act_error
275 |
276 | # if self.quant_grad_act_gc == -1:
277 | # gc_bits = num_grad_bits
278 | # else:
279 | # gc_bits = self.quant_grad_act_gc
280 |
281 | # output = self.conv2d_quant_act(qinput_fw, qinput_bw, qweight, qbias, self.stride, self.padding, self.dilation, self.groups, error_bits, gc_bits)
282 |
283 | return output
284 |
285 |
286 | def conv2d_quant_act(self, input_fw, input_bw, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, error_bits=0, gc_bits=0):
287 | out1 = F.conv2d(input_fw, weight.detach(), bias.detach() if bias is not None else None,
288 | stride, padding, dilation, groups)
289 | out2 = F.conv2d(input_bw.detach(), weight, bias,
290 | stride, padding, dilation, groups)
291 | out1 = quantize_grad(out1, num_bits=error_bits)
292 | out2 = quantize_grad(out2, num_bits=gc_bits)
293 | return out1 + out2 - out2.detach()
294 |
295 |
296 | class QLinear(nn.Linear):
297 | """docstring for QConv2d."""
298 |
299 | def __init__(self, in_features, out_features, bias=True, num_bits=8, num_bits_weight=8, num_bits_grad=8, biprecision=False):
300 | super(QLinear, self).__init__(in_features, out_features, bias)
301 | self.num_bits = num_bits
302 | self.num_bits_weight = num_bits_weight or num_bits
303 | self.num_bits_grad = num_bits_grad
304 | self.biprecision = biprecision
305 | self.quantize_input = QuantMeasure(shape_measure=(1, 1), flatten_dims=(1, -1), momentum=0.1)
306 |
307 | def forward(self, input, num_bits, num_bits_grad):
308 | # self.quantize_input = QuantMeasure(num_bits)
309 |
310 | qinput = self.quantize_input(input, num_bits)
311 | weight_qparams = calculate_qparams(
312 | self.weight, num_bits=num_bits, flatten_dims=(1, -1), reduce_dim=None)
313 | qweight = quantize(self.weight, qparams=weight_qparams)
314 | if self.bias is not None:
315 | qbias = quantize(
316 | self.bias, num_bits=num_bits,
317 | flatten_dims=(0, -1))
318 | else:
319 | qbias = None
320 |
321 | # if not self.biprecision or self.num_bits_grad is None:
322 | output = F.linear(qinput, qweight, qbias)
323 | # if self.num_bits_grad is not None:
324 | output = quantize_grad(
325 | output, num_bits=num_bits_grad)
326 | # else:
327 | # output = linear_biprec(qinput, qweight, qbias, self.num_bits_grad)
328 |
329 |
330 | return output
331 |
332 |
333 | class RangeBN(nn.Module):
334 | # this is normalized RangeBN
335 |
336 | def __init__(self, num_features, dim=1, momentum=0.1, affine=True, num_chunks=16, eps=1e-5, num_bits=8, num_bits_grad=8):
337 | super(RangeBN, self).__init__()
338 | self.register_buffer('running_mean', torch.zeros(num_features))
339 | self.register_buffer('running_var', torch.zeros(num_features))
340 |
341 | self.momentum = momentum
342 | self.dim = dim
343 | if affine:
344 | self.bias = nn.Parameter(torch.Tensor(num_features))
345 | self.weight = nn.Parameter(torch.Tensor(num_features))
346 | self.num_bits = num_bits
347 | self.num_bits_grad = num_bits_grad
348 | self.quantize_input = QuantMeasure(inplace=True, shape_measure=(1, 1, 1, 1), flatten_dims=(1, -1))
349 | self.eps = eps
350 | self.num_chunks = num_chunks
351 | self.reset_params()
352 |
353 | def reset_params(self):
354 | if self.weight is not None:
355 | self.weight.data.uniform_()
356 | if self.bias is not None:
357 | self.bias.data.zero_()
358 |
359 | def forward(self, x, num_bits, num_grad_bits):
360 | x = self.quantize_input(x, num_bits)
361 | if x.dim() == 2: # 1d
362 | x = x.unsqueeze(-1,).unsqueeze(-1)
363 |
364 | if self.training:
365 | B, C, H, W = x.shape
366 | y = x.transpose(0, 1).contiguous() # C x B x H x W
367 | y = y.view(C, self.num_chunks, (B * H * W) // self.num_chunks)
368 | mean_max = y.max(-1)[0].mean(-1) # C
369 | mean_min = y.min(-1)[0].mean(-1) # C
370 | mean = y.view(C, -1).mean(-1) # C
371 | scale_fix = (0.5 * 0.35) * (1 + (math.pi * math.log(4)) **
372 | 0.5) / ((2 * math.log(y.size(-1))) ** 0.5)
373 |
374 | scale = (mean_max - mean_min) * scale_fix
375 | with torch.no_grad():
376 | self.running_mean.mul_(self.momentum).add_(
377 | mean * (1 - self.momentum))
378 |
379 | self.running_var.mul_(self.momentum).add_(
380 | scale * (1 - self.momentum))
381 | else:
382 | mean = self.running_mean
383 | scale = self.running_var
384 | # scale = quantize(scale, num_bits=self.num_bits, min_value=float(
385 | # scale.min()), max_value=float(scale.max()))
386 | out = (x - mean.view(1, -1, 1, 1)) / \
387 | (scale.view(1, -1, 1, 1) + self.eps)
388 |
389 | if self.weight is not None:
390 | qweight = self.weight
391 | # qweight = quantize(self.weight, num_bits=self.num_bits,
392 | # min_value=float(self.weight.min()),
393 | # max_value=float(self.weight.max()))
394 | out = out * qweight.view(1, -1, 1, 1)
395 |
396 | if self.bias is not None:
397 | qbias = self.bias
398 | # qbias = quantize(self.bias, num_bits=self.num_bits)
399 | out = out + qbias.view(1, -1, 1, 1)
400 | if num_grad_bits:
401 | out = quantize_grad(
402 | out, num_bits=num_grad_bits, flatten_dims=(1, -1))
403 |
404 | if out.size(3) == 1 and out.size(2) == 1:
405 | out = out.squeeze(-1).squeeze(-1)
406 | return out
407 |
408 |
409 | class RangeBN1d(RangeBN):
410 | # this is normalized RangeBN
411 |
412 | def __init__(self, num_features, dim=1, momentum=0.1, affine=True, num_chunks=16, eps=1e-5, num_bits=8, num_bits_grad=8):
413 | super(RangeBN1d, self).__init__(num_features, dim, momentum,
414 | affine, num_chunks, eps, num_bits, num_bits_grad)
415 | self.quantize_input = QuantMeasure(
416 | self.num_bits, inplace=True, shape_measure=(1, 1), flatten_dims=(1, -1))
417 |
418 | if __name__ == '__main__':
419 | x = torch.rand(2, 3)
420 | x_q = quantize(x, flatten_dims=(-1), num_bits=8, dequantize=True)
421 | print(x)
422 | print(x_q)
--------------------------------------------------------------------------------
/torchshiftadd/utils/ste.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Function
3 | import numpy as np
4 | import math
5 |
6 | def dynamic_range_for_sign(sign, threshold):
7 | # print(sign, threshold)
8 | with torch.no_grad():
9 | sign.data[sign.data < -threshold] = -1
10 | sign.data[sign.data > threshold] = 1
11 | sign.data[(-threshold <= sign.data) & (sign.data <= threshold)] = 0
12 | return sign
13 |
14 | class myRoundFunction(Function):
15 | @staticmethod
16 | def forward(ctx, input, threshold):
17 | return dynamic_range_for_sign(input, threshold)
18 |
19 | @staticmethod
20 | def backward(ctx, grad_output):
21 | return grad_output, None
22 |
23 | def myround(input, threshold):
24 | return myRoundFunction.apply(input, threshold)
25 |
26 |
27 | ##############
28 |
29 | def get_shift_and_sign(x, rounding='deterministic'):
30 | sign = torch.sign(x)
31 |
32 | x_abs = torch.abs(x)
33 | shift = round(torch.log(x_abs) / np.log(2), rounding)
34 |
35 | return shift, sign
36 |
37 | def round_power_of_2(x, rounding='deterministic'):
38 | shift, sign = get_shift_and_sign(x, rounding)
39 | # print(shift)
40 | x_rounded = (2.0 ** shift) * sign
41 | return x_rounded
42 |
43 | class RoundPowerOf2(Function):
44 | @staticmethod
45 | def forward(ctx, input, stochastic=False):
46 | return round_power_of_2(input, stochastic)
47 |
48 | @staticmethod
49 | def backward(ctx, grad_output):
50 | return grad_output, None
51 |
52 | def round_power_of_2(input, stochastic=False):
53 | return RoundPowerOf2.apply(input, stochastic)
54 |
55 | def round_to_fixed(input, fraction=16, integer=16):
56 | assert integer >= 1, integer
57 | if integer == 1:
58 | return torch.sign(input) - 1
59 | delta = math.pow(2.0, -(fraction))
60 | bound = math.pow(2.0, integer-1)
61 | min_val = - bound
62 | max_val = bound - 1
63 | rounded = torch.floor(input / delta) * delta
64 |
65 | clipped_value = torch.clamp(rounded, min_val, max_val)
66 | return clipped_value
67 |
68 | class RoundFixedPoint(Function):
69 | @staticmethod
70 | def forward(ctx, input, quant_bits):
71 | return round_to_fixed(input, fraction=quant_bits)
72 |
73 | @staticmethod
74 | def backward(ctx, grad_output):
75 | return grad_output, None
76 |
77 | def round_fixed_point(input, quant_bits):
78 | return RoundFixedPoint.apply(input, quant_bits)
79 |
80 | def new_round(x, rounding='deterministic'):
81 | assert(rounding in ['deterministic', 'stochastic'])
82 | if rounding == 'stochastic':
83 | x_floor = x.floor()
84 | return x_floor + torch.bernoulli(x - x_floor)
85 | else:
86 | return x.round()
87 |
88 | class RoundFunction(Function):
89 | @staticmethod
90 | def forward(ctx, input, rounding='deterministic'):
91 | return new_round(input, rounding)
92 |
93 | @staticmethod
94 | def backward(ctx, grad_output):
95 | return grad_output, None
96 |
97 | def round(input, rounding='deterministic'):
98 | return RoundFunction.apply(input, rounding)
99 |
100 | class SignFunction(Function):
101 | @staticmethod
102 | def forward(ctx, input):
103 | return torch.sign(input)
104 |
105 | @staticmethod
106 | def backward(ctx, grad_output):
107 | return grad_output
108 |
109 | def sign(input):
110 | return SignFunction.apply(input)
111 |
112 | class ClampFunction(Function):
113 | @staticmethod
114 | def forward(ctx, input, min, max):
115 | return torch.clamp(input, min, max)
116 |
117 | @staticmethod
118 | def backward(ctx, grad_output):
119 | return grad_output, None, None
120 |
121 | def clamp(input, min, max):
122 | return ClampFunction.apply(input, min, max)
123 |
124 | class ClampAbsFunction(Function):
125 | @staticmethod
126 | def forward(ctx, input, min, max):
127 | assert(min >= 0 and max >=0)
128 |
129 | input[input > max] = max
130 | input[input < -max] = -max
131 |
132 | input[(input > torch.zeros_like(input)) & (input < min)] = min
133 | input[(input < torch.zeros_like(input)) & (input > -min)] = -min
134 | return input
135 |
136 | @staticmethod
137 | def backward(ctx, grad_output):
138 | return grad_output, None, None
139 |
140 | def clampabs(input, min, max):
141 | return ClampAbsFunction.apply(input, min, max)
142 |
143 | class LogFunction(Function):
144 | @staticmethod
145 | def forward(ctx, input):
146 | return torch.log(input)
147 |
148 | @staticmethod
149 | def backward(ctx, grad_output):
150 | return grad_output
151 |
152 | def log(input):
153 | return LogFunction.apply(input)
154 |
155 | class UnsymmetricGradMulFunction(Function):
156 | @staticmethod
157 | def forward(ctx, input1, input2):
158 | ctx.save_for_backward(input1, input2)
159 | return torch.mul(input1, input2)
160 |
161 | @staticmethod
162 | def backward(ctx, grad_output):
163 | input1, input2 = ctx.saved_tensors
164 | return grad_output*input2, grad_output
165 |
166 | def unsym_grad_mul(input1, input2):
167 | return UnsymmetricGradMulFunction.apply(input1, input2)
168 |
169 |
170 | class AbsFunction(Function):
171 | @staticmethod
172 | def forward(ctx, input):
173 | return torch.abs(input)
174 |
175 | @staticmethod
176 | def backward(ctx, grad_output):
177 | return grad_output
178 |
179 | def abs(input):
180 | return AbsFunction.apply(input)
--------------------------------------------------------------------------------
/torchshiftadd/utils/test_acc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch.autograd import Variable
4 | import numpy as np
5 |
6 | def accuracy(output, target, topk=(1,)):
7 | """Computes the precision@k for the specified values of k"""
8 | maxk = max(topk)
9 | batch_size = target.size(0)
10 |
11 | _, pred = output.topk(maxk, 1, True, True)
12 | pred = pred.t()
13 | correct = pred.eq(target.view(1, -1).expand_as(pred))
14 |
15 | res = []
16 | for k in topk:
17 | correct_k = correct[:k].reshape(-1).float().sum(0)
18 | res.append(correct_k.mul_(100.0 / batch_size))
19 | return res
20 |
21 | def test_acc(model, test_loader):
22 | model.eval()
23 | test_loss = 0
24 | test_acc = 0
25 | test_acc_5 = 0
26 | for data, target in test_loader:
27 | data, target = data.cuda(), target.cuda()
28 | data, target = Variable(data, volatile=True), Variable(target)
29 | output = model(data)
30 | test_loss += F.cross_entropy(output, target, size_average=False).item() # sum up batch loss
31 | prec1, prec5 = accuracy(output.data, target.data, topk=(1, 5))
32 | test_acc += prec1.item()
33 | test_acc_5 += prec5.item()
34 |
35 | test_loss /= len(test_loader.dataset)
36 | print('\nTest set: Average loss: {:.4f}, Prec1: {}/{} ({:.2f}%), Prec5: ({:.2f}%)\n'.format(
37 | test_loss, test_acc, len(test_loader), test_acc / len(test_loader), test_acc_5 / len(test_loader)))
38 | return np.round(test_acc / len(test_loader), 2), np.round(test_acc_5 / len(test_loader), 2)
--------------------------------------------------------------------------------
/torchshiftadd/utils/torch.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from torch.utils import cpp_extension
5 | from . import decorator, comm
6 |
7 | class LazyExtensionLoader(object):
8 |
9 | def __init__(self, name, sources, extra_cflags=None, extra_cuda_cflags=None, extra_ldflags=None,
10 | extra_include_paths=None, build_directory=None, verbose=False, **kwargs):
11 | self.name = name
12 | self.sources = sources
13 | self.extra_cflags = extra_cflags
14 | self.extra_cuda_cflags = extra_cuda_cflags
15 | self.extra_ldflags = extra_ldflags
16 | self.extra_include_paths = extra_include_paths
17 | worker_name = "%s_%d" % (name, comm.get_rank())
18 | self.build_directory = build_directory or cpp_extension._get_build_directory(worker_name, verbose)
19 | self.verbose = verbose
20 | self.kwargs = kwargs
21 |
22 | def __getattr__(self, key):
23 | return getattr(self.module, key)
24 |
25 | @decorator.cached_property
26 | def module(self):
27 | return cpp_extension.load(self.name, self.sources, self.extra_cflags, self.extra_cuda_cflags,
28 | self.extra_ldflags, self.extra_include_paths, self.build_directory,
29 | self.verbose, **self.kwargs)
30 |
31 |
32 | def load_extension(name, sources, extra_cflags=None, extra_cuda_cflags=None, **kwargs):
33 | """
34 | Load a PyTorch C++ extension just-in-time (JIT).
35 | Automatically decide the compilation flags if not specified.
36 |
37 | This function performs lazy evaluation and is multi-process-safe.
38 |
39 | See `torch.utils.cpp_extension.load`_ for more details.
40 |
41 | .. _torch.utils.cpp_extension.load:
42 | https://pytorch.org/docs/stable/cpp_extension.html#torch.utils.cpp_extension.load
43 | """
44 | if extra_cflags is None:
45 | extra_cflags = ["-Ofast"]
46 | if torch.backends.openmp.is_available():
47 | extra_cflags += ["-fopenmp", "-DAT_PARALLEL_OPENMP"]
48 | else:
49 | extra_cflags.append("-DAT_PARALLEL_NATIVE")
50 | if extra_cuda_cflags is None:
51 | if torch.cuda.is_available():
52 | extra_cuda_cflags = ["-O3"]
53 | extra_cflags.append("-DCUDA_OP")
54 | else:
55 | new_sources = []
56 | for source in sources:
57 | if not cpp_extension._is_cuda_file(source):
58 | new_sources.append(source)
59 | sources = new_sources
60 |
61 | return LazyExtensionLoader(name, sources, extra_cflags, extra_cuda_cflags, **kwargs)
62 |
63 |
--------------------------------------------------------------------------------