├── LICENSE ├── MANIFEST.in ├── README.md ├── examples ├── compare_conv2d.py ├── compare_linear.py ├── correctness_conv2d.py ├── correctness_linear.py ├── finetune_resnet18_imagenette.py ├── models │ └── resnet18_ac_dc_500_epochs_sp=0.95_uniform.pt ├── notebook.ipynb └── utils.py ├── setup.py └── sparseprop ├── __init__.py ├── backend.cpp ├── lib ├── sparse_conv2d.cpp ├── sparse_conv2d_over_on.cpp ├── sparse_linear.cpp └── utils.cpp ├── modules ├── __init__.py ├── conv2d.py ├── functions.py ├── linear.py └── utils.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include sparseprop/lib/*.cpp -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # *SparseProp* 2 | 3 | Official implementation of the paper *"SparseProp: Efficient Sparse Backpropagation for Faster Training of Neural Networks"*. 4 | 5 | [Link to the paper](https://arxiv.org/abs/2302.04852) 6 | 7 | This library provides fast PyTorch modules exploiting sparse backpropagation algorithms described in the paper. 8 | 9 | ## Installation 10 | 1. Make sure you have PyTorch installed (refer to the [PyTorch website](https://pytorch.org)). A CPU version will suffice for our purpose. 11 | 2. Install *SparseProp*: 12 | ``` 13 | pip install sparseprop 14 | ``` 15 | 16 | ## Usage 17 | 18 | __Check out our [tutorial notebook](https://github.com/IST-DASLab/sparseprop/blob/main/examples/notebook.ipynb) for a simple and step-by-step guide on how to use *SparseProp*.__ 19 | 20 | #### Sparsifying a single layer 21 | If you have a sparse *Linear* module called `linear`, you can easily convert it to a *SparseLinear* module using the `from_dense` method. 22 | ``` 23 | from sparseprop.modules import SparseLinear 24 | 25 | sparse_linear = SparseLinear.from_dense(linear) 26 | ``` 27 | 28 | This will automatically store the parameters of the `linear` module in a sparse format and benefit from *SparseProp*'s efficient backend. You can treat `sparse_linear` as a normal PyTorch module, e.g., you can simply call `output = sparse_linear(input)`. 29 | 30 | A similar interface exists for a sparse *Conv2d* module (called `conv`): 31 | ``` 32 | from sparseprop.modules import SparseConv2d 33 | 34 | sparse_conv = SparseConv2d.from_dense(conv, vectorizing_over_on=False) 35 | ``` 36 | 37 | The only difference with the *Linear* case is that there is an additional boolean argument `vectorizing_over_on`. As described in the paper, we have two implementations for the convolution case, one performing the vectorization over the bactch size `B`, and the other over the output width `ON`. Using this argument you can specify which one of the two implementations to use. A quick rule of thumb is that if the input width and height are small (e.g., less than 32) then `vectorizing_over_on=False` is faster. 38 | 39 | Alternatively, the `sparsify_conv2d_auto` method can automatically determine the correct value of `vectorizing_over_on`. 40 | 41 | ``` 42 | from sparseprop.modules import sparsify_conv2d_auto 43 | 44 | sparse_conv = sparsify_conv2d_auto(conv, input_shape, verbose=True) 45 | ``` 46 | 47 | Notice that you will need to feed the `input_shape` to this method, which should look something like (`batch_size`, `input_channels`, `input_height`, `input_width`). This method will create two sparse modules, one with `vectorizing_over_on=False` and the other one with `vectorizing_over_on=True`, run a randomly generated batch through both, and return the faster module based on forward+backward time. 48 | 49 | #### Sparsifying the whole network 50 | As explained in the paper, we replace each *Linear* or *Conv2d* layer in a network with a sparse one, if the following conditions are met: 51 | 1. It is at least 80% sparse. 52 | 2. The sparse module is faster than the original dense one (in terms of forward+backward time). 53 | 54 | This behavior is implemented in the `swap_modules_with_sparse` method in `sparseprop.utils`. For example, if you have a sparse (global or uniform) `model`: 55 | 56 | ``` 57 | from sparseprop.utils import swap_modules_with_sparse 58 | 59 | sparse_model = swap_modules_with_sparse(model, input_shape, verbose=True) 60 | ``` 61 | 62 | Notice that you need to provide the `input_shape` to this method, which is easily accessible through your *DataLoader*. The `swap_modules_with_sparse` method will iterate through the network's layers and replace them with their sparse counterparts if the above two conditions are met. 63 | 64 | ## Examples 65 | In the `examples` folder, you can find multiple python scripts, which will help you get started with *SparseProp*. In order to get persistent timings, we refer you to [this article](https://easyperf.net/blog/2019/08/02/Perf-measurement-environment-on-Linux). You can use your favorite command line tool in case you want to limit the number of CPU cores on which the code executes, e.g., `taskset` or `numactl`. Refer to the "Set cpu affinity" section in the same article. 66 | 67 | #### Correctness check 68 | The files `correctness_linear.py` and `correctness_conv2d.py` will compare the output of the *SparseLinear* and *SparseConv2d* modules with PyTorch's *Linear* and *Conv2d*, respectively. You can tweak the parameters in the scripts to check the correctness in different cases. 69 | 70 | #### Layer-wise performance comparison 71 | The files `compare_linear.py` and `compare_conv2d.py` will compare the running time of the *SparseLinear* and *SparseConv2d* modules with PyTorch's *Linear* and *Conv2d*, respectively. You will find the results in the `plots` directory. Again, feel free to tweak the parameters in the scripts to compare the runtime in different cases. 72 | 73 | #### Sparse fine-tuning of ResNet18 on imagenette 74 | 75 | For this example to work, you will need to have the [*sparseml*](https://github.com/neuralmagic/sparseml) library installed, as we use it to conveniently load the imagenette dataset (`pip install sparseml`). 76 | 77 | The file `finetune_resnet18_imagenette.py` finetunes a pretrained sparse ResNet18 model on the imagenette dataset, keeping the sparsity masks fixed. In the `examples/models/` folder, we have also included a 95% uniformly pruned ResNet18 checkpoint trained on imagenet (using the [AC/DC](https://arxiv.org/abs/2106.12379) method). You can use the following command to run this script on 4 cpu cores. 78 | 79 | ``` 80 | taskset -c 0,1,2,3 nice -n 5 python finetune_resnet18_imagenette.py --checkpoint-path=models/resnet18_ac_dc_500_epochs_sp\=0.95_uniform.pt --output-dir=results/resnet18_ac_dc_500_epochs_sp\=0.95_uniform/ 81 | ``` 82 | 83 | Notice that "`0,1,2,3`" are the core numbers, so simply modify that in case your machine has less than 4 cores. Also "`nice -n 5`" gives a high priority to your process. 84 | 85 | The most important arguments of this script are: 86 | - `--checkpoint-path`: Path to the pretrained checkpoint. 87 | - `--output-dir`: Path to a directory where you wish to write the results. 88 | - `--run-dense`: You can use this argument to run this script without *SparseProp*. 89 | 90 | For the complete list of arguments, refer to [here](https://github.com/IST-DASLab/sparseprop/blob/96a8f545461847effe863e4471d1cd80b33fc0a2/examples/finetune_resnet18_imagenette_95_uniform.py#L16). 91 | 92 | In addition to the loss and accuracy metrics, this script also reports the time spent in each part of the process. The timings include: 93 | 94 | - `avg_end_to_end_forward`: the average time spent in the forward pass, i.e., the `model(inputs)` line. 95 | - `avg_end_to_end_backward`: the average time spent in the backward pass, i.e., the `loss.backward()` line. 96 | - `avg_end_to_end_minibatch`: the average time spent processing a minibatch. This includes forward pass, backward pass, loss calculation, optimization step, etc. Note that loading the data into memory is not included. 97 | - `avg_module_forward_sum`: the average time spent in the forward function of the modules *torch.nn.Linear*, *torch.nn.Conv2d*, *SparseLinear*, and *SparseConv2d*. 98 | - `avg_module_backward_sum`: the average time spent in the backward function of the modules *torch.nn.Linear*, *torch.nn.Conv2d*, *SparseLinear*, and *SparseConv2d*. 99 | 100 | ## Todo 101 | 1. Include outputs of the example scripts in the README. 102 | 2. Prepare an example script for training a sparse model from scratch using gradual magnitude pruning. This will most likely be integrated into [this](https://github.com/IST-DASLab/ACDC) repository. 103 | -------------------------------------------------------------------------------- /examples/compare_conv2d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import matplotlib.pyplot as plt 3 | from time import time 4 | from tqdm import tqdm 5 | from sparseprop.modules import attach_identity_and_time, SparseConv2d 6 | import math 7 | import os, sys 8 | 9 | def DenseConv(W, padding, stride): 10 | OC, IC, K, _ = W.shape 11 | conv = torch.nn.Conv2d(IC, OC, K, stride=stride, padding=padding) 12 | with torch.no_grad(): 13 | conv.weight.mul_(0.) 14 | conv.weight.add_(W) 15 | return conv 16 | 17 | def SparseConv2dOverON(W, padding, stride): 18 | return SparseConv2d(W, padding=padding, stride=stride, vectorizing_over_on=True) 19 | 20 | if __name__ == '__main__': 21 | B = 64 # batch size 22 | IC = 128 # input channels 23 | OC = 128 # output channels 24 | M = 32 # input height 25 | N = 32 # input width 26 | K = 3 # kernel size 27 | stride = 1 # stride 28 | padding = 0 # padding 29 | sparsities = [.8, .9, .95, .98, .99] 30 | reps = 3 31 | 32 | torch.manual_seed(10) 33 | tag = sys.argv[1] if len(sys.argv) > 1 else None 34 | 35 | OM = math.ceil((M + 2 * padding - K + 1) / stride) 36 | ON = math.ceil((N + 2 * padding - K + 1) / stride) 37 | 38 | module_fns = [DenseConv, SparseConv2d, SparseConv2dOverON] 39 | module_names = [m.__name__ for m in module_fns] 40 | 41 | forward_times = {m: [] for m in module_names} 42 | backward_times = {m: [] for m in module_names} 43 | 44 | for sparsity in sparsities: 45 | sp_forward_times = {m: [] for m in module_names} 46 | sp_backward_times = {m: [] for m in module_names} 47 | for _ in tqdm(range(reps)): 48 | W = torch.randn(OC, IC, K, K) 49 | mask = torch.rand_like(W) > sparsity 50 | W *= mask 51 | 52 | Y_orig = torch.randn(B, OC, OM, ON) 53 | 54 | X_orig = torch.randn(B, IC, M, N) 55 | X_orig.requires_grad_() 56 | X_orig.retain_grad() 57 | 58 | for module_name, module_fn in zip(module_names, module_fns): 59 | 60 | module = module_fn(W, padding=padding, stride=stride) 61 | X = X_orig.clone() 62 | Y = Y_orig.clone() 63 | 64 | bt, ft = attach_identity_and_time(module, X, Y, time_forward=True, time_backward=True) 65 | sp_forward_times[module_name].append(ft) 66 | sp_backward_times[module_name].append(bt) 67 | 68 | for mn in module_names: 69 | forward_times[mn].append(sum(sp_forward_times[mn]) / reps) 70 | backward_times[mn].append(sum(sp_backward_times[mn]) / reps) 71 | 72 | title = f'B{B}-IC{IC}-OC{OC}-M{M}-K{K}-S{stride}-P{padding}' 73 | if tag is not None: 74 | title += '-' + tag 75 | os.makedirs(f'plots/conv2d/{title}', exist_ok=False) 76 | 77 | for mn in module_names: 78 | plt.plot(sparsities, forward_times[mn], '-o', label=mn) 79 | plt.grid() 80 | plt.xlabel('sparsity') 81 | plt.ylabel('time') 82 | plt.ylim(bottom=0) 83 | plt.title(f'{title}-forward') 84 | plt.legend() 85 | plt.savefig(f'plots/conv2d/{title}/forward.jpg') 86 | plt.close() 87 | 88 | for mn in module_names: 89 | plt.plot(sparsities, backward_times[mn], '-o', label=mn) 90 | plt.grid() 91 | plt.xlabel('sparsity') 92 | plt.ylabel('time') 93 | plt.ylim(bottom=0) 94 | plt.title(f'{title}-backward') 95 | plt.legend() 96 | plt.savefig(f'plots/conv2d/{title}/backward.jpg') 97 | plt.close() -------------------------------------------------------------------------------- /examples/compare_linear.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import matplotlib.pyplot as plt 3 | from time import time 4 | from tqdm import tqdm 5 | from sparseprop.modules import SparseLinear, attach_identity_and_time 6 | import os, sys 7 | 8 | class SparseLinearTorch(torch.nn.Module): 9 | def __init__(self, W): 10 | super(SparseLinearTorch, self).__init__() 11 | self.W = W.clone() 12 | self.W = W.to_sparse() 13 | self.W.requires_grad_() 14 | self.W.retain_grad() 15 | 16 | def forward(self, X): 17 | return torch.sparse.mm(self.W, X.T).T 18 | 19 | def DenseLinear(W): 20 | N, M = W.shape 21 | linear = torch.nn.Linear(M, N, bias=False) 22 | with torch.no_grad(): 23 | linear.weight.mul_(0.) 24 | linear.weight.add_(W) 25 | return linear 26 | 27 | if __name__ == '__main__': 28 | M = 512 29 | N = 256 30 | B = 128 31 | sparsities = [.8, .9, .95, .98, .99] 32 | reps = 3 33 | 34 | torch.manual_seed(11) 35 | tag = sys.argv[1] if len(sys.argv) > 1 else None 36 | 37 | module_fns = [DenseLinear, SparseLinear, SparseLinearTorch] 38 | module_names = [m.__name__ for m in module_fns] 39 | 40 | forward_times = {m: [] for m in module_names} 41 | backward_times = {m: [] for m in module_names} 42 | 43 | for sparsity in sparsities: 44 | sp_forward_times = {m: [] for m in module_names} 45 | sp_backward_times = {m: [] for m in module_names} 46 | for _ in tqdm(range(reps)): 47 | W = torch.randn(N, M) 48 | mask = torch.rand_like(W) > sparsity 49 | W *= mask 50 | 51 | Y_orig = torch.randn(B, N) 52 | 53 | X_orig = torch.randn(B, M) 54 | X_orig.requires_grad_() 55 | X_orig.retain_grad() 56 | 57 | for module_name, module_fn in zip(module_names, module_fns): 58 | 59 | module = module_fn(W) 60 | X = X_orig.clone() 61 | Y = Y_orig.clone() 62 | 63 | bt, ft = attach_identity_and_time(module, X, Y, time_forward=True, time_backward=True) 64 | sp_forward_times[module_name].append(ft) 65 | sp_backward_times[module_name].append(bt) 66 | 67 | for mn in module_names: 68 | forward_times[mn].append(sum(sp_forward_times[mn]) / reps) 69 | backward_times[mn].append(sum(sp_backward_times[mn]) / reps) 70 | 71 | title = f'B{B}-M{M}-N{N}' 72 | if tag is not None: 73 | title += '-' + tag 74 | os.makedirs(f'plots/linear/{title}', exist_ok=False) 75 | 76 | for mn in module_names: 77 | plt.plot(sparsities, forward_times[mn], '-o', label=mn) 78 | plt.grid() 79 | plt.xlabel('sparsity') 80 | plt.ylabel('time') 81 | plt.ylim(bottom=0) 82 | plt.title(f'{title}-forward') 83 | plt.legend() 84 | plt.savefig(f'plots/linear/{title}/forward.jpg') 85 | plt.close() 86 | 87 | for mn in module_names: 88 | plt.plot(sparsities, backward_times[mn], '-o', label=mn) 89 | plt.grid() 90 | plt.xlabel('sparsity') 91 | plt.ylabel('time') 92 | plt.ylim(bottom=0) 93 | plt.title(f'{title}-backward') 94 | plt.legend() 95 | plt.savefig(f'plots/linear/{title}/backward.jpg') 96 | plt.close() -------------------------------------------------------------------------------- /examples/correctness_conv2d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from sparseprop.modules import SparseConv2d 3 | from sparseprop.utils import error 4 | import math 5 | from copy import deepcopy 6 | 7 | if __name__ == '__main__': 8 | torch.manual_seed(11) 9 | 10 | B = 256 # batch size 11 | IC = 512 # input channels 12 | OC = 512 # output channels 13 | M = 7 # input height 14 | N = 7 # input width 15 | K = 3 # kernel size 16 | stride = 1 # stride 17 | padding = 0 # padding 18 | vectorizing_over_on = False # as described in the paper 19 | sparsity = .9 # sparsity of the weights 20 | 21 | OM = math.ceil((M + 2 * padding - K + 1) / stride) 22 | ON = math.ceil((N + 2 * padding - K + 1) / stride) 23 | 24 | W = torch.randn(OC, IC, K, K) 25 | bias = torch.randn(OC) 26 | mask = torch.rand_like(W) > sparsity 27 | W *= mask 28 | 29 | Y_orig = torch.randn(B, OC, OM, ON) 30 | 31 | X_orig = torch.randn(B, IC, M, N) 32 | X_orig.requires_grad_() 33 | X_orig.retain_grad() 34 | 35 | torch_X = X_orig.clone() 36 | torch_X.retain_grad() 37 | torch_Y = Y_orig.clone() 38 | conv = torch.nn.Conv2d(IC, OC, K, stride=stride, padding=padding, bias=True) 39 | with torch.no_grad(): 40 | conv.weight.mul_(0.) 41 | conv.weight.add_(W) 42 | conv.bias.mul_(0.) 43 | conv.bias.add_(bias) 44 | torch_O = conv(torch_X) 45 | torch.mean((torch_O - torch_Y) ** 2).backward() 46 | torch_X_grad = torch_X.grad 47 | torch_W_grad = conv.weight.grad[conv.weight != 0] 48 | 49 | our_X = X_orig.clone() 50 | our_X.retain_grad() 51 | our_Y = Y_orig.clone() 52 | spconv = SparseConv2d(W, bias=torch.nn.Parameter(deepcopy(bias)), padding=padding, stride=stride, vectorizing_over_on=vectorizing_over_on) 53 | our_O = spconv(our_X) 54 | torch.mean((our_O - our_Y) ** 2).backward() 55 | our_X_grad = our_X.grad 56 | our_W_grad = spconv.W_val.grad 57 | 58 | print('[Forward]\n O error:', error(our_O, torch_O)) 59 | print('[Backward]\n X grad error:', error(our_X_grad, torch_X_grad), '\n W grad error:', error(our_W_grad, torch_W_grad)) 60 | print('[Backward]\n bias grad error:', error(spconv.bias.grad, conv.bias.grad)) -------------------------------------------------------------------------------- /examples/correctness_linear.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from sparseprop.modules import SparseLinear 3 | from sparseprop.utils import error 4 | import math 5 | from copy import deepcopy 6 | 7 | if __name__ == '__main__': 8 | torch.manual_seed(11) 9 | 10 | B = 128 # batch size 11 | M = 512 # input height 12 | N = 256 # input width 13 | sparsity = .9 14 | 15 | W = torch.randn(N, M) 16 | bias = torch.randn(N) 17 | mask = torch.rand_like(W) > sparsity 18 | W *= mask 19 | 20 | Y_orig = torch.randn(B, N) 21 | 22 | X_orig = torch.randn(B, M) 23 | X_orig.requires_grad_() 24 | X_orig.retain_grad() 25 | 26 | torch_X = X_orig.clone() 27 | torch_X.retain_grad() 28 | torch_Y = Y_orig.clone() 29 | linear = torch.nn.Linear(M, N, bias=True) 30 | 31 | print(linear.weight.shape, W.shape) 32 | with torch.no_grad(): 33 | linear.weight.mul_(0.) 34 | linear.weight.add_(W) 35 | linear.bias.mul_(0.) 36 | linear.bias.add_(bias) 37 | 38 | torch_O = linear(torch_X) 39 | torch.mean((torch_O - torch_Y) ** 2).backward() 40 | torch_X_grad = torch_X.grad 41 | torch_W_grad = linear.weight.grad[linear.weight != 0] 42 | 43 | our_X = X_orig.clone() 44 | our_X.retain_grad() 45 | our_Y = Y_orig.clone() 46 | splinear = SparseLinear(W, bias=torch.nn.Parameter(deepcopy(bias))) 47 | our_O = splinear(our_X) 48 | torch.mean((our_O - our_Y) ** 2).backward() 49 | our_X_grad = our_X.grad 50 | our_W_grad = splinear.W_val.grad 51 | 52 | print('[Forward]\n O error:', error(our_O, torch_O)) 53 | print('[Backward]\n X grad error:', error(our_X_grad, torch_X_grad), '\n W grad error:', error(our_W_grad, torch_W_grad)) 54 | print('[Backward]\n bias grad error:', error(splinear.bias.grad, linear.bias.grad)) -------------------------------------------------------------------------------- /examples/finetune_resnet18_imagenette.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from torchvision.models import resnet18 4 | from sparseml.pytorch.datasets import ImagenetteDataset, ImagenetteSize 5 | import numpy 6 | import random 7 | 8 | from argparse import ArgumentParser 9 | from pprint import pformat 10 | import os 11 | 12 | from sparseprop.utils import sparsity, swap_modules_with_sparse 13 | from utils import Logger, Timer, apply_to_all_modules_with_types, Finetuner 14 | 15 | # arguments 16 | parser = ArgumentParser() 17 | parser.add_argument('-b','--batch-size', help='batch size for fine-tuning.', type=int, default=64) 18 | parser.add_argument('-nw','--num-workers', help='number of workers for dataloaders.', type=int, default=4) 19 | parser.add_argument('-rd', '--run-dense', help='set true to not use sparseprop and run everything dense.', action='store_true', default=False) 20 | parser.add_argument('-s','--seed', help='manual seed.', type=int, default=10) 21 | parser.add_argument('-e','--epochs', help='the number of epoch to train.', type=int, default=1) 22 | parser.add_argument('-sf','--save-frequency', help='how often to save the model (in epochs).', type=int, default=1) 23 | parser.add_argument('-lf','--log-frequency', help='how often to log (in batches).', type=int, default=20) 24 | parser.add_argument('-cp','--checkpoint-path', dest='ckpt_path', help='path to the pretrained sparse resnet18 checkpoint to be fine-tuned.', type=str, required=True) 25 | parser.add_argument('-od','--output-dir', dest='outdir', help='where to write the results. cannot already exist.', type=str, required=True) 26 | parser.add_argument('-dd','--dataset-dir', help='where to store the dataset. we recommend /dev/shm/datasets/imagenette/. storing the data in /dev/shm/ will map it directly to memory, minimizing the data loading overhead.', type=str, default='/dev/shm/datasets/imagenette/') 27 | args = parser.parse_args() 28 | 29 | 30 | # set the seed everywhere 31 | random.seed(args.seed) 32 | numpy.random.seed(args.seed) 33 | torch.manual_seed(args.seed) 34 | 35 | 36 | # make a directory to save the results in 37 | os.makedirs(args.outdir, exist_ok=False) 38 | 39 | 40 | # initialize the logger 41 | logger = Logger(args.outdir) 42 | 43 | # load the sparse model 44 | model = resnet18() 45 | ckpt = torch.load(args.ckpt_path, map_location='cpu') 46 | model.load_state_dict(ckpt) 47 | 48 | 49 | # print sparsity of each layer 50 | logger.log("Sparsity per layer:") 51 | logger.log(pformat(apply_to_all_modules_with_types( 52 | model, 53 | [torch.nn.Linear, torch.nn.Conv2d], 54 | lambda n, m: f'{sparsity(m):.3f}') 55 | , indent=4)) 56 | logger.log('-' * 40) 57 | 58 | 59 | # swap the last layer of the model to match the number of classes in imagenette 60 | fc = model.fc 61 | model.fc = torch.nn.Linear( 62 | fc.in_features, 63 | 10, # number of classes in imagenette 64 | bias=fc.bias is not None 65 | ) 66 | 67 | 68 | # load the datasets and prepare the dataloaders 69 | train_dataset, test_dataset = [ImagenetteDataset( 70 | root=args.dataset_dir, 71 | train=train, 72 | dataset_size=ImagenetteSize.s320, 73 | image_size=224 74 | ) for train in [True, False]] 75 | 76 | train_loader, test_loader = [DataLoader( 77 | dataset, 78 | batch_size=args.batch_size, 79 | num_workers=args.num_workers, 80 | shuffle=True 81 | ) for dataset in [train_dataset, test_dataset]] 82 | 83 | logger.log(f'Total number of training batches: {len(train_loader)}') 84 | 85 | 86 | # replace modules with sparse ones if --run-dense not requested 87 | if not args.run_dense: 88 | input_shape = next(iter(train_loader))[0].shape 89 | model = swap_modules_with_sparse(model, input_shape, inplace=True, verbose=True) 90 | 91 | 92 | # loss and optim 93 | loss_fn = torch.nn.CrossEntropyLoss() 94 | optimizer = torch.optim.SGD(model.parameters(), lr=5e-3, momentum=0.9, weight_decay=1e-4) 95 | 96 | 97 | # initialize the finetuner 98 | finetuner = Finetuner( 99 | model, 100 | optimizer, 101 | schedular=None, 102 | loss_fn=loss_fn, 103 | log_freq=args.log_frequency, 104 | save_freq=args.save_frequency, 105 | logger=logger 106 | ) 107 | 108 | 109 | # finetune 110 | finetuner.finetune(train_loader, test_loader, args.epochs) 111 | 112 | 113 | -------------------------------------------------------------------------------- /examples/models/resnet18_ac_dc_500_epochs_sp=0.95_uniform.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IST-DASLab/sparseprop/df97d8e15b372b59103dceb0df71e1013bbf12a7/examples/models/resnet18_ac_dc_500_epochs_sp=0.95_uniform.pt -------------------------------------------------------------------------------- /examples/notebook.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# SparseProp Usage Guide\n", 8 | "\n", 9 | "This notebook serves as a guide on how to effectively utilize SparseProp. You'll find detailed steps on how to take advantage of SparseProp for both individual layers, as well as the entire network to accelerate the backpropagation process.\n", 10 | "\n", 11 | "As an introduction, SparseProp provides a low-level CPU implementation of backpropagation, where the weights of a layer are unstructured sparse. More specifically, if we have a sparse fully connected or convolution layer, SparseProp is capable of speeding up the backpropagation process on CPU. We further integrate SparseProp with the PyTorch framework, providing the *SparseLinear* and *SparseConv2d* modules as drop-in replacements for PyTorch's *Linear* and *Conv2d* modules, respectively. Further details of our algorithms can be found in [our paper](https://arxiv.org/abs/2302.04852).\n", 12 | "\n", 13 | "If you haven't already installed *SparseProp*, make sure you have PyTorch installed, and then simply run the following cell:" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "%pip install sparseprop" 23 | ] 24 | }, 25 | { 26 | "cell_type": "markdown", 27 | "metadata": {}, 28 | "source": [ 29 | "Now, let's get started! Here we only consider the case where only a single thread is being used for doing the computations. Run the following cell to limit both *PyTorch* and *SparseProp* to a single thread." 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 1, 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "import torch\n", 39 | "import sparseprop\n", 40 | "\n", 41 | "torch.set_num_threads(1)\n", 42 | "sparseprop.set_num_threads(1)" 43 | ] 44 | }, 45 | { 46 | "cell_type": "markdown", 47 | "metadata": {}, 48 | "source": [ 49 | "Also, let's set the random seeds to get consistent results." 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 2, 55 | "metadata": {}, 56 | "outputs": [ 57 | { 58 | "data": { 59 | "text/plain": [ 60 | "'All seeds were set to 10.'" 61 | ] 62 | }, 63 | "execution_count": 2, 64 | "metadata": {}, 65 | "output_type": "execute_result" 66 | } 67 | ], 68 | "source": [ 69 | "import random\n", 70 | "import numpy as np\n", 71 | "\n", 72 | "seed = 10\n", 73 | "random.seed(seed)\n", 74 | "np.random.seed(seed)\n", 75 | "torch.manual_seed(seed)\n", 76 | "\n", 77 | "f\"All seeds were set to {seed}.\"" 78 | ] 79 | }, 80 | { 81 | "cell_type": "markdown", 82 | "metadata": {}, 83 | "source": [ 84 | "## Individual Layer\n", 85 | "\n", 86 | "Let's say we have a *Linear* module, which is 98% sparse. For the sake of argument, let's actually create such module. We assume that the input and output dimensions are 768 and 3072, respectively, but any other dimensions work just as fine." 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": 3, 92 | "metadata": {}, 93 | "outputs": [ 94 | { 95 | "data": { 96 | "text/plain": [ 97 | "\"Our module's spasity is now 0.98.\"" 98 | ] 99 | }, 100 | "execution_count": 3, 101 | "metadata": {}, 102 | "output_type": "execute_result" 103 | } 104 | ], 105 | "source": [ 106 | "from torch.nn import Linear\n", 107 | "\n", 108 | "linear = Linear(768, 3072) # input size of 768, and output size of 3072\n", 109 | "\n", 110 | "# prune the module randomly to 98% unstructred sparsity\n", 111 | "with torch.no_grad():\n", 112 | "\n", 113 | " # generate a random mask with roughly 98% sparsity\n", 114 | " mask = torch.rand_like(linear.weight) > 0.98\n", 115 | "\n", 116 | " # apply the mask to the module\n", 117 | " linear.weight.mul_(mask.float())\n", 118 | "\n", 119 | "f\"Our module's spasity is now {(linear.weight == 0).float().mean().item():.2f}.\"" 120 | ] 121 | }, 122 | { 123 | "cell_type": "markdown", 124 | "metadata": {}, 125 | "source": [ 126 | "So now we actually have a 98% sparse module, called `linear`. Let's see how long forward and backward steps take on this module. Assuming the batch size is 2048, we generate a synthetic batch of data." 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": 4, 132 | "metadata": {}, 133 | "outputs": [], 134 | "source": [ 135 | "X = torch.randn(2048, 768) # batch_size x input_dimension\n", 136 | "\n", 137 | "# the following two lines tell PyTorch to keep the gradients for the input tensor\n", 138 | "X.requires_grad_()\n", 139 | "X.retain_grad()\n", 140 | "\n", 141 | "y = torch.randn(2048, 3072) # batch_size x output_dimension" 142 | ] 143 | }, 144 | { 145 | "cell_type": "markdown", 146 | "metadata": {}, 147 | "source": [ 148 | "Now we measure the time required for the forward and backward steps of the `linear` module." 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": 5, 154 | "metadata": {}, 155 | "outputs": [ 156 | { 157 | "data": { 158 | "text/plain": [ 159 | "'The forward pass took 0.083 seconds.'" 160 | ] 161 | }, 162 | "execution_count": 5, 163 | "metadata": {}, 164 | "output_type": "execute_result" 165 | } 166 | ], 167 | "source": [ 168 | "import time\n", 169 | "\n", 170 | "# time the forward step\n", 171 | "start = time.time()\n", 172 | "O = linear(X)\n", 173 | "pytorch_forward_time = time.time() - start\n", 174 | "f\"The forward pass took {pytorch_forward_time:.3f} seconds.\"" 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": 6, 180 | "metadata": {}, 181 | "outputs": [ 182 | { 183 | "data": { 184 | "text/plain": [ 185 | "'The backward pass took 0.196 seconds.'" 186 | ] 187 | }, 188 | "execution_count": 6, 189 | "metadata": {}, 190 | "output_type": "execute_result" 191 | } 192 | ], 193 | "source": [ 194 | "# calculate the mse loss\n", 195 | "L = torch.mean((y - O) ** 2)\n", 196 | "\n", 197 | "# time the backward step\n", 198 | "start = time.time()\n", 199 | "L.backward()\n", 200 | "pytorch_backward_time = time.time() - start\n", 201 | "f\"The backward pass took {pytorch_backward_time:.3f} seconds.\"" 202 | ] 203 | }, 204 | { 205 | "cell_type": "markdown", 206 | "metadata": {}, 207 | "source": [ 208 | "Notice we haven't exploited *SparseProp*'s implementations yet. Let's see how much speedup we can get if we utilize SparseProp.\n", 209 | "\n", 210 | "To do so, we only need one line of code:" 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": 7, 216 | "metadata": {}, 217 | "outputs": [ 218 | { 219 | "name": "stdout", 220 | "output_type": "stream", 221 | "text": [ 222 | "SparseLinear([3072, 768], sp=0.98, nnz=46972)\n" 223 | ] 224 | } 225 | ], 226 | "source": [ 227 | "from sparseprop.modules import SparseLinear\n", 228 | "\n", 229 | "# this line will convert your pytorch module to a sparseprop module\n", 230 | "sparse_linear = SparseLinear.from_dense(linear)\n", 231 | "\n", 232 | "print(sparse_linear)" 233 | ] 234 | }, 235 | { 236 | "cell_type": "markdown", 237 | "metadata": {}, 238 | "source": [ 239 | "Now that we have a *SparseProp* module, let's again compute the forward and backward times:" 240 | ] 241 | }, 242 | { 243 | "cell_type": "code", 244 | "execution_count": 8, 245 | "metadata": {}, 246 | "outputs": [ 247 | { 248 | "data": { 249 | "text/plain": [ 250 | "'The forward pass took 0.038 seconds.'" 251 | ] 252 | }, 253 | "execution_count": 8, 254 | "metadata": {}, 255 | "output_type": "execute_result" 256 | } 257 | ], 258 | "source": [ 259 | "# time the forward step\n", 260 | "start = time.time()\n", 261 | "O = sparse_linear(X)\n", 262 | "sparseprop_forward_time = time.time() - start\n", 263 | "f\"The forward pass took {sparseprop_forward_time:.3f} seconds.\"" 264 | ] 265 | }, 266 | { 267 | "cell_type": "code", 268 | "execution_count": 9, 269 | "metadata": {}, 270 | "outputs": [ 271 | { 272 | "data": { 273 | "text/plain": [ 274 | "'The backward pass took 0.084 seconds.'" 275 | ] 276 | }, 277 | "execution_count": 9, 278 | "metadata": {}, 279 | "output_type": "execute_result" 280 | } 281 | ], 282 | "source": [ 283 | "# calculate the mse loss\n", 284 | "L = torch.mean((y - O) ** 2)\n", 285 | "\n", 286 | "# time the backward step\n", 287 | "start = time.time()\n", 288 | "L.backward()\n", 289 | "sparseprop_backward_time = time.time() - start\n", 290 | "f\"The backward pass took {sparseprop_backward_time:.3f} seconds.\"" 291 | ] 292 | }, 293 | { 294 | "cell_type": "markdown", 295 | "metadata": {}, 296 | "source": [ 297 | "The numbers you get will highly depend on your CPU architecture, but you should generally be able to see a non-trivial speedup with *SparseProp* with respect to PyTorch's implementations. Run the following cell to compare the two:" 298 | ] 299 | }, 300 | { 301 | "cell_type": "code", 302 | "execution_count": 10, 303 | "metadata": {}, 304 | "outputs": [ 305 | { 306 | "name": "stdout", 307 | "output_type": "stream", 308 | "text": [ 309 | "Forward speedup: 2.22x\n", 310 | "Backward speedup: 2.35x\n" 311 | ] 312 | } 313 | ], 314 | "source": [ 315 | "print(f\"Forward speedup: {pytorch_forward_time / sparseprop_forward_time:.2f}x\")\n", 316 | "print(f\"Backward speedup: {pytorch_backward_time / sparseprop_backward_time:.2f}x\")" 317 | ] 318 | }, 319 | { 320 | "cell_type": "markdown", 321 | "metadata": {}, 322 | "source": [ 323 | "If you have a `Conv2d` module instead of a `Linear` one, you can again use *SparseProp* to gain speedups. The interface is exactly the same with only one differnce. If your module is called `conv`, you can do:\n", 324 | "\n", 325 | "```\n", 326 | "from sparseprop.modules import SparseConv2d\n", 327 | "\n", 328 | "sparse_conv = SparseConv2d.from_dense(conv, vectorizing_over_on=False)\n", 329 | "```\n", 330 | "\n", 331 | "The only difference with the *Linear* case is that there is an additional boolean argument `vectorizing_over_on`. As described in [the paper](https://arxiv.org/abs/2302.04852), we have two implementations for the convolution case, one performing the vectorization over the bactch size, and the other over the output dimension. Using this argument you can specify which one of the two implementations to use. A quick rule of thumb is that if the input width and height are small (e.g., less than 32) then `vectorizing_over_on=False` is faster.\n", 332 | "\n", 333 | "Alternatively, the `sparsify_conv2d_auto` method can automatically determine the correct value of `vectorizing_over_on`.\n", 334 | "\n", 335 | "```\n", 336 | "from sparseprop.modules import sparsify_conv2d_auto\n", 337 | "\n", 338 | "sparse_conv = sparsify_conv2d_auto(conv, input_shape, verbose=True)\n", 339 | "```\n", 340 | "\n", 341 | "Notice that you will need to feed the `input_shape` to this method, which should look something like (`batch_size`, `input_channels`, `input_height`, `input_width`). This method will create two sparse modules, one with `vectorizing_over_on=False` and the other one with `vectorizing_over_on=True`, run a randomly generated batch through both, and return the faster module based on forward+backward time." 342 | ] 343 | }, 344 | { 345 | "cell_type": "markdown", 346 | "metadata": {}, 347 | "source": [ 348 | "## Full Network" 349 | ] 350 | }, 351 | { 352 | "cell_type": "markdown", 353 | "metadata": {}, 354 | "source": [ 355 | "Now assume you have a sparse network instead of just one layer. *SparseProp* offers tools that can seemlessly process your network object, and replace its layers with their corresponding sparse counterparts with only one or two lines of extra code. Let's go through an example of this case.\n", 356 | "\n", 357 | "Consider the scenario where you have a sparse model pre-trained on a large dataset (e.g., ImageNet). Let's say you want to fine-tune this sparse model on a smaller dataset (e.g., ImageNette), while keeping the sparsity mask fixed . This process is called *sparse transfer learning*.\n", 358 | "\n", 359 | "We have provided the checkpoint for a 95% uniform sparse ResNet18 model pre-trained on ImageNet at `models/resnet18_ac_dc_500_epochs_sp=0.95_uniform.pt`. Let's go ahead and load this model!" 360 | ] 361 | }, 362 | { 363 | "attachments": {}, 364 | "cell_type": "markdown", 365 | "metadata": {}, 366 | "source": [ 367 | "Let's again fix the random seed so we get consistent results." 368 | ] 369 | }, 370 | { 371 | "cell_type": "code", 372 | "execution_count": 10, 373 | "metadata": {}, 374 | "outputs": [ 375 | { 376 | "data": { 377 | "text/plain": [ 378 | "'All seeds were set to 11.'" 379 | ] 380 | }, 381 | "execution_count": 10, 382 | "metadata": {}, 383 | "output_type": "execute_result" 384 | } 385 | ], 386 | "source": [ 387 | "# set the seed everywhere\n", 388 | "seed = 11\n", 389 | "random.seed(seed)\n", 390 | "np.random.seed(seed)\n", 391 | "torch.manual_seed(seed)\n", 392 | "\n", 393 | "f\"All seeds were set to {seed}.\"" 394 | ] 395 | }, 396 | { 397 | "cell_type": "markdown", 398 | "metadata": {}, 399 | "source": [ 400 | "Now let's create a directory to store the log file and initialize a `Logger`. " 401 | ] 402 | }, 403 | { 404 | "cell_type": "code", 405 | "execution_count": 11, 406 | "metadata": {}, 407 | "outputs": [], 408 | "source": [ 409 | "import os\n", 410 | "from utils import Logger\n", 411 | "\n", 412 | "outdir = \"results-finetune-resnet18-imagenette/\"\n", 413 | "os.makedirs(outdir, exist_ok=False)\n", 414 | "logger = Logger(outdir)" 415 | ] 416 | }, 417 | { 418 | "cell_type": "markdown", 419 | "metadata": {}, 420 | "source": [ 421 | "It's time to load the checkpoint:" 422 | ] 423 | }, 424 | { 425 | "cell_type": "code", 426 | "execution_count": null, 427 | "metadata": {}, 428 | "outputs": [], 429 | "source": [ 430 | "from torchvision.models import resnet18\n", 431 | "\n", 432 | "model = resnet18() # initialize the network\n", 433 | "ckpt = torch.load('models/resnet18_ac_dc_500_epochs_sp=0.95_uniform.pt', map_location='cpu') # read the checkpoint from the file\n", 434 | "model.load_state_dict(ckpt) # load the checkpoint into the network" 435 | ] 436 | }, 437 | { 438 | "cell_type": "markdown", 439 | "metadata": {}, 440 | "source": [ 441 | "Run the following block of code to print the sparsity level of each layer. Since the network is 95% uniformly sparse, we expect all the layers (except the first and last ones) to have a sparsity of exactly 95%.\n", 442 | "\n", 443 | "You will notice we have used the function `apply_to_all_modules_with_types(model, types, fn)`. This function iterates through the layers of `model`, and if the type of the layer is in the `types` list, it will apply the `fn` function on it and return the results. " 444 | ] 445 | }, 446 | { 447 | "cell_type": "code", 448 | "execution_count": 13, 449 | "metadata": {}, 450 | "outputs": [ 451 | { 452 | "name": "stdout", 453 | "output_type": "stream", 454 | "text": [ 455 | "Sparsity per layer:\n", 456 | "OrderedDict([ ('conv1', '0.000'),\n", 457 | " ('layer1.0.conv1', '0.950'),\n", 458 | " ('layer1.0.conv2', '0.950'),\n", 459 | " ('layer1.1.conv1', '0.950'),\n", 460 | " ('layer1.1.conv2', '0.950'),\n", 461 | " ('layer2.0.conv1', '0.950'),\n", 462 | " ('layer2.0.conv2', '0.950'),\n", 463 | " ('layer2.0.downsample.0', '0.950'),\n", 464 | " ('layer2.1.conv1', '0.950'),\n", 465 | " ('layer2.1.conv2', '0.950'),\n", 466 | " ('layer3.0.conv1', '0.950'),\n", 467 | " ('layer3.0.conv2', '0.950'),\n", 468 | " ('layer3.0.downsample.0', '0.950'),\n", 469 | " ('layer3.1.conv1', '0.950'),\n", 470 | " ('layer3.1.conv2', '0.950'),\n", 471 | " ('layer4.0.conv1', '0.950'),\n", 472 | " ('layer4.0.conv2', '0.950'),\n", 473 | " ('layer4.0.downsample.0', '0.950'),\n", 474 | " ('layer4.1.conv1', '0.950'),\n", 475 | " ('layer4.1.conv2', '0.950'),\n", 476 | " ('fc', '0.000')])\n" 477 | ] 478 | } 479 | ], 480 | "source": [ 481 | "from pprint import pformat # just to print a dictionary nicely\n", 482 | "from utils import apply_to_all_modules_with_types\n", 483 | "from sparseprop.utils import sparsity\n", 484 | "\n", 485 | "logger.log(\"Sparsity per layer:\")\n", 486 | "logger.log(pformat(apply_to_all_modules_with_types(\n", 487 | " model,\n", 488 | " [torch.nn.Linear, torch.nn.Conv2d], # we only want the sparsity of linear and conv2d modules\n", 489 | " lambda name, module: f'{sparsity(module):.3f}') # calculate the sparsity for each module\n", 490 | ", indent=4))" 491 | ] 492 | }, 493 | { 494 | "cell_type": "markdown", 495 | "metadata": {}, 496 | "source": [ 497 | "This model is pre-trained on the ImageNet dataset, which consists of 1000 classes. However, for fine-tuning, we will be using the ImageNette dataset, which only has 10 classes. As a result, we will need to replace the classifier layer in order to adapt the model to this specific task." 498 | ] 499 | }, 500 | { 501 | "cell_type": "code", 502 | "execution_count": 14, 503 | "metadata": {}, 504 | "outputs": [], 505 | "source": [ 506 | "model.fc = torch.nn.Linear(\n", 507 | " model.fc.in_features, # number of input features\n", 508 | " 10, # number of classes in imagenette\n", 509 | " bias=model.fc.bias is not None # keep the bias if exists\n", 510 | ")" 511 | ] 512 | }, 513 | { 514 | "cell_type": "markdown", 515 | "metadata": {}, 516 | "source": [ 517 | "Now let's get our dataset and dataloaders ready. We directly load the ImageNette dataset from the *SparseML* library. You can run the following commnad to install it on your environment." 518 | ] 519 | }, 520 | { 521 | "cell_type": "code", 522 | "execution_count": null, 523 | "metadata": {}, 524 | "outputs": [], 525 | "source": [ 526 | "%pip install sparseml" 527 | ] 528 | }, 529 | { 530 | "cell_type": "markdown", 531 | "metadata": {}, 532 | "source": [ 533 | "Now that the library installed, we can load the dataset." 534 | ] 535 | }, 536 | { 537 | "cell_type": "code", 538 | "execution_count": null, 539 | "metadata": {}, 540 | "outputs": [], 541 | "source": [ 542 | "from torch.utils.data import DataLoader\n", 543 | "from sparseml.pytorch.datasets import ImagenetteDataset, ImagenetteSize\n", 544 | "\n", 545 | "# load the datasets\n", 546 | "train_dataset, test_dataset = [ImagenetteDataset(\n", 547 | " root='/dev/shm/', # store the dataset in /dev/shm/ to map the dataset to memory and avoid data loading overheads\n", 548 | " train=train,\n", 549 | " dataset_size=ImagenetteSize.s320,\n", 550 | " image_size=224\n", 551 | ") for train in [True, False]]\n", 552 | "\n", 553 | "# prepare the dataloaders\n", 554 | "train_loader, test_loader = [DataLoader(\n", 555 | " dataset,\n", 556 | " batch_size=256,\n", 557 | " num_workers=4,\n", 558 | " shuffle=True\n", 559 | ") for dataset in [train_dataset, test_dataset]]" 560 | ] 561 | }, 562 | { 563 | "cell_type": "markdown", 564 | "metadata": {}, 565 | "source": [ 566 | "Now here's where *SparseProp* comes into play. As explained in the paper, we replace each Linear or Conv2d layer in a network with a sparse one, if the following conditions are met:\n", 567 | "\n", 568 | "- It's at least 80% sparse.\n", 569 | "- The sparse module is faster than the original dense one (in terms of forward+backward time).\n", 570 | "\n", 571 | "This behavior is implemented in the `swap_modules_with_sparse` method in `sparseprop.utils`. Let's do this!" 572 | ] 573 | }, 574 | { 575 | "cell_type": "code", 576 | "execution_count": 17, 577 | "metadata": {}, 578 | "outputs": [ 579 | { 580 | "name": "stdout", 581 | "output_type": "stream", 582 | "text": [ 583 | "------------------------------\n", 584 | "keeping the module conv1 dense...\n", 585 | "------------------------------\n", 586 | "module Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) took 0.6537895202636719 fwd and 1.5996980667114258 bwd\n", 587 | "module SparseConv2d([64, 64, 3, 3], sp=0.95, nnz=1843, s=1, p=1, voo=False) took 1.3576796054840088 fwd and 1.9765965938568115 bwd\n", 588 | "module SparseConv2d([64, 64, 3, 3], sp=0.95, nnz=1843, s=1, p=1, voo=True) took 0.24248600006103516 fwd and 0.5414636135101318 bwd\n", 589 | "going with SparseConv2d([64, 64, 3, 3], sp=0.95, nnz=1843, s=1, p=1, voo=True) with full time of 0.783949613571167\n", 590 | "module layer1.0.conv1 replaced with SparseConv2d([64, 64, 3, 3], sp=0.95, nnz=1843, s=1, p=1, voo=True)\n", 591 | "------------------------------\n", 592 | "module Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) took 0.6515367031097412 fwd and 1.5962450504302979 bwd\n", 593 | "module SparseConv2d([64, 64, 3, 3], sp=0.95, nnz=1843, s=1, p=1, voo=False) took 1.3252687454223633 fwd and 1.9624810218811035 bwd\n", 594 | "module SparseConv2d([64, 64, 3, 3], sp=0.95, nnz=1843, s=1, p=1, voo=True) took 0.25075316429138184 fwd and 0.5489640235900879 bwd\n", 595 | "going with SparseConv2d([64, 64, 3, 3], sp=0.95, nnz=1843, s=1, p=1, voo=True) with full time of 0.7997171878814697\n", 596 | "module layer1.0.conv2 replaced with SparseConv2d([64, 64, 3, 3], sp=0.95, nnz=1843, s=1, p=1, voo=True)\n", 597 | "------------------------------\n", 598 | "module Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) took 0.651078462600708 fwd and 1.5953965187072754 bwd\n", 599 | "module SparseConv2d([64, 64, 3, 3], sp=0.95, nnz=1843, s=1, p=1, voo=False) took 1.3374743461608887 fwd and 1.9667127132415771 bwd\n", 600 | "module SparseConv2d([64, 64, 3, 3], sp=0.95, nnz=1843, s=1, p=1, voo=True) took 0.24969148635864258 fwd and 0.5471739768981934 bwd\n", 601 | "going with SparseConv2d([64, 64, 3, 3], sp=0.95, nnz=1843, s=1, p=1, voo=True) with full time of 0.7968654632568359\n", 602 | "module layer1.1.conv1 replaced with SparseConv2d([64, 64, 3, 3], sp=0.95, nnz=1843, s=1, p=1, voo=True)\n", 603 | "------------------------------\n", 604 | "module Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) took 0.6508922576904297 fwd and 1.5951457023620605 bwd\n", 605 | "module SparseConv2d([64, 64, 3, 3], sp=0.95, nnz=1843, s=1, p=1, voo=False) took 1.331139087677002 fwd and 1.9636893272399902 bwd\n", 606 | "module SparseConv2d([64, 64, 3, 3], sp=0.95, nnz=1843, s=1, p=1, voo=True) took 0.25159716606140137 fwd and 0.5491302013397217 bwd\n", 607 | "going with SparseConv2d([64, 64, 3, 3], sp=0.95, nnz=1843, s=1, p=1, voo=True) with full time of 0.800727367401123\n", 608 | "module layer1.1.conv2 replaced with SparseConv2d([64, 64, 3, 3], sp=0.95, nnz=1843, s=1, p=1, voo=True)\n", 609 | "------------------------------\n", 610 | "module Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) took 0.36971068382263184 fwd and 0.9490935802459717 bwd\n", 611 | "module SparseConv2d([128, 64, 3, 3], sp=0.95, nnz=3686, s=2, p=1, voo=False) took 1.1540188789367676 fwd and 0.9293184280395508 bwd\n", 612 | "module SparseConv2d([128, 64, 3, 3], sp=0.95, nnz=3686, s=2, p=1, voo=True) took 1.1658782958984375 fwd and 0.6493189334869385 bwd\n", 613 | "going with Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) with full time of 1.3188042640686035\n", 614 | "keeping the module layer2.0.conv1 dense...\n", 615 | "------------------------------\n", 616 | "module Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) took 0.6316447257995605 fwd and 1.5091893672943115 bwd\n", 617 | "module SparseConv2d([128, 128, 3, 3], sp=0.95, nnz=7373, s=1, p=1, voo=False) took 0.6453986167907715 fwd and 0.9138906002044678 bwd\n", 618 | "module SparseConv2d([128, 128, 3, 3], sp=0.95, nnz=7373, s=1, p=1, voo=True) took 0.3703117370605469 fwd and 0.687126636505127 bwd\n", 619 | "going with SparseConv2d([128, 128, 3, 3], sp=0.95, nnz=7373, s=1, p=1, voo=True) with full time of 1.0574383735656738\n", 620 | "module layer2.0.conv2 replaced with SparseConv2d([128, 128, 3, 3], sp=0.95, nnz=7373, s=1, p=1, voo=True)\n", 621 | "------------------------------\n", 622 | "module Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False) took 0.0976557731628418 fwd and 0.278552770614624 bwd\n", 623 | "module SparseConv2d([128, 64, 1, 1], sp=0.95, nnz=410, s=2, p=0, voo=False) took 0.9179229736328125 fwd and 0.7433874607086182 bwd\n", 624 | "module SparseConv2d([128, 64, 1, 1], sp=0.95, nnz=410, s=2, p=0, voo=True) took 0.9330620765686035 fwd and 0.2980782985687256 bwd\n", 625 | "going with Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False) with full time of 0.3762085437774658\n", 626 | "keeping the module layer2.0.downsample.0 dense...\n", 627 | "------------------------------\n", 628 | "module Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) took 0.6319241523742676 fwd and 1.508660078048706 bwd\n", 629 | "module SparseConv2d([128, 128, 3, 3], sp=0.95, nnz=7373, s=1, p=1, voo=False) took 0.6373977661132812 fwd and 0.9119634628295898 bwd\n", 630 | "module SparseConv2d([128, 128, 3, 3], sp=0.95, nnz=7373, s=1, p=1, voo=True) took 0.37430548667907715 fwd and 0.6898000240325928 bwd\n", 631 | "going with SparseConv2d([128, 128, 3, 3], sp=0.95, nnz=7373, s=1, p=1, voo=True) with full time of 1.06410551071167\n", 632 | "module layer2.1.conv1 replaced with SparseConv2d([128, 128, 3, 3], sp=0.95, nnz=7373, s=1, p=1, voo=True)\n", 633 | "------------------------------\n", 634 | "module Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) took 0.6316280364990234 fwd and 1.5089399814605713 bwd\n", 635 | "module SparseConv2d([128, 128, 3, 3], sp=0.95, nnz=7373, s=1, p=1, voo=False) took 0.6428439617156982 fwd and 0.9137217998504639 bwd\n", 636 | "module SparseConv2d([128, 128, 3, 3], sp=0.95, nnz=7373, s=1, p=1, voo=True) took 0.3743324279785156 fwd and 0.6895146369934082 bwd\n", 637 | "going with SparseConv2d([128, 128, 3, 3], sp=0.95, nnz=7373, s=1, p=1, voo=True) with full time of 1.0638470649719238\n", 638 | "module layer2.1.conv2 replaced with SparseConv2d([128, 128, 3, 3], sp=0.95, nnz=7373, s=1, p=1, voo=True)\n", 639 | "------------------------------\n", 640 | "module Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) took 0.3503093719482422 fwd and 0.8493261337280273 bwd\n", 641 | "module SparseConv2d([256, 128, 3, 3], sp=0.95, nnz=14746, s=2, p=1, voo=False) took 0.4656362533569336 fwd and 0.5568394660949707 bwd\n", 642 | "module SparseConv2d([256, 128, 3, 3], sp=0.95, nnz=14746, s=2, p=1, voo=True) took 0.46555352210998535 fwd and 0.8326826095581055 bwd\n", 643 | "going with SparseConv2d([256, 128, 3, 3], sp=0.95, nnz=14746, s=2, p=1, voo=False) with full time of 1.0224757194519043\n", 644 | "module layer3.0.conv1 replaced with SparseConv2d([256, 128, 3, 3], sp=0.95, nnz=14746, s=2, p=1, voo=False)\n", 645 | "------------------------------\n", 646 | "module Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) took 0.6343159675598145 fwd and 1.3283319473266602 bwd\n", 647 | "module SparseConv2d([256, 256, 3, 3], sp=0.95, nnz=29491, s=1, p=1, voo=False) took 0.5050003528594971 fwd and 0.6285405158996582 bwd\n", 648 | "module SparseConv2d([256, 256, 3, 3], sp=0.95, nnz=29491, s=1, p=1, voo=True) took 0.7813265323638916 fwd and 1.2833356857299805 bwd\n", 649 | "going with SparseConv2d([256, 256, 3, 3], sp=0.95, nnz=29491, s=1, p=1, voo=False) with full time of 1.1335408687591553\n", 650 | "module layer3.0.conv2 replaced with SparseConv2d([256, 256, 3, 3], sp=0.95, nnz=29491, s=1, p=1, voo=False)\n", 651 | "------------------------------\n", 652 | "module Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False) took 0.06782793998718262 fwd and 0.1881122589111328 bwd\n", 653 | "module SparseConv2d([256, 128, 1, 1], sp=0.95, nnz=1638, s=2, p=0, voo=False) took 0.22548270225524902 fwd and 0.385814905166626 bwd\n", 654 | "module SparseConv2d([256, 128, 1, 1], sp=0.95, nnz=1638, s=2, p=0, voo=True) took 0.22622132301330566 fwd and 0.2331991195678711 bwd\n", 655 | "going with Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False) with full time of 0.25594019889831543\n", 656 | "keeping the module layer3.0.downsample.0 dense...\n", 657 | "------------------------------\n", 658 | "module Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) took 0.6341907978057861 fwd and 1.3260297775268555 bwd\n", 659 | "module SparseConv2d([256, 256, 3, 3], sp=0.95, nnz=29491, s=1, p=1, voo=False) took 0.49993395805358887 fwd and 0.6333653926849365 bwd\n", 660 | "module SparseConv2d([256, 256, 3, 3], sp=0.95, nnz=29491, s=1, p=1, voo=True) took 0.7733349800109863 fwd and 1.2643136978149414 bwd\n", 661 | "going with SparseConv2d([256, 256, 3, 3], sp=0.95, nnz=29491, s=1, p=1, voo=False) with full time of 1.1332993507385254\n", 662 | "module layer3.1.conv1 replaced with SparseConv2d([256, 256, 3, 3], sp=0.95, nnz=29491, s=1, p=1, voo=False)\n", 663 | "------------------------------\n", 664 | "module Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) took 0.6344144344329834 fwd and 1.326756238937378 bwd\n", 665 | "module SparseConv2d([256, 256, 3, 3], sp=0.95, nnz=29491, s=1, p=1, voo=False) took 0.5124797821044922 fwd and 0.6291441917419434 bwd\n", 666 | "module SparseConv2d([256, 256, 3, 3], sp=0.95, nnz=29491, s=1, p=1, voo=True) took 0.7771730422973633 fwd and 1.2691457271575928 bwd\n", 667 | "going with SparseConv2d([256, 256, 3, 3], sp=0.95, nnz=29491, s=1, p=1, voo=False) with full time of 1.1416239738464355\n", 668 | "module layer3.1.conv2 replaced with SparseConv2d([256, 256, 3, 3], sp=0.95, nnz=29491, s=1, p=1, voo=False)\n", 669 | "------------------------------\n", 670 | "module Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) took 0.40129590034484863 fwd and 0.8475296497344971 bwd\n", 671 | "module SparseConv2d([512, 256, 3, 3], sp=0.95, nnz=58982, s=2, p=1, voo=False) took 0.3875877857208252 fwd and 0.31590890884399414 bwd\n", 672 | "module SparseConv2d([512, 256, 3, 3], sp=0.95, nnz=58982, s=2, p=1, voo=True) took 0.38722801208496094 fwd and 1.4996874332427979 bwd\n", 673 | "going with SparseConv2d([512, 256, 3, 3], sp=0.95, nnz=58982, s=2, p=1, voo=False) with full time of 0.7034966945648193\n", 674 | "module layer4.0.conv1 replaced with SparseConv2d([512, 256, 3, 3], sp=0.95, nnz=58982, s=2, p=1, voo=False)\n", 675 | "------------------------------\n", 676 | "module Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) took 0.6806645393371582 fwd and 1.3471698760986328 bwd\n", 677 | "module SparseConv2d([512, 512, 3, 3], sp=0.95, nnz=117965, s=1, p=1, voo=False) took 0.35182809829711914 fwd and 0.4279053211212158 bwd\n", 678 | "module SparseConv2d([512, 512, 3, 3], sp=0.95, nnz=117965, s=1, p=1, voo=True) took 1.8099141120910645 fwd and 2.827181339263916 bwd\n", 679 | "going with SparseConv2d([512, 512, 3, 3], sp=0.95, nnz=117965, s=1, p=1, voo=False) with full time of 0.779733419418335\n", 680 | "module layer4.0.conv2 replaced with SparseConv2d([512, 512, 3, 3], sp=0.95, nnz=117965, s=1, p=1, voo=False)\n", 681 | "------------------------------\n", 682 | "module Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False) took 0.04999351501464844 fwd and 0.14213848114013672 bwd\n", 683 | "module SparseConv2d([512, 256, 1, 1], sp=0.95, nnz=6554, s=2, p=0, voo=False) took 0.12142157554626465 fwd and 0.16385126113891602 bwd\n", 684 | "module SparseConv2d([512, 256, 1, 1], sp=0.95, nnz=6554, s=2, p=0, voo=True) took 0.1244204044342041 fwd and 0.3151395320892334 bwd\n", 685 | "going with Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False) with full time of 0.19213199615478516\n", 686 | "keeping the module layer4.0.downsample.0 dense...\n", 687 | "------------------------------\n", 688 | "module Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) took 0.6791129112243652 fwd and 1.341883659362793 bwd\n", 689 | "module SparseConv2d([512, 512, 3, 3], sp=0.95, nnz=117965, s=1, p=1, voo=False) took 0.334489107131958 fwd and 0.41767215728759766 bwd\n", 690 | "module SparseConv2d([512, 512, 3, 3], sp=0.95, nnz=117965, s=1, p=1, voo=True) took 1.7890021800994873 fwd and 2.7480850219726562 bwd\n", 691 | "going with SparseConv2d([512, 512, 3, 3], sp=0.95, nnz=117965, s=1, p=1, voo=False) with full time of 0.7521612644195557\n", 692 | "module layer4.1.conv1 replaced with SparseConv2d([512, 512, 3, 3], sp=0.95, nnz=117965, s=1, p=1, voo=False)\n", 693 | "------------------------------\n", 694 | "module Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) took 0.6805310249328613 fwd and 1.3308117389678955 bwd\n", 695 | "module SparseConv2d([512, 512, 3, 3], sp=0.95, nnz=117965, s=1, p=1, voo=False) took 0.3323853015899658 fwd and 0.41096043586730957 bwd\n", 696 | "module SparseConv2d([512, 512, 3, 3], sp=0.95, nnz=117965, s=1, p=1, voo=True) took 1.7962658405303955 fwd and 2.7770447731018066 bwd\n", 697 | "going with SparseConv2d([512, 512, 3, 3], sp=0.95, nnz=117965, s=1, p=1, voo=False) with full time of 0.7433457374572754\n", 698 | "module layer4.1.conv2 replaced with SparseConv2d([512, 512, 3, 3], sp=0.95, nnz=117965, s=1, p=1, voo=False)\n", 699 | "------------------------------\n", 700 | "keeping the module fc dense...\n" 701 | ] 702 | } 703 | ], 704 | "source": [ 705 | "from sparseprop.utils import swap_modules_with_sparse\n", 706 | "\n", 707 | "input_shape = next(iter(train_loader))[0].shape # we need the shape of our data\n", 708 | "\n", 709 | "# here's where the magic happens\n", 710 | "model = swap_modules_with_sparse(model, input_shape, inplace=True, verbose=True)" 711 | ] 712 | }, 713 | { 714 | "cell_type": "markdown", 715 | "metadata": {}, 716 | "source": [ 717 | "Let's see how our model looks like:" 718 | ] 719 | }, 720 | { 721 | "cell_type": "code", 722 | "execution_count": 18, 723 | "metadata": {}, 724 | "outputs": [ 725 | { 726 | "data": { 727 | "text/plain": [ 728 | "ResNet(\n", 729 | " (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n", 730 | " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 731 | " (relu): ReLU(inplace=True)\n", 732 | " (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n", 733 | " (layer1): Sequential(\n", 734 | " (0): BasicBlock(\n", 735 | " (conv1): SparseConv2d([64, 64, 3, 3], sp=0.95, nnz=1843, s=1, p=1, voo=True)\n", 736 | " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 737 | " (relu): ReLU(inplace=True)\n", 738 | " (conv2): SparseConv2d([64, 64, 3, 3], sp=0.95, nnz=1843, s=1, p=1, voo=True)\n", 739 | " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 740 | " )\n", 741 | " (1): BasicBlock(\n", 742 | " (conv1): SparseConv2d([64, 64, 3, 3], sp=0.95, nnz=1843, s=1, p=1, voo=True)\n", 743 | " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 744 | " (relu): ReLU(inplace=True)\n", 745 | " (conv2): SparseConv2d([64, 64, 3, 3], sp=0.95, nnz=1843, s=1, p=1, voo=True)\n", 746 | " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 747 | " )\n", 748 | " )\n", 749 | " (layer2): Sequential(\n", 750 | " (0): BasicBlock(\n", 751 | " (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", 752 | " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 753 | " (relu): ReLU(inplace=True)\n", 754 | " (conv2): SparseConv2d([128, 128, 3, 3], sp=0.95, nnz=7373, s=1, p=1, voo=True)\n", 755 | " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 756 | " (downsample): Sequential(\n", 757 | " (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", 758 | " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 759 | " )\n", 760 | " )\n", 761 | " (1): BasicBlock(\n", 762 | " (conv1): SparseConv2d([128, 128, 3, 3], sp=0.95, nnz=7373, s=1, p=1, voo=True)\n", 763 | " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 764 | " (relu): ReLU(inplace=True)\n", 765 | " (conv2): SparseConv2d([128, 128, 3, 3], sp=0.95, nnz=7373, s=1, p=1, voo=True)\n", 766 | " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 767 | " )\n", 768 | " )\n", 769 | " (layer3): Sequential(\n", 770 | " (0): BasicBlock(\n", 771 | " (conv1): SparseConv2d([256, 128, 3, 3], sp=0.95, nnz=14746, s=2, p=1, voo=False)\n", 772 | " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 773 | " (relu): ReLU(inplace=True)\n", 774 | " (conv2): SparseConv2d([256, 256, 3, 3], sp=0.95, nnz=29491, s=1, p=1, voo=False)\n", 775 | " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 776 | " (downsample): Sequential(\n", 777 | " (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", 778 | " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 779 | " )\n", 780 | " )\n", 781 | " (1): BasicBlock(\n", 782 | " (conv1): SparseConv2d([256, 256, 3, 3], sp=0.95, nnz=29491, s=1, p=1, voo=False)\n", 783 | " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 784 | " (relu): ReLU(inplace=True)\n", 785 | " (conv2): SparseConv2d([256, 256, 3, 3], sp=0.95, nnz=29491, s=1, p=1, voo=False)\n", 786 | " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 787 | " )\n", 788 | " )\n", 789 | " (layer4): Sequential(\n", 790 | " (0): BasicBlock(\n", 791 | " (conv1): SparseConv2d([512, 256, 3, 3], sp=0.95, nnz=58982, s=2, p=1, voo=False)\n", 792 | " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 793 | " (relu): ReLU(inplace=True)\n", 794 | " (conv2): SparseConv2d([512, 512, 3, 3], sp=0.95, nnz=117965, s=1, p=1, voo=False)\n", 795 | " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 796 | " (downsample): Sequential(\n", 797 | " (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", 798 | " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 799 | " )\n", 800 | " )\n", 801 | " (1): BasicBlock(\n", 802 | " (conv1): SparseConv2d([512, 512, 3, 3], sp=0.95, nnz=117965, s=1, p=1, voo=False)\n", 803 | " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 804 | " (relu): ReLU(inplace=True)\n", 805 | " (conv2): SparseConv2d([512, 512, 3, 3], sp=0.95, nnz=117965, s=1, p=1, voo=False)\n", 806 | " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 807 | " )\n", 808 | " )\n", 809 | " (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n", 810 | " (fc): Linear(in_features=512, out_features=10, bias=True)\n", 811 | ")" 812 | ] 813 | }, 814 | "execution_count": 18, 815 | "metadata": {}, 816 | "output_type": "execute_result" 817 | } 818 | ], 819 | "source": [ 820 | "model" 821 | ] 822 | }, 823 | { 824 | "cell_type": "markdown", 825 | "metadata": {}, 826 | "source": [ 827 | "Now that we have our model and dataloaders ready, let's create our loss criterion and optimizer objects." 828 | ] 829 | }, 830 | { 831 | "cell_type": "code", 832 | "execution_count": 19, 833 | "metadata": {}, 834 | "outputs": [], 835 | "source": [ 836 | "# loss and optim\n", 837 | "loss_fn = torch.nn.CrossEntropyLoss()\n", 838 | "optimizer = torch.optim.SGD(model.parameters(), lr=5e-3, momentum=0.9, weight_decay=1e-4)" 839 | ] 840 | }, 841 | { 842 | "cell_type": "markdown", 843 | "metadata": {}, 844 | "source": [ 845 | "Finally, let's create our `Finetuner` object to perform one epoch (51 batches) of training. This object handles training and validaton, as well as performing extensive timings on the network." 846 | ] 847 | }, 848 | { 849 | "cell_type": "code", 850 | "execution_count": 20, 851 | "metadata": {}, 852 | "outputs": [ 853 | { 854 | "name": "stdout", 855 | "output_type": "stream", 856 | "text": [ 857 | "[Train] Epoch 1, Step 1: loss=2.7053, acc=0.0820\n", 858 | "Timings: avg_end_to_end_forward=11.1022, avg_end_to_end_backward=11.5243, avg_end_to_end_minibatch=22.6428, avg_module_forward_sum=6.9712, avg_module_backward_sum=7.5827\n", 859 | "[Train] Epoch 1, Step 2: loss=2.6545, acc=0.0898\n", 860 | "Timings: avg_end_to_end_forward=11.0292, avg_end_to_end_backward=11.3527, avg_end_to_end_minibatch=22.3985, avg_module_forward_sum=6.9279, avg_module_backward_sum=7.5711\n", 861 | "[Train] Epoch 1, Step 3: loss=2.5691, acc=0.0924\n", 862 | "Timings: avg_end_to_end_forward=11.0056, avg_end_to_end_backward=11.2854, avg_end_to_end_minibatch=22.3066, avg_module_forward_sum=6.9138, avg_module_backward_sum=7.5587\n", 863 | "[Train] Epoch 1, Step 4: loss=2.4842, acc=0.1113\n", 864 | "Timings: avg_end_to_end_forward=10.9948, avg_end_to_end_backward=11.2466, avg_end_to_end_minibatch=22.2568, avg_module_forward_sum=6.9054, avg_module_backward_sum=7.5530\n", 865 | "[Train] Epoch 1, Step 5: loss=2.3990, acc=0.1437\n", 866 | "Timings: avg_end_to_end_forward=10.9984, avg_end_to_end_backward=11.2293, avg_end_to_end_minibatch=22.2431, avg_module_forward_sum=6.9095, avg_module_backward_sum=7.5502\n", 867 | "[Train] Epoch 1, Step 6: loss=2.3070, acc=0.1953\n", 868 | "Timings: avg_end_to_end_forward=10.9899, avg_end_to_end_backward=11.2169, avg_end_to_end_minibatch=22.2220, avg_module_forward_sum=6.9045, avg_module_backward_sum=7.5499\n", 869 | "[Train] Epoch 1, Step 7: loss=2.2150, acc=0.2394\n", 870 | "Timings: avg_end_to_end_forward=10.9836, avg_end_to_end_backward=11.1999, avg_end_to_end_minibatch=22.1986, avg_module_forward_sum=6.9007, avg_module_backward_sum=7.5457\n", 871 | "[Train] Epoch 1, Step 8: loss=2.1194, acc=0.2910\n", 872 | "Timings: avg_end_to_end_forward=10.9775, avg_end_to_end_backward=11.1892, avg_end_to_end_minibatch=22.1819, avg_module_forward_sum=6.8955, avg_module_backward_sum=7.5437\n", 873 | "[Train] Epoch 1, Step 9: loss=2.0242, acc=0.3368\n", 874 | "Timings: avg_end_to_end_forward=10.9733, avg_end_to_end_backward=11.1817, avg_end_to_end_minibatch=22.1701, avg_module_forward_sum=6.8918, avg_module_backward_sum=7.5428\n", 875 | "[Train] Epoch 1, Step 10: loss=1.9279, acc=0.3820\n", 876 | "Timings: avg_end_to_end_forward=10.9689, avg_end_to_end_backward=11.1748, avg_end_to_end_minibatch=22.1587, avg_module_forward_sum=6.8890, avg_module_backward_sum=7.5417\n", 877 | "[Train] Epoch 1, Step 11: loss=1.8320, acc=0.4244\n", 878 | "Timings: avg_end_to_end_forward=10.9660, avg_end_to_end_backward=11.1696, avg_end_to_end_minibatch=22.1503, avg_module_forward_sum=6.8869, avg_module_backward_sum=7.5417\n", 879 | "[Train] Epoch 1, Step 12: loss=1.7435, acc=0.4622\n", 880 | "Timings: avg_end_to_end_forward=10.9614, avg_end_to_end_backward=11.1651, avg_end_to_end_minibatch=22.1414, avg_module_forward_sum=6.8831, avg_module_backward_sum=7.5407\n", 881 | "[Train] Epoch 1, Step 13: loss=1.6589, acc=0.4964\n", 882 | "Timings: avg_end_to_end_forward=10.9593, avg_end_to_end_backward=11.1597, avg_end_to_end_minibatch=22.1338, avg_module_forward_sum=6.8812, avg_module_backward_sum=7.5393\n", 883 | "[Train] Epoch 1, Step 14: loss=1.5815, acc=0.5232\n", 884 | "Timings: avg_end_to_end_forward=10.9577, avg_end_to_end_backward=11.1564, avg_end_to_end_minibatch=22.1289, avg_module_forward_sum=6.8801, avg_module_backward_sum=7.5387\n", 885 | "[Train] Epoch 1, Step 15: loss=1.5068, acc=0.5482\n", 886 | "Timings: avg_end_to_end_forward=10.9548, avg_end_to_end_backward=11.1531, avg_end_to_end_minibatch=22.1226, avg_module_forward_sum=6.8784, avg_module_backward_sum=7.5377\n", 887 | "[Train] Epoch 1, Step 16: loss=1.4409, acc=0.5708\n", 888 | "Timings: avg_end_to_end_forward=10.9528, avg_end_to_end_backward=11.1505, avg_end_to_end_minibatch=22.1179, avg_module_forward_sum=6.8770, avg_module_backward_sum=7.5373\n", 889 | "[Train] Epoch 1, Step 17: loss=1.3755, acc=0.5926\n", 890 | "Timings: avg_end_to_end_forward=10.9501, avg_end_to_end_backward=11.1484, avg_end_to_end_minibatch=22.1131, avg_module_forward_sum=6.8749, avg_module_backward_sum=7.5369\n", 891 | "[Train] Epoch 1, Step 18: loss=1.3169, acc=0.6122\n", 892 | "Timings: avg_end_to_end_forward=10.9485, avg_end_to_end_backward=11.1465, avg_end_to_end_minibatch=22.1097, avg_module_forward_sum=6.8733, avg_module_backward_sum=7.5366\n", 893 | "[Train] Epoch 1, Step 19: loss=1.2669, acc=0.6283\n", 894 | "Timings: avg_end_to_end_forward=10.9475, avg_end_to_end_backward=11.1445, avg_end_to_end_minibatch=22.1066, avg_module_forward_sum=6.8717, avg_module_backward_sum=7.5362\n", 895 | "[Train] Epoch 1, Step 20: loss=1.2189, acc=0.6426\n", 896 | "Timings: avg_end_to_end_forward=10.9462, avg_end_to_end_backward=11.1480, avg_end_to_end_minibatch=22.1088, avg_module_forward_sum=6.8707, avg_module_backward_sum=7.5408\n", 897 | "[Train] Epoch 1, Step 21: loss=1.1756, acc=0.6561\n", 898 | "Timings: avg_end_to_end_forward=10.9446, avg_end_to_end_backward=11.1459, avg_end_to_end_minibatch=22.1052, avg_module_forward_sum=6.8697, avg_module_backward_sum=7.5399\n", 899 | "[Train] Epoch 1, Step 22: loss=1.1365, acc=0.6674\n", 900 | "Timings: avg_end_to_end_forward=10.9437, avg_end_to_end_backward=11.1436, avg_end_to_end_minibatch=22.1018, avg_module_forward_sum=6.8690, avg_module_backward_sum=7.5391\n", 901 | "[Train] Epoch 1, Step 23: loss=1.0966, acc=0.6795\n", 902 | "Timings: avg_end_to_end_forward=10.9420, avg_end_to_end_backward=11.1421, avg_end_to_end_minibatch=22.0986, avg_module_forward_sum=6.8679, avg_module_backward_sum=7.5384\n", 903 | "[Train] Epoch 1, Step 24: loss=1.0588, acc=0.6911\n", 904 | "Timings: avg_end_to_end_forward=10.9409, avg_end_to_end_backward=11.1407, avg_end_to_end_minibatch=22.0961, avg_module_forward_sum=6.8672, avg_module_backward_sum=7.5378\n", 905 | "[Train] Epoch 1, Step 25: loss=1.0273, acc=0.7005\n", 906 | "Timings: avg_end_to_end_forward=10.9401, avg_end_to_end_backward=11.1393, avg_end_to_end_minibatch=22.0939, avg_module_forward_sum=6.8662, avg_module_backward_sum=7.5373\n", 907 | "[Train] Epoch 1, Step 26: loss=0.9959, acc=0.7099\n", 908 | "Timings: avg_end_to_end_forward=10.9389, avg_end_to_end_backward=11.1386, avg_end_to_end_minibatch=22.0920, avg_module_forward_sum=6.8654, avg_module_backward_sum=7.5372\n", 909 | "[Train] Epoch 1, Step 27: loss=0.9684, acc=0.7177\n", 910 | "Timings: avg_end_to_end_forward=10.9377, avg_end_to_end_backward=11.1385, avg_end_to_end_minibatch=22.0907, avg_module_forward_sum=6.8644, avg_module_backward_sum=7.5375\n", 911 | "[Train] Epoch 1, Step 28: loss=0.9410, acc=0.7257\n", 912 | "Timings: avg_end_to_end_forward=10.9364, avg_end_to_end_backward=11.1374, avg_end_to_end_minibatch=22.0883, avg_module_forward_sum=6.8635, avg_module_backward_sum=7.5369\n", 913 | "[Train] Epoch 1, Step 29: loss=0.9145, acc=0.7337\n", 914 | "Timings: avg_end_to_end_forward=10.9347, avg_end_to_end_backward=11.1365, avg_end_to_end_minibatch=22.0857, avg_module_forward_sum=6.8626, avg_module_backward_sum=7.5365\n", 915 | "[Train] Epoch 1, Step 30: loss=0.8888, acc=0.7417\n", 916 | "Timings: avg_end_to_end_forward=10.9340, avg_end_to_end_backward=11.1346, avg_end_to_end_minibatch=22.0831, avg_module_forward_sum=6.8622, avg_module_backward_sum=7.5355\n", 917 | "[Train] Epoch 1, Step 31: loss=0.8657, acc=0.7485\n", 918 | "Timings: avg_end_to_end_forward=10.9326, avg_end_to_end_backward=11.1333, avg_end_to_end_minibatch=22.0804, avg_module_forward_sum=6.8615, avg_module_backward_sum=7.5347\n", 919 | "[Train] Epoch 1, Step 32: loss=0.8440, acc=0.7551\n", 920 | "Timings: avg_end_to_end_forward=10.9313, avg_end_to_end_backward=11.1323, avg_end_to_end_minibatch=22.0780, avg_module_forward_sum=6.8609, avg_module_backward_sum=7.5343\n", 921 | "[Train] Epoch 1, Step 33: loss=0.8252, acc=0.7603\n", 922 | "Timings: avg_end_to_end_forward=10.9299, avg_end_to_end_backward=11.1316, avg_end_to_end_minibatch=22.0759, avg_module_forward_sum=6.8602, avg_module_backward_sum=7.5339\n", 923 | "[Train] Epoch 1, Step 34: loss=0.8085, acc=0.7652\n", 924 | "Timings: avg_end_to_end_forward=10.9294, avg_end_to_end_backward=11.1306, avg_end_to_end_minibatch=22.0744, avg_module_forward_sum=6.8596, avg_module_backward_sum=7.5335\n", 925 | "[Train] Epoch 1, Step 35: loss=0.7913, acc=0.7702\n", 926 | "Timings: avg_end_to_end_forward=10.9282, avg_end_to_end_backward=11.1299, avg_end_to_end_minibatch=22.0725, avg_module_forward_sum=6.8589, avg_module_backward_sum=7.5331\n", 927 | "[Train] Epoch 1, Step 36: loss=0.7739, acc=0.7753\n", 928 | "Timings: avg_end_to_end_forward=10.9274, avg_end_to_end_backward=11.1286, avg_end_to_end_minibatch=22.0704, avg_module_forward_sum=6.8586, avg_module_backward_sum=7.5325\n", 929 | "[Train] Epoch 1, Step 37: loss=0.7571, acc=0.7798\n", 930 | "Timings: avg_end_to_end_forward=10.9265, avg_end_to_end_backward=11.1282, avg_end_to_end_minibatch=22.0691, avg_module_forward_sum=6.8579, avg_module_backward_sum=7.5323\n", 931 | "[Train] Epoch 1, Step 38: loss=0.7410, acc=0.7843\n", 932 | "Timings: avg_end_to_end_forward=10.9260, avg_end_to_end_backward=11.1273, avg_end_to_end_minibatch=22.0676, avg_module_forward_sum=6.8575, avg_module_backward_sum=7.5317\n", 933 | "[Train] Epoch 1, Step 39: loss=0.7258, acc=0.7887\n", 934 | "Timings: avg_end_to_end_forward=10.9254, avg_end_to_end_backward=11.1263, avg_end_to_end_minibatch=22.0661, avg_module_forward_sum=6.8570, avg_module_backward_sum=7.5312\n", 935 | "[Train] Epoch 1, Step 40: loss=0.7112, acc=0.7929\n", 936 | "Timings: avg_end_to_end_forward=10.9251, avg_end_to_end_backward=11.1261, avg_end_to_end_minibatch=22.0656, avg_module_forward_sum=6.8568, avg_module_backward_sum=7.5313\n", 937 | "[Train] Epoch 1, Step 41: loss=0.6990, acc=0.7965\n", 938 | "Timings: avg_end_to_end_forward=10.9247, avg_end_to_end_backward=11.1253, avg_end_to_end_minibatch=22.0643, avg_module_forward_sum=6.8564, avg_module_backward_sum=7.5309\n", 939 | "[Train] Epoch 1, Step 42: loss=0.6863, acc=0.7999\n", 940 | "Timings: avg_end_to_end_forward=10.9238, avg_end_to_end_backward=11.1250, avg_end_to_end_minibatch=22.0632, avg_module_forward_sum=6.8558, avg_module_backward_sum=7.5310\n", 941 | "[Train] Epoch 1, Step 43: loss=0.6744, acc=0.8031\n", 942 | "Timings: avg_end_to_end_forward=10.9233, avg_end_to_end_backward=11.1247, avg_end_to_end_minibatch=22.0624, avg_module_forward_sum=6.8555, avg_module_backward_sum=7.5309\n", 943 | "[Train] Epoch 1, Step 44: loss=0.6626, acc=0.8068\n", 944 | "Timings: avg_end_to_end_forward=10.9226, avg_end_to_end_backward=11.1241, avg_end_to_end_minibatch=22.0611, avg_module_forward_sum=6.8550, avg_module_backward_sum=7.5306\n", 945 | "[Train] Epoch 1, Step 45: loss=0.6506, acc=0.8105\n", 946 | "Timings: avg_end_to_end_forward=10.9218, avg_end_to_end_backward=11.1239, avg_end_to_end_minibatch=22.0601, avg_module_forward_sum=6.8546, avg_module_backward_sum=7.5305\n", 947 | "[Train] Epoch 1, Step 46: loss=0.6392, acc=0.8139\n", 948 | "Timings: avg_end_to_end_forward=10.9208, avg_end_to_end_backward=11.1235, avg_end_to_end_minibatch=22.0587, avg_module_forward_sum=6.8541, avg_module_backward_sum=7.5305\n", 949 | "[Train] Epoch 1, Step 47: loss=0.6296, acc=0.8167\n", 950 | "Timings: avg_end_to_end_forward=10.9201, avg_end_to_end_backward=11.1235, avg_end_to_end_minibatch=22.0580, avg_module_forward_sum=6.8537, avg_module_backward_sum=7.5307\n", 951 | "[Train] Epoch 1, Step 48: loss=0.6197, acc=0.8195\n", 952 | "Timings: avg_end_to_end_forward=10.9199, avg_end_to_end_backward=11.1232, avg_end_to_end_minibatch=22.0574, avg_module_forward_sum=6.8537, avg_module_backward_sum=7.5307\n", 953 | "[Train] Epoch 1, Step 49: loss=0.6094, acc=0.8224\n", 954 | "Timings: avg_end_to_end_forward=10.9195, avg_end_to_end_backward=11.1230, avg_end_to_end_minibatch=22.0568, avg_module_forward_sum=6.8534, avg_module_backward_sum=7.5306\n", 955 | "[Train] Epoch 1, Step 50: loss=0.6001, acc=0.8251\n", 956 | "Timings: avg_end_to_end_forward=10.9193, avg_end_to_end_backward=11.1227, avg_end_to_end_minibatch=22.0564, avg_module_forward_sum=6.8533, avg_module_backward_sum=7.5306\n", 957 | "[Train] Epoch 1, Step 51: loss=0.5944, acc=0.8266\n", 958 | "Timings: avg_end_to_end_forward=10.7752, avg_end_to_end_backward=10.9864, avg_end_to_end_minibatch=21.7758, avg_module_forward_sum=6.7629, avg_module_backward_sum=7.4425\n", 959 | "[Train] Epoch 1: loss=0.5944, acc=0.8266\n", 960 | "Timings: avg_end_to_end_forward=10.7752, avg_end_to_end_backward=10.9864, avg_end_to_end_minibatch=21.7758, avg_module_forward_sum=6.7629, avg_module_backward_sum=7.4425\n", 961 | "Epoch 1 training took 1112.7971.\n", 962 | "[Val] Epoch 1: loss=0.1265, acc=0.9638\n", 963 | "Timings: avg_end_to_end_forward=9.7825, avg_end_to_end_backward=0.0000, avg_end_to_end_minibatch=9.7828, avg_module_forward_sum=6.9401, avg_module_backward_sum=0.0000\n", 964 | "Epoch 1 validation took 20.6065.\n", 965 | "The full finetuning took 1133.4627.\n" 966 | ] 967 | } 968 | ], 969 | "source": [ 970 | "from utils import Finetuner\n", 971 | "\n", 972 | "# initialize the finetuner\n", 973 | "finetuner = Finetuner(\n", 974 | " model,\n", 975 | " optimizer,\n", 976 | " schedular=None, # we could pass an lr schedular here. no need for this example.\n", 977 | " loss_fn=loss_fn,\n", 978 | " log_freq=1, # how often to log (in batches). 1 means that it will log after processing every batch.\n", 979 | " save_freq=1, # how often to save the checkpoint (in epochs). 1 means that it will save a checkpoint after each epoch.\n", 980 | " logger=logger\n", 981 | ").finetune(train_loader, test_loader, epochs=1)" 982 | ] 983 | }, 984 | { 985 | "cell_type": "markdown", 986 | "metadata": {}, 987 | "source": [ 988 | "Notice that in addition to the loss and accuracy metrics, this script also reports the time spent in each part of the process. The timings include:\n", 989 | "\n", 990 | "- `avg_end_to_end_forward`: the average time spent in the forward pass, i.e., the model(inputs) line.\n", 991 | "- `avg_end_to_end_backward`: the average time spent in the backward pass, i.e., the loss.backward() line.\n", 992 | "- `avg_end_to_end_minibatch`: the average time spent processing a minibatch. This includes forward pass, backward pass, loss calculation, optimization step, etc. Note that loading the data into memory is not included.\n", 993 | "- `avg_module_forward_sum`: the average time spent in the forward function of the modules torch.nn.Linear, torch.nn.Conv2d, SparseLinear, and SparseConv2d.\n", 994 | "- `avg_module_backward_sum`: the average time spent in the backward function of the modules torch.nn.Linear, torch.nn.Conv2d, SparseLinear, and SparseConv2d.\n" 995 | ] 996 | }, 997 | { 998 | "cell_type": "markdown", 999 | "metadata": {}, 1000 | "source": [ 1001 | "## Conclusion\n", 1002 | "\n", 1003 | "In this notebook we provided step-by-step examples of how to benefit from *SparseProp*'s sparse implementation for speeding-up a single layer, as well as an entire network. For the latter, we took a 95% uniform sparse ResNet18 model (pretrained on ImageNet), and fine-tuned it on the ImageNette dataset." 1004 | ] 1005 | } 1006 | ], 1007 | "metadata": { 1008 | "kernelspec": { 1009 | "display_name": "Python (sp)", 1010 | "language": "python", 1011 | "name": "sp" 1012 | }, 1013 | "language_info": { 1014 | "codemirror_mode": { 1015 | "name": "ipython", 1016 | "version": 3 1017 | }, 1018 | "file_extension": ".py", 1019 | "mimetype": "text/x-python", 1020 | "name": "python", 1021 | "nbconvert_exporter": "python", 1022 | "pygments_lexer": "ipython3", 1023 | "version": "3.9.7" 1024 | } 1025 | }, 1026 | "nbformat": 4, 1027 | "nbformat_minor": 2 1028 | } 1029 | -------------------------------------------------------------------------------- /examples/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from copy import deepcopy 4 | from collections import OrderedDict 5 | 6 | from time import time 7 | 8 | from sparseprop.modules import SparseLinear, SparseConv2d 9 | from sparseprop.utils import swap_module, sparsity 10 | 11 | class Finetuner: 12 | def __init__(self, model, optimizer, schedular, loss_fn, log_freq, save_freq, logger): 13 | self._model = WrappedModel(model) # wrap for layer-wise timing 14 | self._optimizer = optimizer 15 | self._schedular = schedular 16 | self._loss_fn = loss_fn 17 | self._log_freq = log_freq 18 | self._save_freq = save_freq 19 | self._logger = logger 20 | 21 | with torch.no_grad(): 22 | # we need to keep track of the pruned Linear and Conv2d modules, since we need to mask them after each step 23 | is_sparse = apply_to_all_modules_with_types( 24 | self._model, 25 | [torch.nn.Linear, torch.nn.Conv2d], 26 | lambda n, m: sparsity(m) > 0. 27 | ) 28 | dense_modules_to_keep_sparse = [key for key, value in is_sparse.items() if value] 29 | 30 | # store the sparsity mask of dense modules to apply to their weight after each update 31 | self._sparsity_masks = apply_to_all_modules_with_names( 32 | self._model, 33 | dense_modules_to_keep_sparse, 34 | lambda n, m: (m.weight.data != 0).float() 35 | ) 36 | 37 | def _step(self, inputs, targets, phase, timings=None, profile=False): 38 | assert phase in ['train', 'test'] 39 | train = phase == 'train' 40 | 41 | if timings is None: 42 | timings = OrderedDict() 43 | 44 | if train: 45 | self._optimizer.zero_grad() 46 | 47 | with Timer(timings, 'end_to_end_forward'): 48 | if profile: 49 | with torch.autograd.profiler.profile() as prof: 50 | outputs = self._model(inputs) 51 | print("================================================ Forward Profile ================================================") 52 | print(prof.key_averages().table(sort_by="self_cpu_time_total", row_limit=10)) 53 | else: 54 | outputs = self._model(inputs) 55 | 56 | 57 | loss = self._loss_fn(outputs, targets) 58 | 59 | with Timer(timings, 'end_to_end_backward'): 60 | if train: 61 | if profile: 62 | with torch.autograd.profiler.profile() as prof: 63 | loss.backward() 64 | print("================================================ Backward Profile ================================================") 65 | print(prof.key_averages().table(sort_by="self_cpu_time_total", row_limit=10)) 66 | else: 67 | loss.backward() 68 | 69 | if train: 70 | self._optimizer.step() 71 | if self._schedular is not None: 72 | self._schedular.step() 73 | 74 | with torch.no_grad(): 75 | apply_to_all_modules_with_names( 76 | self._model, 77 | list(self._sparsity_masks.keys()), 78 | lambda n, m: m.weight.data.mul_(self._sparsity_masks[n]) 79 | ) 80 | 81 | pred = torch.argmax(outputs, dim=-1) 82 | acc = torch.mean((pred == targets).float()) 83 | 84 | return loss.item(), acc.item(), timings 85 | 86 | def _log_timings(self, epoch, step, full_timings, modules_forward_time, modules_backward_time): 87 | l = f'Timings: ' 88 | 89 | for key, value in full_timings.items(): 90 | l += f'avg_{key}={value / (step + 1):.4f}, ' 91 | l += f'avg_module_forward_sum={modules_forward_time:.4f}, ' 92 | l += f'avg_module_backward_sum={modules_backward_time:.4f}' 93 | 94 | self._logger.log(l) 95 | 96 | def _run_epoch(self, loader, epoch, phase): 97 | assert phase in ['train', 'test'] 98 | train = phase == 'train' 99 | 100 | initial_training = self._model.training 101 | self._model.train(train) 102 | 103 | running_loss = 0. 104 | running_acc = 0. 105 | 106 | full_timings = None 107 | 108 | with torch.set_grad_enabled(train): 109 | for step, (inputs, targets) in enumerate(loader): 110 | timings = OrderedDict() 111 | 112 | with Timer(timings, 'end_to_end_minibatch'): 113 | loss, acc, _ = self._step(inputs, targets, phase, timings=timings) 114 | 115 | running_loss += loss 116 | running_acc += acc 117 | 118 | if full_timings is None: 119 | full_timings = timings 120 | else: 121 | for name in timings: 122 | full_timings[name] += timings[name] 123 | 124 | if train and (step + 1) % self._log_freq == 0: 125 | avg_loss = running_loss / (step + 1) 126 | avg_acc = running_acc / (step + 1) 127 | self._logger.log(f'[{"Train" if train else "Val"}] Epoch {epoch + 1}, Step {step + 1}: loss={avg_loss:.4f}, acc={avg_acc:.4f}') 128 | self._log_timings(epoch, step, full_timings, *self._model.get_timings_sum()) 129 | 130 | avg_loss = running_loss / (step + 1) 131 | avg_acc = running_acc / (step + 1) 132 | self._logger.log(f'[{"Train" if train else "Val"}] Epoch {epoch + 1}: loss={avg_loss:.4f}, acc={avg_acc:.4f}') 133 | self._log_timings(epoch, step, full_timings, *self._model.get_timings_sum(reset_afterwards=True)) 134 | 135 | self._model.train(initial_training) 136 | 137 | def finetune(self, train_loader, test_loader, epochs): 138 | timings = OrderedDict() 139 | 140 | with Timer(timings, 'end_to_end_finetuning'): 141 | for epoch in range(epochs): 142 | with Timer(timings, f'epoch_train'): 143 | self._run_epoch(train_loader, epoch, phase='train') 144 | 145 | self._logger.log(f"Epoch {epoch + 1} training took {timings['epoch_train']:.4f}.") 146 | 147 | with Timer(timings, f'epoch_val'): 148 | self._run_epoch(test_loader, epoch, phase='test') 149 | 150 | self._logger.log(f"Epoch {epoch + 1} validation took {timings['epoch_val']:.4f}.") 151 | 152 | if (epoch + 1) % self._save_freq == 0: 153 | torch.save( 154 | self._model.unwrap().state_dict(), 155 | os.path.join(self._logger._outdir, f'checkpoint-{epoch + 1}.pt') 156 | ) 157 | 158 | self._logger.log(f"The full finetuning took {timings['end_to_end_finetuning']:.4f}.") 159 | torch.save(self._model.unwrap().state_dict(), os.path.join(self._logger._outdir, 'final_checkpoint.pt')) 160 | 161 | 162 | # this class can be used as a context manager to time events 163 | class Timer(object): 164 | def __init__(self, target_dict, event_name): 165 | self._target_dict = target_dict 166 | self._event_name = event_name 167 | 168 | def __enter__(self): 169 | self._start = time() 170 | 171 | def __exit__(self, type, value, traceback): 172 | end = time() 173 | self._target_dict[self._event_name] = end - self._start 174 | 175 | class TimingHook: 176 | def __init__(self, tag=None, verbose=False): 177 | self.clear() 178 | 179 | def __call__(self, module, inp, out): 180 | self.times.append(time()) 181 | 182 | def clear(self): 183 | self.times = [] 184 | 185 | class WrappedModel(torch.nn.Module): 186 | def __init__(self, model): 187 | super().__init__() 188 | self._model = model 189 | self.forward = self._model.forward 190 | 191 | bwd_hooks, fwd_hooks, bwd_handles, fwd_handles = self._prepare_for_timing() 192 | self._bwd_hooks = bwd_hooks 193 | self._fwd_hooks = fwd_hooks 194 | self._handles = bwd_handles + fwd_handles 195 | 196 | def _clear_hooks(self): 197 | for hook_pair in self._bwd_hooks + self._fwd_hooks: 198 | _, hook1, hook2 = hook_pair 199 | hook1.clear() 200 | hook2.clear() 201 | 202 | def _prepare_for_timing(self): 203 | # we swap each module with a torch.nn.Sequential consisting of [identity_before, module, identity_after] 204 | # we employ hooks for per-layer timing: 205 | # for forward: (time right after module's forward) - (time right after identity_before's forward) 206 | # for backward: (time right after module's backward) - (time right after identity_after's backward) 207 | 208 | backward_hooks, forward_hooks = [], [] 209 | backward_handles, forward_handles = [], [] 210 | 211 | for name, module in self._model.named_modules(): 212 | if isinstance(module, torch.nn.ReLU) or isinstance(module, torch.nn.ReLU6): 213 | module.inplace = False 214 | 215 | if isinstance(module, torch.nn.Conv2d) or isinstance(module, SparseConv2d) or isinstance(module, torch.nn.Linear) or isinstance(module, SparseLinear): 216 | identity_after = torch.nn.Identity() 217 | identity_before = torch.nn.Identity() 218 | 219 | identity_backward_hook = TimingHook() 220 | module_backward_hook = TimingHook() 221 | backward_hooks.append((name, module_backward_hook, identity_backward_hook)) 222 | backward_handles += [ 223 | module.register_full_backward_hook(module_backward_hook), 224 | identity_after.register_full_backward_hook(identity_backward_hook) 225 | ] 226 | 227 | identity_forward_hook = TimingHook() 228 | module_forward_hook = TimingHook() 229 | forward_hooks.append((name, module_forward_hook, identity_forward_hook)) 230 | forward_handles += [ 231 | module.register_forward_hook(module_forward_hook), 232 | identity_before.register_forward_hook(identity_forward_hook) 233 | ] 234 | 235 | swap_module(self._model, name, torch.nn.Sequential(identity_before, module, identity_after)) 236 | 237 | return backward_hooks, forward_hooks, backward_handles, forward_handles 238 | 239 | def unwrap(self, inplace=False): 240 | if inplace: 241 | model = self._model 242 | for handle in self._handles: 243 | handle.remove() 244 | else: 245 | model = deepcopy(self._model) 246 | 247 | replaced_module_names = [h[0] for h in self._fwd_hooks] 248 | for name, module in model.named_modules(): 249 | if name in replaced_module_names: 250 | # replace the torch.nn.Sequential with the original module 251 | swap_module(model, name, module[1]) 252 | 253 | return model 254 | 255 | def _safe_mean(self, arr): 256 | if len(arr) == 0: 257 | return 0. 258 | return sum(arr) / len(arr) 259 | 260 | def get_per_layer_timings(self, reset_afterwards=False): 261 | per_layer_forward_time = OrderedDict() 262 | per_layer_backward_time = OrderedDict() 263 | 264 | for name, conv_hook, identity_hook in self._fwd_hooks: 265 | per_layer_forward_time[name] = self._safe_mean(conv_hook.times) - self._safe_mean(identity_hook.times) 266 | 267 | for name, conv_hook, identity_hook in self._bwd_hooks: 268 | per_layer_backward_time[name] = self._safe_mean(conv_hook.times) - self._safe_mean(identity_hook.times) 269 | 270 | if reset_afterwards: 271 | self._clear_hooks() 272 | 273 | return per_layer_forward_time, per_layer_backward_time 274 | 275 | def get_timings_sum(self, reset_afterwards=False): 276 | per_layer_forward_time, per_layer_backward_time = self.get_per_layer_timings( 277 | reset_afterwards=reset_afterwards 278 | ) 279 | 280 | return sum(per_layer_forward_time.values()), sum(per_layer_backward_time.values()) 281 | 282 | 283 | class Logger: 284 | def __init__(self, outdir=None): 285 | self._outdir = outdir 286 | 287 | def log(self, l): 288 | if self._outdir is not None: 289 | with open(os.path.join(self._outdir, 'log.txt'), 'a') as f: 290 | f.write(l) 291 | f.write('\n') 292 | print(l) 293 | 294 | def apply_to_all_modules_with_types(model, module_classes, func): 295 | out = OrderedDict() 296 | 297 | for name, module in model.named_modules(): 298 | if any([isinstance(module, c) for c in module_classes]): 299 | out[name] = func(name, module) 300 | 301 | return out 302 | 303 | def apply_to_all_modules_with_names(model, module_names, func): 304 | out = OrderedDict() 305 | 306 | for name, module in model.named_modules(): 307 | if name in module_names: 308 | out[name] = func(name, module) 309 | 310 | return out 311 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages, Extension 2 | import subprocess 3 | 4 | subprocess.run(["pip install pybind11"], shell=True) 5 | 6 | proc = subprocess.Popen(["python3 -m pybind11 --includes"], stdout=subprocess.PIPE, shell=True) 7 | (out, err) = proc.communicate() 8 | out = out.decode('ascii').strip().split() 9 | 10 | setup( 11 | name='sparseprop', 12 | version='0.1.14', 13 | description='SparseProp: Efficient Sparse Backpropagation for Faster Training of Neural Networks', 14 | url='https://github.com/IST-DASLab/sparseprop', 15 | author='Mahdi Nikdan, Tommaso Pegolotti, Eugenia Iofinova, Eldar Kurtic, Dan Alistarh', 16 | author_email='mahdi.nikdan@ist.ac.at, tommaso.pegolotti@inf.ethz.ch, eugenia.iofinova@ist.ac.at, eldar.kurtic@ist.ac.at, dan.alistarh@ist.ac.at', 17 | license='Apache License 2.0', 18 | packages=find_packages(exclude=['tests', 'tests.*']), 19 | ext_modules=[Extension( 20 | 'backend', 21 | [ 22 | 'sparseprop/backend.cpp', 23 | ], 24 | extra_compile_args=['-O3', '-Wall', '-shared', '-std=c++11', '-fPIC', *out, '-march=native', '-fopenmp', '-ffast-math'], 25 | extra_link_args=['-lgomp'] 26 | )], 27 | install_requires=[ 28 | 'setuptools>=59.0', 29 | 'pybind11>=2.0.0', 30 | 'scipy', 31 | ], 32 | include_package_data=True, 33 | classifiers=[ 34 | "Programming Language :: C++", 35 | "Programming Language :: Python :: 3", 36 | "License :: OSI Approved :: Apache Software License", 37 | "Operating System :: POSIX :: Linux", 38 | ], 39 | ) -------------------------------------------------------------------------------- /sparseprop/__init__.py: -------------------------------------------------------------------------------- 1 | import backend 2 | import os 3 | 4 | def set_num_threads(num_threads): 5 | os.environ['OMP_NUM_THREADS'] = str(num_threads) -------------------------------------------------------------------------------- /sparseprop/backend.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "lib/sparse_linear.cpp" 6 | #include "lib/sparse_conv2d.cpp" 7 | #include "lib/sparse_conv2d_over_on.cpp" 8 | #include "lib/utils.cpp" 9 | 10 | PYBIND11_MODULE(backend, m) 11 | { 12 | // linear 13 | m.def("sparse_linear_vectorized_forward", &sparse_linear_vectorized_forward_wrapper); 14 | m.def("sparse_linear_vectorized_backward", &sparse_linear_vectorized_backward_wrapper); 15 | 16 | // conv2d 17 | m.def("sparse_conv2d_vectorized_forward_stride_1", &sparse_conv2d_vectorized_forward_stride_1_wrapper); 18 | m.def("sparse_conv2d_vectorized_backward_stride_1", &sparse_conv2d_vectorized_backward_stride_1_wrapper); 19 | m.def("sparse_conv2d_vectorized_forward_stride_2", &sparse_conv2d_vectorized_forward_stride_2_wrapper); 20 | m.def("sparse_conv2d_vectorized_backward_stride_2", &sparse_conv2d_vectorized_backward_stride_2_wrapper); 21 | 22 | // conv2d over on 23 | m.def("sparse_conv2d_vectorized_forward_over_on_stride_1", &sparse_conv2d_vectorized_forward_over_on_stride_1_wrapper); 24 | m.def("sparse_conv2d_vectorized_backward_over_on_stride_1", &sparse_conv2d_vectorized_backward_over_on_stride_1_wrapper); 25 | m.def("sparse_conv2d_vectorized_backward_over_on_stride_2", &sparse_conv2d_vectorized_backward_over_on_stride_2_wrapper); 26 | 27 | // utils 28 | m.def("transpose", &transpose_wrapper); 29 | m.def("sparsify_conv2d", &sparsify_conv2d_wrapper); 30 | m.def("densify_conv2d", &densify_conv2d_wrapper); 31 | m.def("further_sparsify_conv2d", &further_sparsify_conv2d_wrapper); 32 | } -------------------------------------------------------------------------------- /sparseprop/lib/sparse_conv2d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | namespace py = pybind11; 7 | 8 | // ====================================== Stride 1 =========================================== 9 | void sparse_conv2d_vectorized_forward_stride_1(int B, int IC, int OC, int M, int N, int K, int W_nnz, int padding, 10 | float* __restrict__ X, int* __restrict__ W_idx_OC, 11 | int16_t* __restrict__ W_idx_IC, uint8_t* __restrict__ W_idx_X, 12 | uint8_t* __restrict__ W_idx_Y, float* __restrict__ W_val, 13 | float* __restrict__ O) { 14 | 15 | const int OM = M + 2 * padding - K + 1; 16 | const int ON = N + 2 * padding - K + 1; 17 | 18 | #pragma omp parallel 19 | { 20 | 21 | #pragma omp for 22 | for (int oc = 0; oc < OC; oc++){ 23 | for (int ic = 0; ic < IC; ic++){ 24 | int oc_s = W_idx_OC[oc]; 25 | int ic_s = oc_s + W_idx_IC[(IC + 1) * oc + ic]; 26 | int ic_e = oc_s + W_idx_IC[(IC + 1) * oc + ic + 1]; 27 | 28 | for (int si = ic_s; si < ic_e; si++) { 29 | uint8_t i = W_idx_X[si]; 30 | uint8_t j = W_idx_Y[si]; 31 | 32 | float v = W_val[si]; 33 | __m256 vv = _mm256_set1_ps(W_val[si]); 34 | 35 | const int pdmi = padding - i; 36 | const int pdmj = padding - j; 37 | const int p_start = std::max(pdmi, 0); 38 | const int p_end = std::min(pdmi + M, OM); 39 | const int q_start = std::max(pdmj, 0); 40 | const int q_end = std::min(pdmj + N, ON); 41 | 42 | for (int po = p_start, px = p_start - padding + i; po < p_end; po++, px++) { 43 | int qo = q_start, qx = q_start - padding + j; 44 | for (; qo < q_end-3; qo+=4, qx+=4) { 45 | int b = 0; 46 | for (; b < B-7; b+=8) { 47 | int Xi = ic * B * M * N + px * N * B + qx * B + b; 48 | int Oi = oc * OM * ON * B + po * ON * B + qo * B + b; 49 | 50 | const __m256 o0 = _mm256_loadu_ps(O + Oi); 51 | const __m256 o1 = _mm256_loadu_ps(O + Oi + B); 52 | const __m256 o2 = _mm256_loadu_ps(O + Oi + 2 * B); 53 | const __m256 o3 = _mm256_loadu_ps(O + Oi + 3 * B); 54 | const __m256 x0 = _mm256_loadu_ps(X + Xi); 55 | const __m256 x1 = _mm256_loadu_ps(X + Xi + B); 56 | const __m256 x2 = _mm256_loadu_ps(X + Xi + 2 * B); 57 | const __m256 x3 = _mm256_loadu_ps(X + Xi + 3 * B); 58 | 59 | const __m256 r0 = _mm256_fmadd_ps(x0,vv,o0); 60 | const __m256 r1 = _mm256_fmadd_ps(x1,vv,o1); 61 | const __m256 r2 = _mm256_fmadd_ps(x2,vv,o2); 62 | const __m256 r3 = _mm256_fmadd_ps(x3,vv,o3); 63 | 64 | _mm256_storeu_ps(O + Oi, r0); 65 | _mm256_storeu_ps(O + Oi + B, r1); 66 | _mm256_storeu_ps(O + Oi + 2 * B, r2); 67 | _mm256_storeu_ps(O + Oi + 3 * B, r3); 68 | } 69 | for (; b < B; b++) { 70 | int Xi = ic * B * M * N + px * N * B + qx * B + b; 71 | int Oi = oc * OM * ON * B + po * ON * B + qo * B + b; 72 | 73 | O[Oi] += X[Xi] * v; 74 | } 75 | } 76 | for (; qo < q_end; qo++, qx++) { 77 | int b = 0; 78 | for (; b < B-7; b+=8) { 79 | int Xi = ic * B * M * N + px * N * B + qx * B + b; 80 | int Oi = oc * OM * ON * B + po * ON * B + qo * B + b; 81 | 82 | 83 | const __m256 o = _mm256_loadu_ps(O + Oi); 84 | const __m256 x = _mm256_loadu_ps(X + Xi); 85 | 86 | const __m256 r = _mm256_fmadd_ps(x,vv,o); 87 | 88 | _mm256_storeu_ps(O + Oi, r); 89 | } 90 | for (; b < B; b++) { 91 | int Xi = ic * B * M * N + px * N * B + qx * B + b; 92 | int Oi = oc * OM * ON * B + po * ON * B + qo * B + b; 93 | 94 | O[Oi] += X[Xi] * v; 95 | } 96 | } 97 | } 98 | } 99 | } 100 | } 101 | } 102 | } 103 | 104 | void sparse_conv2d_vectorized_backward_stride_1(int B, int IC, int OC, int M, int N, int K, int W_nnz, int padding, 105 | float* __restrict__ X, int* __restrict__ W_idx_OC, 106 | int16_t* __restrict__ W_idx_IC, uint8_t* __restrict__ W_idx_X, 107 | uint8_t* __restrict__ W_idx_Y, float* __restrict__ W_val, 108 | float* __restrict__ dLdO, float* __restrict__ dLdX, 109 | float* __restrict__ dLdW_val) { 110 | const int OM = M + 2 * padding - K + 1; 111 | const int ON = N + 2 * padding - K + 1; 112 | 113 | #pragma omp parallel 114 | { 115 | 116 | #pragma omp for reduction(+:dLdW_val[:W_nnz]) 117 | for (int ic = 0; ic < IC; ic++){ 118 | for (int oc = 0; oc < OC; oc++){ 119 | int oc_s = W_idx_OC[oc]; 120 | int ic_s = oc_s + W_idx_IC[(IC + 1) * oc + ic]; 121 | int ic_e = oc_s + W_idx_IC[(IC + 1) * oc + ic + 1]; 122 | 123 | for (int si = ic_s; si < ic_e; si++) { 124 | uint8_t i = W_idx_X[si]; 125 | uint8_t j = W_idx_Y[si]; 126 | 127 | float v = W_val[si]; 128 | __m256 vv = _mm256_set1_ps(v); 129 | __m256 dwv = _mm256_setzero_ps(); 130 | float dw = 0; 131 | 132 | const int pdmi = padding - i; 133 | const int pdmj = padding - j; 134 | const int p_start = std::max(pdmi, 0); 135 | const int p_end = std::min(pdmi + M, OM); 136 | const int q_start = std::max(pdmj, 0); 137 | const int q_end = std::min(pdmj + N, ON); 138 | 139 | for (int po = p_start, px = p_start - padding + i; po < p_end; po++, px++) { 140 | int qo = q_start, qx = q_start - padding + j; 141 | for (; qo < q_end; qo++, qx++) { 142 | int b = 0; 143 | for (; b < B-7; b+=8) { 144 | int Xi = ic * B * M * N + px * N * B + qx * B + b; 145 | int Oi = oc * OM * ON * B + po * ON * B + qo * B + b; 146 | 147 | const __m256 o = _mm256_loadu_ps(dLdO + Oi); 148 | const __m256 x = _mm256_loadu_ps(X + Xi); 149 | const __m256 dx = _mm256_loadu_ps(dLdX + Xi); 150 | 151 | const __m256 r = _mm256_fmadd_ps(o,vv,dx); 152 | dwv = _mm256_fmadd_ps(o, x, dwv); 153 | 154 | _mm256_storeu_ps(dLdX + Xi, r); 155 | } 156 | for (; b < B; b++) { 157 | int Xi = ic * B * M * N + px * N * B + qx * B + b; 158 | int Oi = oc * OM * ON * B + po * ON * B + qo * B + b; 159 | 160 | float o = dLdO[Oi]; 161 | float x = X[Xi]; 162 | 163 | dLdX[Xi] += o * v; 164 | 165 | dw += o * x; 166 | } 167 | } 168 | } 169 | 170 | const __m128 hiQuad0 = _mm256_extractf128_ps(dwv, 1); 171 | const __m128 loQuad0 = _mm256_castps256_ps128(dwv); 172 | const __m128 sumQuad0 = _mm_add_ps(loQuad0, hiQuad0); 173 | const __m128 hiDual0 = _mm_movehl_ps(sumQuad0, sumQuad0); 174 | const __m128 sumDual0 = _mm_add_ps(sumQuad0, hiDual0); 175 | const __m128 hi0 = _mm_shuffle_ps(sumDual0, sumDual0, 0x1); 176 | const __m128 sum0 = _mm_add_ss(sumDual0, hi0); 177 | 178 | dLdW_val[si] += dw + _mm_cvtss_f32(sum0); 179 | } 180 | } 181 | } 182 | } 183 | } 184 | 185 | // ====================================== Stride 2 =========================================== 186 | 187 | void sparse_conv2d_vectorized_forward_stride_2(int B, int IC, int OC, int M, int N, int K, int W_nnz, int padding, 188 | float* __restrict__ X, int* __restrict__ W_idx_OC, 189 | int16_t* __restrict__ W_idx_IC, uint8_t* __restrict__ W_idx_X, 190 | uint8_t* __restrict__ W_idx_Y, float* __restrict__ W_val, 191 | float* __restrict__ O) { 192 | 193 | const int OM = (int) ceil((float) (M + 2 * padding - K + 1) / 2); 194 | const int ON = (int) ceil((float) (N + 2 * padding - K + 1) / 2); 195 | 196 | #pragma omp parallel 197 | { 198 | 199 | #pragma omp for 200 | for (int oc = 0; oc < OC; oc++){ 201 | for (int ic = 0; ic < IC; ic++){ 202 | int oc_s = W_idx_OC[oc]; 203 | int ic_s = oc_s + W_idx_IC[(IC + 1) * oc + ic]; 204 | int ic_e = oc_s + W_idx_IC[(IC + 1) * oc + ic + 1]; 205 | 206 | for (int si = ic_s; si < ic_e; si++) { 207 | uint8_t i = W_idx_X[si]; 208 | uint8_t j = W_idx_Y[si]; 209 | 210 | float v = W_val[si]; 211 | __m256 vv = _mm256_set1_ps(W_val[si]); 212 | 213 | const int pdmi = padding - i; 214 | const int pdmj = padding - j; 215 | const int p_start = std::max((int) ceil((float) pdmi / 2.0), 0); 216 | const int p_end = std::min((int) floor((float) (pdmi + M - 1) / 2) + 1, OM); 217 | const int q_start = std::max((int) ceil((float) pdmj / 2.0), 0); 218 | const int q_end = std::min((int) floor((float) (pdmj + N - 1) / 2.0) + 1, ON); 219 | 220 | for (int po = p_start, px = 2 * p_start - padding + i; po < p_end; po++, px+=2) { 221 | int qo = q_start, qx = 2 * q_start - padding + j; 222 | for (; qo < q_end-3; qo+=4, qx+=8) { 223 | int b = 0; 224 | for (; b < B-7; b+=8) { 225 | int Xi = ic * B * M * N + px * N * B + qx * B + b; 226 | int Oi = oc * OM * ON * B + po * ON * B + qo * B + b; 227 | 228 | 229 | const __m256 o0 = _mm256_loadu_ps(O + Oi); 230 | const __m256 o1 = _mm256_loadu_ps(O + Oi + B); 231 | const __m256 o2 = _mm256_loadu_ps(O + Oi + 2 * B); 232 | const __m256 o3 = _mm256_loadu_ps(O + Oi + 3 * B); 233 | const __m256 x0 = _mm256_loadu_ps(X + Xi); 234 | const __m256 x1 = _mm256_loadu_ps(X + Xi + 2 * B); 235 | const __m256 x2 = _mm256_loadu_ps(X + Xi + 4 * B); 236 | const __m256 x3 = _mm256_loadu_ps(X + Xi + 6 * B); 237 | 238 | const __m256 r0 = _mm256_fmadd_ps(x0,vv,o0); 239 | const __m256 r1 = _mm256_fmadd_ps(x1,vv,o1); 240 | const __m256 r2 = _mm256_fmadd_ps(x2,vv,o2); 241 | const __m256 r3 = _mm256_fmadd_ps(x3,vv,o3); 242 | 243 | _mm256_storeu_ps(O + Oi, r0); 244 | _mm256_storeu_ps(O + Oi + B, r1); 245 | _mm256_storeu_ps(O + Oi + 2 * B, r2); 246 | _mm256_storeu_ps(O + Oi + 3 * B, r3); 247 | } 248 | for (; b < B; b++) { 249 | int Xi = ic * B * M * N + px * N * B + qx * B + b; 250 | int Oi = oc * OM * ON * B + po * ON * B + qo * B + b; 251 | 252 | O[Oi] += X[Xi] * v; 253 | } 254 | } 255 | 256 | for (; qo < q_end; qo++, qx+=2) { 257 | int b = 0; 258 | for (; b < B-7; b+=8) { 259 | int Xi = ic * B * M * N + px * N * B + qx * B + b; 260 | int Oi = oc * OM * ON * B + po * ON * B + qo * B + b; 261 | 262 | 263 | const __m256 o = _mm256_loadu_ps(O + Oi); 264 | const __m256 x = _mm256_loadu_ps(X + Xi); 265 | 266 | const __m256 r = _mm256_fmadd_ps(x,vv,o); 267 | 268 | _mm256_storeu_ps(O + Oi, r); 269 | } 270 | for (; b < B; b++) { 271 | int Xi = ic * B * M * N + px * N * B + qx * B + b; 272 | int Oi = oc * OM * ON * B + po * ON * B + qo * B + b; 273 | 274 | O[Oi] += X[Xi] * v; 275 | } 276 | } 277 | } 278 | } 279 | } 280 | } 281 | } 282 | } 283 | 284 | void sparse_conv2d_vectorized_backward_stride_2(int B, int IC, int OC, int M, int N, int K, int W_nnz, int padding, 285 | float* __restrict__ X, int* __restrict__ W_idx_OC, 286 | int16_t* __restrict__ W_idx_IC, uint8_t* __restrict__ W_idx_X, 287 | uint8_t* __restrict__ W_idx_Y, float* __restrict__ W_val, 288 | float* __restrict__ dLdO, float* __restrict__ dLdX, 289 | float* __restrict__ dLdW_val) { 290 | 291 | const int OM = (int) ceil((float) (M + 2 * padding - K + 1) / 2); 292 | const int ON = (int) ceil((float) (N + 2 * padding - K + 1) / 2); 293 | 294 | #pragma omp parallel 295 | { 296 | 297 | #pragma omp for reduction(+:dLdW_val[:W_nnz]) 298 | for (int ic = 0; ic < IC; ic++){ 299 | for (int oc = 0; oc < OC; oc++){ 300 | int oc_s = W_idx_OC[oc]; 301 | int ic_s = oc_s + W_idx_IC[(IC + 1) * oc + ic]; 302 | int ic_e = oc_s + W_idx_IC[(IC + 1) * oc + ic + 1]; 303 | 304 | for (int si = ic_s; si < ic_e; si++) { 305 | uint8_t i = W_idx_X[si]; 306 | uint8_t j = W_idx_Y[si]; 307 | 308 | float v = W_val[si]; 309 | __m256 vv = _mm256_set1_ps(v); 310 | __m256 dwv = _mm256_setzero_ps(); 311 | float dw = 0; 312 | 313 | const int pdmi = padding - i; 314 | const int pdmj = padding - j; 315 | const int p_start = std::max((int) ceil((float) pdmi / 2.0), 0); 316 | const int p_end = std::min((int) floor((float) (pdmi + M - 1) / 2) + 1, OM); 317 | const int q_start = std::max((int) ceil((float) pdmj / 2.0), 0); 318 | const int q_end = std::min((int) floor((float) (pdmj + N - 1) / 2.0) + 1, ON); 319 | 320 | 321 | for (int po = p_start, px = 2 * p_start - padding + i; po < p_end; po++, px+=2) { 322 | int qo = q_start, qx = 2 * q_start - padding + j; 323 | for (; qo < q_end; qo++, qx+=2) { 324 | int b = 0; 325 | for (; b < B-7; b+=8) { 326 | int Xi = ic * B * M * N + px * N * B + qx * B + b; 327 | int Oi = oc * OM * ON * B + po * ON * B + qo * B + b; 328 | 329 | const __m256 o = _mm256_loadu_ps(dLdO + Oi); 330 | const __m256 x = _mm256_loadu_ps(X + Xi); 331 | const __m256 dx = _mm256_loadu_ps(dLdX + Xi); 332 | 333 | const __m256 r = _mm256_fmadd_ps(o,vv,dx); 334 | dwv = _mm256_fmadd_ps(o, x, dwv); 335 | 336 | _mm256_storeu_ps(dLdX + Xi, r); 337 | } 338 | for (; b < B; b++) { 339 | int Xi = ic * B * M * N + px * N * B + qx * B + b; 340 | int Oi = oc * OM * ON * B + po * ON * B + qo * B + b; 341 | 342 | float o = dLdO[Oi]; 343 | float x = X[Xi]; 344 | 345 | dLdX[Xi] += o * v; 346 | 347 | dw += o * x; 348 | } 349 | } 350 | } 351 | 352 | const __m128 hiQuad0 = _mm256_extractf128_ps(dwv, 1); 353 | const __m128 loQuad0 = _mm256_castps256_ps128(dwv); 354 | const __m128 sumQuad0 = _mm_add_ps(loQuad0, hiQuad0); 355 | const __m128 hiDual0 = _mm_movehl_ps(sumQuad0, sumQuad0); 356 | const __m128 sumDual0 = _mm_add_ps(sumQuad0, hiDual0); 357 | const __m128 hi0 = _mm_shuffle_ps(sumDual0, sumDual0, 0x1); 358 | const __m128 sum0 = _mm_add_ss(sumDual0, hi0); 359 | 360 | dLdW_val[si] += dw + _mm_cvtss_f32(sum0); 361 | } 362 | } 363 | } 364 | } 365 | } 366 | 367 | // ====================================== Wrappers =========================================== 368 | 369 | void sparse_conv2d_vectorized_forward_stride_1_wrapper(py::array_t X, py::array_t W_idx_OC, 370 | py::array_t W_idx_IC, py::array_t W_idx_X, 371 | py::array_t W_idx_Y, py::array_t W_val, 372 | py::array_t O, int kernel_size, int padding) { 373 | int B = X.shape()[3]; 374 | int IC = X.shape()[0]; 375 | int M = X.shape()[1]; 376 | int N = X.shape()[2]; 377 | int OC = O.shape()[0]; 378 | int W_nnz = W_val.shape()[0]; 379 | int K = kernel_size; 380 | 381 | auto buf_X = X.request(); 382 | auto buf_W_idx_OC = W_idx_OC.request(); 383 | auto buf_W_idx_IC = W_idx_IC.request(); 384 | auto buf_W_idx_X = W_idx_X.request(); 385 | auto buf_W_idx_Y = W_idx_Y.request(); 386 | auto buf_W_val = W_val.request(); 387 | auto buf_O = O.request(); 388 | 389 | float* ptr_X = (float*) buf_X.ptr; 390 | int* ptr_W_idx_OC = (int*) buf_W_idx_OC.ptr; 391 | int16_t* ptr_W_idx_IC = (int16_t*) buf_W_idx_IC.ptr; 392 | uint8_t* ptr_W_idx_X = (uint8_t*) buf_W_idx_X.ptr; 393 | uint8_t* ptr_W_idx_Y = (uint8_t*) buf_W_idx_Y.ptr; 394 | float* ptr_W_val = (float*) buf_W_val.ptr; 395 | float* ptr_O = (float*) buf_O.ptr; 396 | 397 | sparse_conv2d_vectorized_forward_stride_1(B, IC, OC, M, N, K, W_nnz, padding, ptr_X, ptr_W_idx_OC, ptr_W_idx_IC, 398 | ptr_W_idx_X, ptr_W_idx_Y, ptr_W_val, ptr_O); 399 | } 400 | 401 | void sparse_conv2d_vectorized_backward_stride_1_wrapper(py::array_t X, py::array_t W_idx_OC, 402 | py::array_t W_idx_IC, py::array_t W_idx_X, 403 | py::array_t W_idx_Y, py::array_t W_val, 404 | py::array_t dLdO, py::array_t dLdX, 405 | py::array_t dLdW_val, int kernel_size, int padding) { 406 | int B = X.shape()[3]; 407 | int IC = X.shape()[0]; 408 | int M = X.shape()[1]; 409 | int N = X.shape()[2]; 410 | int OC = dLdO.shape()[0]; 411 | int W_nnz = W_val.shape()[0]; 412 | int K = kernel_size; 413 | 414 | auto buf_X = X.request(); 415 | auto buf_W_idx_OC = W_idx_OC.request(); 416 | auto buf_W_idx_IC = W_idx_IC.request(); 417 | auto buf_W_idx_X = W_idx_X.request(); 418 | auto buf_W_idx_Y = W_idx_Y.request(); 419 | auto buf_W_val = W_val.request(); 420 | auto buf_dLdO = dLdO.request(); 421 | auto buf_dLdX = dLdX.request(); 422 | auto buf_dLdW_val = dLdW_val.request(); 423 | 424 | float* ptr_X = (float*) buf_X.ptr; 425 | int* ptr_W_idx_OC = (int*) buf_W_idx_OC.ptr; 426 | int16_t* ptr_W_idx_IC = (int16_t*) buf_W_idx_IC.ptr; 427 | uint8_t* ptr_W_idx_X = (uint8_t*) buf_W_idx_X.ptr; 428 | uint8_t* ptr_W_idx_Y = (uint8_t*) buf_W_idx_Y.ptr; 429 | float* ptr_W_val = (float*) buf_W_val.ptr; 430 | float* ptr_dLdO = (float*) buf_dLdO.ptr; 431 | float* ptr_dLdX = (float*) buf_dLdX.ptr; 432 | float* ptr_dLdW_val = (float*) buf_dLdW_val.ptr; 433 | 434 | sparse_conv2d_vectorized_backward_stride_1(B, IC, OC, M, N, K, W_nnz, padding, ptr_X, ptr_W_idx_OC, ptr_W_idx_IC, 435 | ptr_W_idx_X, ptr_W_idx_Y, ptr_W_val, ptr_dLdO, ptr_dLdX, ptr_dLdW_val); 436 | } 437 | 438 | void sparse_conv2d_vectorized_forward_stride_2_wrapper(py::array_t X, py::array_t W_idx_OC, 439 | py::array_t W_idx_IC, py::array_t W_idx_X, 440 | py::array_t W_idx_Y, py::array_t W_val, 441 | py::array_t O, int kernel_size, int padding) { 442 | int B = X.shape()[3]; 443 | int IC = X.shape()[0]; 444 | int M = X.shape()[1]; 445 | int N = X.shape()[2]; 446 | int OC = O.shape()[0]; 447 | int W_nnz = W_val.shape()[0]; 448 | int K = kernel_size; 449 | 450 | auto buf_X = X.request(); 451 | auto buf_W_idx_OC = W_idx_OC.request(); 452 | auto buf_W_idx_IC = W_idx_IC.request(); 453 | auto buf_W_idx_X = W_idx_X.request(); 454 | auto buf_W_idx_Y = W_idx_Y.request(); 455 | auto buf_W_val = W_val.request(); 456 | auto buf_O = O.request(); 457 | 458 | float* ptr_X = (float*) buf_X.ptr; 459 | int* ptr_W_idx_OC = (int*) buf_W_idx_OC.ptr; 460 | int16_t* ptr_W_idx_IC = (int16_t*) buf_W_idx_IC.ptr; 461 | uint8_t* ptr_W_idx_X = (uint8_t*) buf_W_idx_X.ptr; 462 | uint8_t* ptr_W_idx_Y = (uint8_t*) buf_W_idx_Y.ptr; 463 | float* ptr_W_val = (float*) buf_W_val.ptr; 464 | float* ptr_O = (float*) buf_O.ptr; 465 | 466 | sparse_conv2d_vectorized_forward_stride_2(B, IC, OC, M, N, K, W_nnz, padding, ptr_X, ptr_W_idx_OC, ptr_W_idx_IC, 467 | ptr_W_idx_X, ptr_W_idx_Y, ptr_W_val, ptr_O); 468 | } 469 | 470 | void sparse_conv2d_vectorized_backward_stride_2_wrapper(py::array_t X, py::array_t W_idx_OC, 471 | py::array_t W_idx_IC, py::array_t W_idx_X, 472 | py::array_t W_idx_Y, py::array_t W_val, 473 | py::array_t dLdO, py::array_t dLdX, 474 | py::array_t dLdW_val, int kernel_size, int padding) { 475 | int B = X.shape()[3]; 476 | int IC = X.shape()[0]; 477 | int M = X.shape()[1]; 478 | int N = X.shape()[2]; 479 | int OC = dLdO.shape()[0]; 480 | int W_nnz = W_val.shape()[0]; 481 | int K = kernel_size; 482 | 483 | auto buf_X = X.request(); 484 | auto buf_W_idx_OC = W_idx_OC.request(); 485 | auto buf_W_idx_IC = W_idx_IC.request(); 486 | auto buf_W_idx_X = W_idx_X.request(); 487 | auto buf_W_idx_Y = W_idx_Y.request(); 488 | auto buf_W_val = W_val.request(); 489 | auto buf_dLdO = dLdO.request(); 490 | auto buf_dLdX = dLdX.request(); 491 | auto buf_dLdW_val = dLdW_val.request(); 492 | 493 | float* ptr_X = (float*) buf_X.ptr; 494 | int* ptr_W_idx_OC = (int*) buf_W_idx_OC.ptr; 495 | int16_t* ptr_W_idx_IC = (int16_t*) buf_W_idx_IC.ptr; 496 | uint8_t* ptr_W_idx_X = (uint8_t*) buf_W_idx_X.ptr; 497 | uint8_t* ptr_W_idx_Y = (uint8_t*) buf_W_idx_Y.ptr; 498 | float* ptr_W_val = (float*) buf_W_val.ptr; 499 | float* ptr_dLdO = (float*) buf_dLdO.ptr; 500 | float* ptr_dLdX = (float*) buf_dLdX.ptr; 501 | float* ptr_dLdW_val = (float*) buf_dLdW_val.ptr; 502 | 503 | sparse_conv2d_vectorized_backward_stride_2(B, IC, OC, M, N, K, W_nnz, padding, ptr_X, ptr_W_idx_OC, ptr_W_idx_IC, 504 | ptr_W_idx_X, ptr_W_idx_Y, ptr_W_val, ptr_dLdO, ptr_dLdX, ptr_dLdW_val); 505 | } -------------------------------------------------------------------------------- /sparseprop/lib/sparse_conv2d_over_on.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | namespace py = pybind11; 7 | 8 | // ====================================== Stride 1 =========================================== 9 | 10 | void sparse_conv2d_vectorized_forward_over_on_stride_1(int B, int IC, int OC, int M, int N, int K, int W_nnz, int padding, 11 | float* __restrict__ X, int* __restrict__ W_idx_OC, 12 | int16_t* __restrict__ W_idx_IC, uint8_t* __restrict__ W_idx_X, 13 | uint8_t* __restrict__ W_idx_Y, float* __restrict__ W_val, 14 | float* __restrict__ O) { 15 | const int OM = M + 2 * padding - K + 1; 16 | const int ON = N + 2 * padding - K + 1; 17 | 18 | #pragma omp parallel for collapse(2) 19 | for (int b = 0; b < B; b++) { 20 | for (int oc = 0; oc < OC; oc++){ 21 | for (int ic = 0; ic < IC; ic++){ 22 | int oc_s = W_idx_OC[oc]; 23 | int ic_s = oc_s + W_idx_IC[(IC + 1) * oc + ic]; 24 | int ic_e = oc_s + W_idx_IC[(IC + 1) * oc + ic + 1]; 25 | 26 | for (int si = ic_s; si < ic_e; si++) { 27 | uint8_t i = W_idx_X[si]; 28 | uint8_t j = W_idx_Y[si]; 29 | 30 | float v = W_val[si]; 31 | __m256 vv = _mm256_set1_ps(v); 32 | 33 | int pdmi = padding - i; 34 | int p_start = pdmi; 35 | if (p_start < 0) p_start = 0; 36 | int p_end = pdmi + M; 37 | if (p_end > OM) p_end = OM; 38 | 39 | int pdmj = padding - j; 40 | int q_start = pdmj; 41 | if (q_start < 0) q_start = 0; 42 | int q_end = pdmj + N ; 43 | if (q_end > ON) q_end = ON; 44 | int q_end_div8 = q_end - ((q_end - q_start) % 8); 45 | 46 | for (int p = p_start; p < p_end; p++) { 47 | for (int q = q_start; q < q_end_div8; q+=8) { 48 | int Xi = b * IC * M * N + ic * M * N + (-padding + p + i) * N + (-padding + q + j); 49 | int Oi = b * OC * OM * ON + oc * OM * ON + p * ON + q; 50 | 51 | __m256 xv = _mm256_loadu_ps(X + Xi); 52 | 53 | __m256 ov = _mm256_loadu_ps(O + Oi); 54 | ov = _mm256_fmadd_ps(xv, vv, ov); 55 | _mm256_storeu_ps(O + Oi, ov); 56 | } 57 | 58 | for (int q = q_end_div8; q < q_end; q++) { 59 | int Xi = b * IC * M * N + ic * M * N + (-padding + p + i) * N + (-padding + q + j); 60 | int Oi = b * OC * OM * ON + oc * OM * ON + p * ON + q; 61 | 62 | float x = X[Xi]; 63 | 64 | O[Oi] += x * v; 65 | } 66 | } 67 | } 68 | } 69 | } 70 | } 71 | } 72 | 73 | void sparse_conv2d_vectorized_backward_over_on_stride_1(int B, int IC, int OC, int M, int N, int K, int W_nnz, int padding, 74 | float* __restrict__ X, int* __restrict__ W_idx_OC, 75 | int16_t* __restrict__ W_idx_IC, uint8_t* __restrict__ W_idx_X, 76 | uint8_t* __restrict__ W_idx_Y, float* __restrict__ W_val, 77 | float* __restrict__ dLdO, float* __restrict__ dLdX, 78 | float* __restrict__ dLdW_val) { 79 | const int OM = M + 2 * padding - K + 1; 80 | const int ON = N + 2 * padding - K + 1; 81 | 82 | #pragma omp parallel for collapse(2) 83 | for (int b = 0; b < B; b++) { 84 | for (int ic = 0; ic < IC; ic++){ 85 | for (int oc = 0; oc < OC; oc++){ 86 | int oc_s = W_idx_OC[oc]; 87 | int ic_s = oc_s + W_idx_IC[(IC + 1) * oc + ic]; 88 | int ic_e = oc_s + W_idx_IC[(IC + 1) * oc + ic + 1]; 89 | 90 | for (int si = ic_s; si < ic_e; si++) { 91 | uint8_t i = W_idx_X[si]; 92 | uint8_t j = W_idx_Y[si]; 93 | 94 | float v = W_val[si]; 95 | __m256 vv = _mm256_set1_ps(v); 96 | __m256 dwv = _mm256_setzero_ps(); 97 | float dw = 0; 98 | 99 | int pdmi = padding - i; 100 | int p_start = pdmi; 101 | if (p_start < 0) p_start = 0; 102 | int p_end = pdmi + M; 103 | if (p_end > OM) p_end = OM; 104 | 105 | int pdmj = padding - j; 106 | int q_start = pdmj; 107 | if (q_start < 0) q_start = 0; 108 | int q_end = pdmj + N; 109 | if (q_end > ON) q_end = ON; 110 | int q_end_div8 = q_end - ((q_end - q_start) % 8); 111 | 112 | 113 | for (int p = p_start; p < p_end; p++) { 114 | for (int q = q_start; q < q_end_div8; q+=8) { 115 | int Xi = b * IC * M * N + ic * M * N + (-padding + p + i) * N + (-padding + q + j); 116 | int Oi = b * OC * OM * ON + oc * OM * ON + p * ON + q; 117 | 118 | __m256 ov = _mm256_loadu_ps(dLdO + Oi); 119 | __m256 xv = _mm256_loadu_ps(X + Xi); 120 | 121 | dwv = _mm256_fmadd_ps(ov, xv, dwv); 122 | 123 | __m256 dxv = _mm256_loadu_ps(dLdX + Xi); 124 | dxv = _mm256_fmadd_ps(ov, vv, dxv); 125 | _mm256_storeu_ps(dLdX + Xi, dxv); 126 | } 127 | 128 | // handle the end of the row 129 | for (int q = q_end_div8; q < q_end; q++) { 130 | int Xi = b * IC * M * N + ic * M * N + (-padding + p + i) * N + (-padding + q + j); 131 | int Oi = b * OC * OM * ON + oc * OM * ON + p * ON + q; 132 | 133 | float o = dLdO[Oi]; 134 | float x = X[Xi]; 135 | 136 | dLdX[Xi] += o * v; 137 | 138 | dw += o * x; 139 | } 140 | } 141 | 142 | dwv = _mm256_hadd_ps(dwv, dwv); 143 | dwv = _mm256_hadd_ps(dwv, dwv); 144 | dw += _mm_cvtss_f32(_mm_add_ss(_mm256_castps256_ps128(dwv), _mm256_extractf128_ps(dwv, 1))); 145 | 146 | #pragma omp atomic 147 | dLdW_val[si] += dw; 148 | } 149 | } 150 | } 151 | } 152 | } 153 | 154 | // ====================================== Stride 2 =========================================== 155 | 156 | void sparse_conv2d_vectorized_backward_over_on_stride_2(int B, int IC, int OC, int M, int N, int K, int W_nnz, int padding, 157 | float* __restrict__ X, int* __restrict__ W_idx_OC, 158 | int16_t* __restrict__ W_idx_IC, uint8_t* __restrict__ W_idx_X, 159 | uint8_t* __restrict__ W_idx_Y, float* __restrict__ W_val, 160 | float* __restrict__ dLdO, float* __restrict__ dLdX, 161 | float* __restrict__ dLdW_val) { 162 | int OM = (int) ceil((float) (M + 2 * padding - K + 1) / 2); 163 | int ON = (int) ceil((float) (N + 2 * padding - K + 1) / 2); 164 | 165 | __m256i permutevar8x32_mask = _mm256_set_epi32(7, 6, 3, 2, 5, 4, 1, 0); 166 | __m256 zv = _mm256_setzero_ps(); 167 | __m256i permuteamask = _mm256_set_epi32(7,7,7,7,6,4,2,0); 168 | 169 | 170 | #pragma omp parallel for collapse(2) 171 | for (int b = 0; b < B; b++) { 172 | for (int ic = 0; ic < IC; ic++){ 173 | for (int oc = 0; oc < OC; oc++){ 174 | int oc_s = W_idx_OC[oc]; 175 | int ic_s = oc_s + W_idx_IC[(IC + 1) * oc + ic]; 176 | int ic_e = oc_s + W_idx_IC[(IC + 1) * oc + ic + 1]; 177 | 178 | for (int si = ic_s; si < ic_e; si++) { 179 | uint8_t i = W_idx_X[si]; 180 | uint8_t j = W_idx_Y[si]; 181 | 182 | float v = W_val[si]; 183 | __m256 v0v = _mm256_set_ps(0., v, 0., v, 0., v, 0., v); 184 | __m256 dwv = _mm256_setzero_ps(); 185 | float dw = 0; 186 | 187 | int pdmi = padding - i; 188 | int p_start = (int) ceil((float) pdmi / 2); 189 | if (p_start < 0) p_start = 0; 190 | int p_end = (int) floor((float) (pdmi + M - 1) / 2) + 1; 191 | if (p_end > OM) p_end = OM; 192 | 193 | int pdmj = padding - j; 194 | int q_start = (int) ceil((float) pdmj / 2); 195 | if (q_start < 0) q_start = 0; 196 | int q_end = (int) floor((float) (pdmj + N - 1) / 2) + 1; 197 | if (q_end > ON) q_end = ON; 198 | int q_end_div8 = q_end - ((q_end - q_start) % 8); 199 | 200 | 201 | for (int p = p_start; p < p_end; p++) { 202 | for (int q = q_start; q < q_end_div8; q+=8) { 203 | int Xi = b * IC * M * N + ic * M * N + (-padding + 2 * p + i) * N + (-padding + 2 * q + j); 204 | int Oi = b * OC * OM * ON + oc * OM * ON + p * ON + q; 205 | 206 | __m256 ov = _mm256_loadu_ps(dLdO + Oi); 207 | 208 | __m256 a = _mm256_loadu_ps(X + Xi); 209 | __m256 b = _mm256_loadu_ps(X + Xi + 8); 210 | __m256 ap = _mm256_permutevar8x32_ps(a, permuteamask); 211 | __m256 bp = _mm256_permutevar8x32_ps(b, permuteamask); 212 | __m256 xv = _mm256_insertf128_ps(ap, _mm256_castps256_ps128(bp), 1); 213 | 214 | dwv = _mm256_fmadd_ps(ov, xv, dwv); 215 | 216 | ov = _mm256_permutevar8x32_ps(ov, permutevar8x32_mask); 217 | 218 | __m256 ov0 = _mm256_unpacklo_ps(ov, zv); 219 | __m256 ov1 = _mm256_unpackhi_ps(ov, zv); 220 | 221 | __m256 dxv0 = _mm256_loadu_ps(dLdX + Xi); 222 | __m256 dxv1 = _mm256_loadu_ps(dLdX + Xi + 8); 223 | 224 | dxv0 = _mm256_fmadd_ps(ov0, v0v, dxv0); 225 | dxv1 = _mm256_fmadd_ps(ov1, v0v, dxv1); 226 | 227 | _mm256_storeu_ps(dLdX + Xi, dxv0); 228 | _mm256_storeu_ps(dLdX + Xi + 8, dxv1); 229 | } 230 | 231 | // handle the end of the row 232 | for (int q = q_end_div8; q < q_end; q++) { 233 | int Xi = b * IC * M * N + ic * M * N + (-padding + 2 * p + i) * N + (-padding + 2 * q + j); 234 | int Oi = b * OC * OM * ON + oc * OM * ON + p * ON + q; 235 | 236 | float o = dLdO[Oi]; 237 | float x = X[Xi]; 238 | 239 | dLdX[Xi] += o * v; 240 | 241 | dw += o * x; 242 | } 243 | } 244 | 245 | dwv = _mm256_hadd_ps(dwv, dwv); 246 | dwv = _mm256_hadd_ps(dwv, dwv); 247 | dw += _mm_cvtss_f32(_mm_add_ss(_mm256_castps256_ps128(dwv), _mm256_extractf128_ps(dwv, 1))); 248 | 249 | #pragma omp atomic 250 | dLdW_val[si] += dw; 251 | } 252 | } 253 | } 254 | } 255 | } 256 | 257 | // ====================================== Wrappers =========================================== 258 | 259 | void sparse_conv2d_vectorized_forward_over_on_stride_1_wrapper(py::array_t X, py::array_t W_idx_OC, 260 | py::array_t W_idx_IC, py::array_t W_idx_X, 261 | py::array_t W_idx_Y, py::array_t W_val, 262 | py::array_t O, int kernel_size, int padding) { 263 | int B = X.shape()[0]; 264 | int IC = X.shape()[1]; 265 | int M = X.shape()[2]; 266 | int N = X.shape()[3]; 267 | int OC = O.shape()[1]; 268 | int W_nnz = W_val.shape()[0]; 269 | int K = kernel_size; 270 | 271 | auto buf_X = X.request(); 272 | auto buf_W_idx_OC = W_idx_OC.request(); 273 | auto buf_W_idx_IC = W_idx_IC.request(); 274 | auto buf_W_idx_X = W_idx_X.request(); 275 | auto buf_W_idx_Y = W_idx_Y.request(); 276 | auto buf_W_val = W_val.request(); 277 | auto buf_O = O.request(); 278 | 279 | float* ptr_X = (float*) buf_X.ptr; 280 | int* ptr_W_idx_OC = (int*) buf_W_idx_OC.ptr; 281 | int16_t* ptr_W_idx_IC = (int16_t*) buf_W_idx_IC.ptr; 282 | uint8_t* ptr_W_idx_X = (uint8_t*) buf_W_idx_X.ptr; 283 | uint8_t* ptr_W_idx_Y = (uint8_t*) buf_W_idx_Y.ptr; 284 | float* ptr_W_val = (float*) buf_W_val.ptr; 285 | float* ptr_O = (float*) buf_O.ptr; 286 | 287 | sparse_conv2d_vectorized_forward_over_on_stride_1(B, IC, OC, M, N, K, W_nnz, padding, ptr_X, ptr_W_idx_OC, ptr_W_idx_IC, ptr_W_idx_X, ptr_W_idx_Y, ptr_W_val, ptr_O); 288 | } 289 | 290 | void sparse_conv2d_vectorized_backward_over_on_stride_1_wrapper(py::array_t X, py::array_t W_idx_OC, 291 | py::array_t W_idx_IC, py::array_t W_idx_X, 292 | py::array_t W_idx_Y, py::array_t W_val, 293 | py::array_t dLdO, py::array_t dLdX, 294 | py::array_t dLdW_val, int kernel_size, int padding) { 295 | int B = X.shape()[0]; 296 | int IC = X.shape()[1]; 297 | int M = X.shape()[2]; 298 | int N = X.shape()[3]; 299 | int OC = dLdO.shape()[1]; 300 | int W_nnz = W_val.shape()[0]; 301 | int K = kernel_size; 302 | 303 | auto buf_X = X.request(); 304 | auto buf_W_idx_OC = W_idx_OC.request(); 305 | auto buf_W_idx_IC = W_idx_IC.request(); 306 | auto buf_W_idx_X = W_idx_X.request(); 307 | auto buf_W_idx_Y = W_idx_Y.request(); 308 | auto buf_W_val = W_val.request(); 309 | auto buf_dLdO = dLdO.request(); 310 | auto buf_dLdX = dLdX.request(); 311 | auto buf_dLdW_val = dLdW_val.request(); 312 | 313 | float* ptr_X = (float*) buf_X.ptr; 314 | int* ptr_W_idx_OC = (int*) buf_W_idx_OC.ptr; 315 | int16_t* ptr_W_idx_IC = (int16_t*) buf_W_idx_IC.ptr; 316 | uint8_t* ptr_W_idx_X = (uint8_t*) buf_W_idx_X.ptr; 317 | uint8_t* ptr_W_idx_Y = (uint8_t*) buf_W_idx_Y.ptr; 318 | float* ptr_W_val = (float*) buf_W_val.ptr; 319 | float* ptr_dLdO = (float*) buf_dLdO.ptr; 320 | float* ptr_dLdX = (float*) buf_dLdX.ptr; 321 | float* ptr_dLdW_val = (float*) buf_dLdW_val.ptr; 322 | 323 | sparse_conv2d_vectorized_backward_over_on_stride_1(B, IC, OC, M, N, K, W_nnz, padding, ptr_X, ptr_W_idx_OC, ptr_W_idx_IC, ptr_W_idx_X, ptr_W_idx_Y, ptr_W_val, ptr_dLdO, ptr_dLdX, ptr_dLdW_val); 324 | } 325 | 326 | void sparse_conv2d_vectorized_backward_over_on_stride_2_wrapper(py::array_t X, py::array_t W_idx_OC, 327 | py::array_t W_idx_IC, py::array_t W_idx_X, 328 | py::array_t W_idx_Y, py::array_t W_val, 329 | py::array_t dLdO, py::array_t dLdX, 330 | py::array_t dLdW_val, int kernel_size, int padding) { 331 | int B = X.shape()[0]; 332 | int IC = X.shape()[1]; 333 | int M = X.shape()[2]; 334 | int N = X.shape()[3]; 335 | int OC = dLdO.shape()[1]; 336 | int W_nnz = W_val.shape()[0]; 337 | int K = kernel_size; 338 | 339 | auto buf_X = X.request(); 340 | auto buf_W_idx_OC = W_idx_OC.request(); 341 | auto buf_W_idx_IC = W_idx_IC.request(); 342 | auto buf_W_idx_X = W_idx_X.request(); 343 | auto buf_W_idx_Y = W_idx_Y.request(); 344 | auto buf_W_val = W_val.request(); 345 | auto buf_dLdO = dLdO.request(); 346 | auto buf_dLdX = dLdX.request(); 347 | auto buf_dLdW_val = dLdW_val.request(); 348 | 349 | float* ptr_X = (float*) buf_X.ptr; 350 | int* ptr_W_idx_OC = (int*) buf_W_idx_OC.ptr; 351 | int16_t* ptr_W_idx_IC = (int16_t*) buf_W_idx_IC.ptr; 352 | uint8_t* ptr_W_idx_X = (uint8_t*) buf_W_idx_X.ptr; 353 | uint8_t* ptr_W_idx_Y = (uint8_t*) buf_W_idx_Y.ptr; 354 | float* ptr_W_val = (float*) buf_W_val.ptr; 355 | float* ptr_dLdO = (float*) buf_dLdO.ptr; 356 | float* ptr_dLdX = (float*) buf_dLdX.ptr; 357 | float* ptr_dLdW_val = (float*) buf_dLdW_val.ptr; 358 | 359 | sparse_conv2d_vectorized_backward_over_on_stride_2(B, IC, OC, M, N, K, W_nnz, padding, ptr_X, ptr_W_idx_OC, ptr_W_idx_IC, ptr_W_idx_X, ptr_W_idx_Y, ptr_W_val, ptr_dLdO, ptr_dLdX, ptr_dLdW_val); 360 | } -------------------------------------------------------------------------------- /sparseprop/lib/sparse_linear.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | namespace py = pybind11; 7 | 8 | void sparse_linear_vectorized_forward(int B, int M, int N, int W_nnz, float* X, int* W_idx_N, int* W_idx_M, 9 | float* W_val, float* O) { 10 | #pragma omp parallel 11 | { 12 | #pragma omp for 13 | for(int i = 0; i < N; i++){ 14 | int k = W_idx_N[i]; 15 | for(; k < W_idx_N[i+1]; k++){ 16 | int idx = W_idx_M[k]; 17 | __m256 v = _mm256_set1_ps(W_val[k]); 18 | int j = 0; 19 | 20 | for(; j < B-7; j+=8){ 21 | __m256 x = _mm256_loadu_ps(X + (idx * B + j)); 22 | __m256 o = _mm256_loadu_ps(O + (i * B + j)); 23 | 24 | __m256 r = _mm256_fmadd_ps(x,v,o); 25 | 26 | _mm256_storeu_ps(O + (i * B + j), r); 27 | } 28 | 29 | for(; j < B; j++){ 30 | O[i * B + j] += W_val[k] * X[idx * B + j]; 31 | } 32 | } 33 | } 34 | } 35 | } 36 | 37 | void sparse_linear_vectorized_backward(int B, int M, int N, int W_nnz, float* X, int* W_idx_N, int* W_idx_M,float* W_val, 38 | float* dLdO, float* dLdX, float* dLdW_val) { 39 | #pragma omp parallel 40 | { 41 | for(int i = 0; i < N; i++){ 42 | #pragma omp for 43 | for(int j = W_idx_N[i]; j < W_idx_N[i+1]; j++){ 44 | int r = W_idx_M[j]; 45 | float sv = W_val[j]; 46 | __m256 v = _mm256_set1_ps(W_val[j ]); 47 | float sacc = 0; 48 | __m256 acc = _mm256_setzero_ps(); 49 | 50 | int k = 0; 51 | for(; k < B-7; k+=8){ 52 | __m256 dx0 = _mm256_loadu_ps(dLdX + (r * B + k)); 53 | __m256 x0 = _mm256_loadu_ps(X + (r * B + k)); 54 | __m256 do0 = _mm256_loadu_ps(dLdO + (i * B + k)); 55 | __m256 s0 = _mm256_fmadd_ps(v, do0, dx0); 56 | acc = _mm256_fmadd_ps(do0,x0,acc); 57 | _mm256_storeu_ps(dLdX + (r * B + k), s0); 58 | } 59 | 60 | //cleanup 61 | for(; k < B; k++){ 62 | dLdX[r*B+k] += sv * dLdO[i * B + k]; 63 | sacc += dLdO[i*B + k] * X[r*B + k]; 64 | } 65 | 66 | //reduce sum 67 | const __m128 hiQuad0 = _mm256_extractf128_ps(acc, 1); 68 | const __m128 loQuad0 = _mm256_castps256_ps128(acc); 69 | const __m128 sumQuad0 = _mm_add_ps(loQuad0, hiQuad0); 70 | const __m128 hiDual0 = _mm_movehl_ps(sumQuad0, sumQuad0); 71 | const __m128 sumDual0 = _mm_add_ps(sumQuad0, hiDual0); 72 | const __m128 hi0 = _mm_shuffle_ps(sumDual0, sumDual0, 0x1); 73 | const __m128 sum0 = _mm_add_ss(sumDual0, hi0); 74 | 75 | dLdW_val[j] = sacc + _mm_cvtss_f32(sum0); 76 | } 77 | } 78 | } 79 | } 80 | 81 | // ====================================== Wrappers =========================================== 82 | 83 | void sparse_linear_vectorized_forward_wrapper(py::array_t X, py::array_t W_idx_N, py::array_t W_idx_M, 84 | py::array_t W_val, py::array_t O) { 85 | int B = X.shape()[1]; 86 | int M = X.shape()[0]; 87 | int N = O.shape()[0]; 88 | int W_nnz = W_val.shape()[0]; 89 | 90 | auto buf_X = X.request(); 91 | auto buf_W_idx_N = W_idx_N.request(); 92 | auto buf_W_idx_M = W_idx_M.request(); 93 | auto buf_W_val = W_val.request(); 94 | auto buf_O = O.request(); 95 | 96 | float* ptr_X = (float*) buf_X.ptr; 97 | int* ptr_W_idx_N = (int*) buf_W_idx_N.ptr; 98 | int* ptr_W_idx_M = (int*) buf_W_idx_M.ptr; 99 | float* ptr_W_val = (float*) buf_W_val.ptr; 100 | float* ptr_O = (float*) buf_O.ptr; 101 | 102 | sparse_linear_vectorized_forward(B, M, N, W_nnz, ptr_X, ptr_W_idx_N, ptr_W_idx_M, ptr_W_val, ptr_O); 103 | } 104 | 105 | void sparse_linear_vectorized_backward_wrapper(py::array_t X, py::array_t W_idx_N, py::array_t W_idx_M, 106 | py::array_t W_val, py::array_t dLdO, py::array_t dLdX, 107 | py::array_t dLdW_val) { 108 | int B = X.shape()[1]; 109 | int M = X.shape()[0]; 110 | int N = dLdO.shape()[0]; 111 | int W_nnz = W_val.shape()[0]; 112 | 113 | auto buf_X = X.request(); 114 | auto buf_W_idx_N = W_idx_N.request(); 115 | auto buf_W_idx_M = W_idx_M.request(); 116 | auto buf_W_val = W_val.request(); 117 | auto buf_dLdO = dLdO.request(); 118 | auto buf_dLdX = dLdX.request(); 119 | auto buf_dLdW_val = dLdW_val.request(); 120 | 121 | float* ptr_X = (float*) buf_X.ptr; 122 | int* ptr_W_idx_N = (int*) buf_W_idx_N.ptr; 123 | int* ptr_W_idx_M = (int*) buf_W_idx_M.ptr; 124 | float* ptr_W_val = (float*) buf_W_val.ptr; 125 | float* ptr_dLdO = (float*) buf_dLdO.ptr; 126 | float* ptr_dLdX = (float*) buf_dLdX.ptr; 127 | float* ptr_dLdW_val = (float*) buf_dLdW_val.ptr; 128 | 129 | sparse_linear_vectorized_backward(B, M, N, W_nnz, ptr_X, ptr_W_idx_N, ptr_W_idx_M, 130 | ptr_W_val, ptr_dLdO, ptr_dLdX,ptr_dLdW_val); 131 | } -------------------------------------------------------------------------------- /sparseprop/lib/utils.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | namespace py = pybind11; 7 | 8 | 9 | // ====================================== Transpose =========================================== 10 | 11 | void tran(float* mat, float* matT, const int lda, const int ldb) { 12 | __m256 r0, r1, r2, r3, r4, r5, r6, r7; 13 | __m256 t0, t1, t2, t3, t4, t5, t6, t7; 14 | 15 | r0 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_load_ps(&mat[0*lda+0])), _mm_load_ps(&mat[4*lda+0]), 1); 16 | r1 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_load_ps(&mat[1*lda+0])), _mm_load_ps(&mat[5*lda+0]), 1); 17 | r2 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_load_ps(&mat[2*lda+0])), _mm_load_ps(&mat[6*lda+0]), 1); 18 | r3 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_load_ps(&mat[3*lda+0])), _mm_load_ps(&mat[7*lda+0]), 1); 19 | r4 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_load_ps(&mat[0*lda+4])), _mm_load_ps(&mat[4*lda+4]), 1); 20 | r5 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_load_ps(&mat[1*lda+4])), _mm_load_ps(&mat[5*lda+4]), 1); 21 | r6 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_load_ps(&mat[2*lda+4])), _mm_load_ps(&mat[6*lda+4]), 1); 22 | r7 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_load_ps(&mat[3*lda+4])), _mm_load_ps(&mat[7*lda+4]), 1); 23 | 24 | t0 = _mm256_unpacklo_ps(r0,r1); 25 | t1 = _mm256_unpackhi_ps(r0,r1); 26 | t2 = _mm256_unpacklo_ps(r2,r3); 27 | t3 = _mm256_unpackhi_ps(r2,r3); 28 | t4 = _mm256_unpacklo_ps(r4,r5); 29 | t5 = _mm256_unpackhi_ps(r4,r5); 30 | t6 = _mm256_unpacklo_ps(r6,r7); 31 | t7 = _mm256_unpackhi_ps(r6,r7); 32 | 33 | r0 = _mm256_shuffle_ps(t0,t2, 0x44); 34 | r1 = _mm256_shuffle_ps(t0,t2, 0xEE); 35 | r2 = _mm256_shuffle_ps(t1,t3, 0x44); 36 | r3 = _mm256_shuffle_ps(t1,t3, 0xEE); 37 | r4 = _mm256_shuffle_ps(t4,t6, 0x44); 38 | r5 = _mm256_shuffle_ps(t4,t6, 0xEE); 39 | r6 = _mm256_shuffle_ps(t5,t7, 0x44); 40 | r7 = _mm256_shuffle_ps(t5,t7, 0xEE); 41 | 42 | _mm256_store_ps(&matT[0*ldb], r0); 43 | _mm256_store_ps(&matT[1*ldb], r1); 44 | _mm256_store_ps(&matT[2*ldb], r2); 45 | _mm256_store_ps(&matT[3*ldb], r3); 46 | _mm256_store_ps(&matT[4*ldb], r4); 47 | _mm256_store_ps(&matT[5*ldb], r5); 48 | _mm256_store_ps(&matT[6*ldb], r6); 49 | _mm256_store_ps(&matT[7*ldb], r7); 50 | } 51 | 52 | inline void transpose(float* __restrict__ X, float* __restrict__ XT, const int N, const int M, const int block_size) { 53 | #pragma omp parallel for 54 | for(int i=0; i X, py::array_t XT, int block_size) { 163 | 164 | int N = X.shape()[0]; 165 | int M = X.shape()[1]; 166 | 167 | auto buf_X = X.request(); 168 | auto buf_XT = XT.request(); 169 | 170 | float* ptr_X = (float*) buf_X.ptr; 171 | float* ptr_XT = (float*) buf_XT.ptr; 172 | 173 | transpose(ptr_X, ptr_XT, N, M, block_size); 174 | } 175 | 176 | void sparsify_conv2d_wrapper(int OC, int IC, int K, py::array_t W, py::array_t W_idx_OC, 177 | py::array_t W_idx_IC, py::array_t W_idx_X, 178 | py::array_t W_idx_Y,py::array_t W_val) { 179 | 180 | auto buf_W = W.request(); 181 | auto buf_W_idx_OC = W_idx_OC.request(); 182 | auto buf_W_idx_IC = W_idx_IC.request(); 183 | auto buf_W_idx_X = W_idx_X.request(); 184 | auto buf_W_idx_Y = W_idx_Y.request(); 185 | auto buf_W_val = W_val.request(); 186 | 187 | int* ptr_W_idx_OC = (int*) buf_W_idx_OC.ptr; 188 | int16_t* ptr_W_idx_IC = (int16_t*) buf_W_idx_IC.ptr; 189 | uint8_t* ptr_W_idx_X = (uint8_t*) buf_W_idx_X.ptr; 190 | uint8_t* ptr_W_idx_Y = (uint8_t*) buf_W_idx_Y.ptr; 191 | float* ptr_W_val = (float*) buf_W_val.ptr; 192 | float* ptr_W= (float*) buf_W.ptr; 193 | 194 | sparsify_conv2d(IC, OC, K, ptr_W, ptr_W_idx_OC, ptr_W_idx_IC, ptr_W_idx_X, ptr_W_idx_Y, ptr_W_val); 195 | } 196 | 197 | void densify_conv2d_wrapper(int OC, int IC, int K, py::array_t W, py::array_t W_idx_OC, 198 | py::array_t W_idx_IC, py::array_t W_idx_X, 199 | py::array_t W_idx_Y,py::array_t W_val) { 200 | 201 | auto buf_W = W.request(); 202 | auto buf_W_idx_OC = W_idx_OC.request(); 203 | auto buf_W_idx_IC = W_idx_IC.request(); 204 | auto buf_W_idx_X = W_idx_X.request(); 205 | auto buf_W_idx_Y = W_idx_Y.request(); 206 | auto buf_W_val = W_val.request(); 207 | 208 | int* ptr_W_idx_OC = (int*) buf_W_idx_OC.ptr; 209 | int16_t* ptr_W_idx_IC = (int16_t*) buf_W_idx_IC.ptr; 210 | uint8_t* ptr_W_idx_X = (uint8_t*) buf_W_idx_X.ptr; 211 | uint8_t* ptr_W_idx_Y = (uint8_t*) buf_W_idx_Y.ptr; 212 | float* ptr_W_val = (float*) buf_W_val.ptr; 213 | float* ptr_W= (float*) buf_W.ptr; 214 | 215 | densify_conv2d(IC, OC, K, ptr_W, ptr_W_idx_OC, ptr_W_idx_IC, ptr_W_idx_X, ptr_W_idx_Y, ptr_W_val); 216 | } 217 | 218 | void further_sparsify_conv2d_wrapper(int OC, int IC, py::array_t W_idx_OC, py::array_t W_idx_IC, 219 | py::array_t W_idx_X, py::array_t W_idx_Y, 220 | py::array_t W_val, py::array_t W_idx_OC_new, 221 | py::array_t W_idx_IC_new, py::array_t W_idx_X_new, 222 | py::array_t W_idx_Y_new, py::array_t W_val_new, 223 | py::array_t mask) { 224 | 225 | auto buf_W_idx_OC = W_idx_OC.request(); 226 | auto buf_W_idx_IC = W_idx_IC.request(); 227 | auto buf_W_idx_X = W_idx_X.request(); 228 | auto buf_W_idx_Y = W_idx_Y.request(); 229 | auto buf_W_val = W_val.request(); 230 | auto buf_W_idx_OC_new = W_idx_OC_new.request(); 231 | auto buf_W_idx_IC_new = W_idx_IC_new.request(); 232 | auto buf_W_idx_X_new = W_idx_X_new.request(); 233 | auto buf_W_idx_Y_new = W_idx_Y_new.request(); 234 | auto buf_W_val_new = W_val_new.request(); 235 | auto buf_mask = mask.request(); 236 | 237 | int* ptr_W_idx_OC = (int*) buf_W_idx_OC.ptr; 238 | int16_t* ptr_W_idx_IC = (int16_t*) buf_W_idx_IC.ptr; 239 | uint8_t* ptr_W_idx_X = (uint8_t*) buf_W_idx_X.ptr; 240 | uint8_t* ptr_W_idx_Y = (uint8_t*) buf_W_idx_Y.ptr; 241 | float* ptr_W_val = (float*) buf_W_val.ptr; 242 | int* ptr_W_idx_OC_new = (int*) buf_W_idx_OC_new.ptr; 243 | int16_t* ptr_W_idx_IC_new = (int16_t*) buf_W_idx_IC_new.ptr; 244 | uint8_t* ptr_W_idx_X_new = (uint8_t*) buf_W_idx_X_new.ptr; 245 | uint8_t* ptr_W_idx_Y_new = (uint8_t*) buf_W_idx_Y_new.ptr; 246 | float* ptr_W_val_new = (float*) buf_W_val_new.ptr; 247 | int* ptr_mask = (int*) buf_mask.ptr; 248 | 249 | further_sparsify_conv2d(IC, OC, ptr_W_idx_OC, ptr_W_idx_IC, ptr_W_idx_X, ptr_W_idx_Y, ptr_W_val, ptr_W_idx_OC_new, ptr_W_idx_IC_new, ptr_W_idx_X_new, ptr_W_idx_Y_new, ptr_W_val_new, ptr_mask); 250 | } -------------------------------------------------------------------------------- /sparseprop/modules/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | from copy import deepcopy 4 | 5 | from sparseprop.modules.conv2d import SparseConv2d 6 | from sparseprop.modules.linear import SparseLinear 7 | 8 | @torch.enable_grad() 9 | def run_and_choose(modules, input_shape, verbose=False): 10 | if len(modules) == 1: 11 | if verbose: 12 | print('only one option...') 13 | return modules[0] 14 | 15 | X_orig = torch.randn(*input_shape) 16 | Y_orig = None 17 | 18 | min_time = 1e10 19 | best_module = None 20 | for module in modules: 21 | module_copy = deepcopy(module) 22 | X = X_orig.clone() 23 | X.requires_grad_() 24 | X.retain_grad() 25 | 26 | temp = time.time() 27 | O = module_copy(X) 28 | fwd_time = time.time() - temp 29 | 30 | if Y_orig is None: 31 | Y_orig = torch.randn_like(O) 32 | Y = Y_orig.clone() 33 | 34 | L = torch.mean((O - Y) ** 2) 35 | temp = time.time() 36 | L.backward() 37 | bwd_time = time.time() - temp 38 | 39 | if verbose: 40 | print(f'module {module} took {fwd_time} fwd and {bwd_time} bwd') 41 | 42 | full_time = fwd_time + bwd_time 43 | if full_time < min_time: 44 | min_time = full_time 45 | best_module = module 46 | 47 | if verbose: 48 | print(f'going with {best_module} with full time of {min_time}') 49 | return best_module 50 | 51 | def _sparsify_if_faster_linear(module, input_shape, include_dense, verbose): 52 | sp = SparseLinear( 53 | dense_weight=module.weight.data, 54 | bias=None if module.bias is None else torch.nn.Parameter(module.bias.data.clone()) 55 | ) 56 | 57 | if not include_dense: 58 | return sp 59 | 60 | assert input_shape is not None 61 | return run_and_choose([module, sp], input_shape, verbose=verbose) 62 | 63 | def _sparsify_if_faster_conv2d(conv, input_shape, include_dense, verbose): 64 | def bias_to_param(): 65 | if conv.bias is None: 66 | return None 67 | return torch.nn.Parameter(conv.bias.data.clone()) 68 | 69 | dense_weight = conv.weight.data 70 | stride = conv.stride[0] 71 | padding = conv.padding[0] 72 | 73 | sp1 = SparseConv2d( 74 | dense_weight, 75 | bias=bias_to_param(), 76 | padding=padding, 77 | stride=stride, 78 | vectorizing_over_on=False 79 | ) 80 | 81 | sp2 = SparseConv2d( 82 | dense_weight, 83 | bias=bias_to_param(), 84 | padding=padding, 85 | stride=stride, 86 | vectorizing_over_on=True 87 | ) 88 | 89 | modules = [] 90 | if include_dense: 91 | modules.append(conv) 92 | modules += [sp1, sp2] 93 | 94 | return run_and_choose(modules, input_shape, verbose=verbose) 95 | 96 | def sparsify_if_faster(module, input_shape, include_dense=True, verbose=False): 97 | if isinstance(module, torch.nn.Linear): 98 | return _sparsify_if_faster_linear(module, input_shape, include_dense, verbose) 99 | else: 100 | assert isinstance(module, torch.nn.Conv2d) 101 | return _sparsify_if_faster_conv2d(module, input_shape, include_dense, verbose) 102 | 103 | def sparsify_conv2d_auto(conv, input_shape, verbose=False): 104 | return _sparsify_if_faster_conv2d(conv, input_shape, include_dense=False, verbose=verbose) 105 | 106 | class TimingHook: 107 | def __init__(self, tag=None, verbose=False): 108 | self.clear() 109 | self._tag = tag 110 | self._verbose = verbose 111 | 112 | def __call__(self, module, inp, out): 113 | self.time = time.time() 114 | self.count += 1 115 | 116 | if isinstance(inp, tuple): 117 | inp = inp[0] 118 | if isinstance(out, tuple): 119 | out = out[0] 120 | 121 | if self._verbose: 122 | print(f"[Hook {self._tag}] inp: {inp.shape if inp is not None else inp}, out: {out.shape if out is not None else None}") 123 | 124 | def clear(self): 125 | self.time = None 126 | self.count = 0 127 | 128 | 129 | def attach_identity_and_time(module, X, Y, time_forward=False, time_backward=True): 130 | if not time_backward: 131 | t = time.time() 132 | O = module(X) 133 | return time.time() - t 134 | else: 135 | identity = torch.nn.Identity() 136 | 137 | X.requires_grad_() 138 | X.retain_grad() 139 | 140 | module_hook = TimingHook() 141 | identity_hook = TimingHook() 142 | handles = [ 143 | module.register_full_backward_hook(module_hook), 144 | identity.register_full_backward_hook(identity_hook) 145 | ] 146 | 147 | t = time.time() 148 | O_before_identity = module(X) 149 | forward_time = time.time() - t 150 | 151 | O = identity(O_before_identity) 152 | L = torch.mean((O - Y) ** 2) 153 | 154 | L.backward() 155 | 156 | backward_time = module_hook.time - identity_hook.time 157 | 158 | for handle in handles: 159 | handle.remove() 160 | 161 | if time_forward: 162 | return backward_time, forward_time 163 | else: 164 | return backward_time -------------------------------------------------------------------------------- /sparseprop/modules/conv2d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from sparseprop.modules.utils import to_sparse_format_conv2d, from_sparse_format_conv2d 3 | from sparseprop.modules.functions import SparseConvFunction 4 | from sparseprop import backend as sppb 5 | from copy import deepcopy 6 | 7 | class SparseConv2d(torch.nn.Module): 8 | def __init__(self, dense_weight, bias=None, padding=0, stride=1, vectorizing_over_on=False): 9 | super(SparseConv2d, self).__init__() 10 | 11 | self.OC, self.IC, self.K, _ = dense_weight.shape 12 | self.padding = padding 13 | self.stride = stride 14 | self.set_vectorizing_over_on(vectorizing_over_on) 15 | 16 | W_val, W_idx = to_sparse_format_conv2d(dense_weight) 17 | 18 | self.W_val = torch.nn.Parameter(W_val) 19 | self.W_idx = W_idx 20 | 21 | assert bias is None or isinstance(bias, torch.nn.Parameter), f"bias is not a parameter but it's {type(bias)}" 22 | self.bias = bias 23 | 24 | @staticmethod 25 | def from_dense(conv, vectorizing_over_on=False): 26 | def bias_to_param(): 27 | if conv.bias is None: 28 | return None 29 | return torch.nn.Parameter(conv.bias.data.clone()) 30 | 31 | dense_weight = conv.weight.data 32 | stride = conv.stride[0] 33 | padding = conv.padding[0] 34 | 35 | return SparseConv2d( 36 | dense_weight, 37 | bias=bias_to_param(), 38 | padding=padding, 39 | stride=stride, 40 | vectorizing_over_on=vectorizing_over_on 41 | ) 42 | 43 | def to_dense(self): 44 | dense_weight = from_sparse_format_conv2d( 45 | self.W_val, 46 | self.W_idx, 47 | shape=(self.OC, self.IC, self.K, self.K) 48 | ) 49 | 50 | conv = torch.nn.Conv2d( 51 | self.IC, 52 | self.OC, 53 | self.K, 54 | stride=self.stride, 55 | padding=self.padding, 56 | bias=self.bias is not None 57 | ) 58 | 59 | with torch.no_grad(): 60 | conv.weight.mul_(0) 61 | conv.weight.add_(dense_weight) 62 | 63 | if self.bias is not None: 64 | conv.bias.mul_(0) 65 | conv.bias.add_(self.bias) 66 | 67 | return conv 68 | 69 | def set_vectorizing_over_on(self, vectorizing_over_on): 70 | self.vectorizing_over_on = vectorizing_over_on 71 | self.vectorizing_bwd_over_on = vectorizing_over_on 72 | self.vectorizing_fwd_over_on = vectorizing_over_on and self.stride == 1 # stride 2 is not supported over on 73 | 74 | @property 75 | def weight(self): 76 | return self.W_val 77 | 78 | def forward(self, input): 79 | return SparseConvFunction.apply(input, self.W_val, self.W_idx, self.bias, self.OC, self.K, self.padding, self.stride, self.vectorizing_fwd_over_on, self.vectorizing_bwd_over_on) 80 | 81 | @torch.no_grad() 82 | def apply_further_mask(self, new_mask, input_shape=None, verbose=False): 83 | new_mask = (new_mask * 1).int() 84 | W_idx_OC, W_idx_IC, W_idx_X, W_idx_Y = self.W_idx 85 | new_nnz = torch.sum(new_mask).item() 86 | W_val_new = torch.zeros(new_nnz).float() 87 | W_idx_OC_new = torch.zeros_like(W_idx_OC).int() 88 | W_idx_IC_new = torch.zeros_like(W_idx_IC).type(torch.short) 89 | W_idx_X_new = torch.zeros(new_nnz).type(torch.uint8) 90 | W_idx_Y_new = torch.zeros(new_nnz).type(torch.uint8) 91 | sppb.further_sparsify_conv2d( 92 | self.OC, 93 | self.IC, 94 | W_idx_OC, 95 | W_idx_IC, 96 | W_idx_X, 97 | W_idx_Y, 98 | self.W_val.data, 99 | W_idx_OC_new, 100 | W_idx_IC_new, 101 | W_idx_X_new, 102 | W_idx_Y_new, 103 | W_val_new, 104 | new_mask 105 | ) 106 | 107 | sp1 = deepcopy(self) 108 | sp1.W_val = torch.nn.Parameter(W_val_new) 109 | sp1.W_idx = W_idx_OC_new, W_idx_IC_new, W_idx_X_new, W_idx_Y_new 110 | 111 | sp2 = deepcopy(sp1) 112 | sp2.set_vectorizing_over_on(not sp1.vectorizing_over_on) 113 | 114 | sp = run_and_choose([sp1, sp2], input_shape=input_shape, verbose=verbose) 115 | self.W_val = torch.nn.Parameter(W_val_new) 116 | self.W_idx = W_idx_OC_new, W_idx_IC_new, W_idx_X_new, W_idx_Y_new 117 | self.set_vectorizing_over_on(sp.vectorizing_over_on) 118 | 119 | def __repr__(self): 120 | nnz = len(self.W_val) 121 | numel = self.OC * self.IC * self.K * self.K 122 | return f'SparseConv2d([{self.OC}, {self.IC}, {self.K}, {self.K}], sp={1. - nnz/numel:.2f}, nnz={nnz}, s={self.stride}, p={self.padding}, voo={self.vectorizing_over_on})' -------------------------------------------------------------------------------- /sparseprop/modules/functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | from sparseprop import backend as sppb 5 | 6 | TRANSPOSE_BLOCK_SIZE = 16 7 | 8 | class SparseLinearFunction(torch.autograd.Function): 9 | 10 | @staticmethod 11 | def forward(ctx, inputT, W_val, W_idx, bias, N): 12 | input_flat_t = inputT.reshape(-1, inputT.shape[-1]) 13 | B, M = input_flat_t.shape 14 | if B % TRANSPOSE_BLOCK_SIZE == 0 and M % TRANSPOSE_BLOCK_SIZE == 0: 15 | input_flat = torch.zeros(M, B) 16 | sppb.transpose(input_flat_t, input_flat, TRANSPOSE_BLOCK_SIZE) 17 | else: 18 | input_flat = input_flat_t.t().contiguous() 19 | ctx.inputT_shape = inputT.shape 20 | 21 | M, B = input_flat.shape 22 | 23 | W_idx_N, W_idx_M = W_idx 24 | 25 | output = torch.zeros(N, B).float() 26 | sppb.sparse_linear_vectorized_forward(input_flat, W_idx_N, W_idx_M, W_val, output) 27 | 28 | ctx.save_for_backward(W_val, bias) 29 | ctx.svd = (input_flat, W_idx_N, W_idx_M) 30 | 31 | if bias is not None: 32 | output += bias.view(-1, 1) 33 | 34 | if B % TRANSPOSE_BLOCK_SIZE == 0 and N % TRANSPOSE_BLOCK_SIZE == 0: 35 | output_t = torch.zeros(B, N) 36 | sppb.transpose(output, output_t, TRANSPOSE_BLOCK_SIZE) 37 | else: 38 | output_t = output.t() # (B, N) 39 | output_t = output_t.reshape(*ctx.inputT_shape[:-1], N) 40 | return output_t 41 | 42 | @staticmethod 43 | def backward(ctx, grad_output_t): 44 | W_val, bias = ctx.saved_tensors 45 | input_flat, W_idx_N, W_idx_M = ctx.svd 46 | 47 | grad_output_t = grad_output_t.reshape(-1, grad_output_t.shape[-1]).contiguous() 48 | B, N = grad_output_t.shape 49 | if B % TRANSPOSE_BLOCK_SIZE == 0 and N % TRANSPOSE_BLOCK_SIZE == 0: 50 | grad_output = torch.zeros(N, B) 51 | sppb.transpose(grad_output_t, grad_output, TRANSPOSE_BLOCK_SIZE) 52 | else: 53 | grad_output = grad_output_t.t().contiguous() 54 | 55 | grad_input = torch.zeros_like(input_flat).float().contiguous() # (M, B) 56 | grad_W_val = torch.zeros_like(W_val).float().contiguous() 57 | 58 | sppb.sparse_linear_vectorized_backward( 59 | input_flat, 60 | W_idx_N, 61 | W_idx_M, 62 | W_val, 63 | grad_output, 64 | grad_input, 65 | grad_W_val 66 | ) 67 | 68 | M = input_flat.shape[0] 69 | if B % TRANSPOSE_BLOCK_SIZE == 0 and M % TRANSPOSE_BLOCK_SIZE == 0: 70 | grad_input_t = torch.zeros(B, M) 71 | sppb.transpose(grad_input, grad_input_t, TRANSPOSE_BLOCK_SIZE) 72 | else: 73 | grad_input_t = grad_input.t() # (B, M) 74 | grad_input_t = grad_input_t.reshape(ctx.inputT_shape) 75 | 76 | grad_bias = None 77 | if bias is not None: 78 | grad_bias = grad_output_t.sum([i for i in range(len(grad_output_t.shape) - 1)]) 79 | 80 | return grad_input_t, grad_W_val, None, grad_bias, None 81 | 82 | 83 | class SparseConvFunction(torch.autograd.Function): 84 | 85 | @staticmethod 86 | def forward(ctx, input, W_val, W_idx, bias, OC, K, padding, stride, vectorizing_fwd_over_on, vectorizing_bwd_over_on): 87 | orig_input = input 88 | 89 | assert stride in [1, 2], 'only strides 1 and 2 are supported' 90 | 91 | B, IC, M, N = orig_input.shape 92 | OM = math.ceil((M + 2 * padding - K + 1) / stride) 93 | ON = math.ceil((N + 2 * padding - K + 1) / stride) 94 | 95 | if vectorizing_fwd_over_on: 96 | assert stride == 1 # only stride 1 is supported in this case, for now 97 | output = torch.zeros(B, OC, OM, ON).float() 98 | sppb.sparse_conv2d_vectorized_forward_over_on_stride_1(input, *W_idx, W_val, output, K, padding) 99 | else: 100 | input = input.permute(1, 2, 3, 0).contiguous() 101 | output = torch.zeros(OC, OM, ON, B).float() 102 | if stride == 1: 103 | sppb.sparse_conv2d_vectorized_forward_stride_1(input, *W_idx, W_val, output, K, padding) 104 | elif stride == 2: 105 | sppb.sparse_conv2d_vectorized_forward_stride_2(input, *W_idx, W_val, output, K, padding) 106 | 107 | output = output.permute(3, 0, 1, 2) 108 | 109 | if vectorizing_bwd_over_on: # backward needs the original shape 110 | ctx.save_for_backward(W_val, bias) 111 | ctx.svd = (orig_input, *W_idx) 112 | else: 113 | ctx.save_for_backward(W_val, bias) 114 | ctx.svd = (input, *W_idx) 115 | ctx.K, ctx.padding, ctx.stride = K, padding, stride 116 | ctx.vectorizing_bwd_over_on = vectorizing_bwd_over_on 117 | 118 | if bias is not None: 119 | output += bias.view(1, -1, 1, 1) 120 | 121 | return output 122 | 123 | @staticmethod 124 | def backward(ctx, grad_output): 125 | W_val, bias = ctx.saved_tensors 126 | input, W_idx_OC, W_idx_IC, W_idx_X, W_idx_Y = ctx.svd 127 | K, padding, stride = ctx.K, ctx.padding, ctx.stride 128 | 129 | vectorizing_bwd_over_on = ctx.vectorizing_bwd_over_on 130 | 131 | grad_input = torch.zeros_like(input).float() 132 | grad_W_val = torch.zeros_like(W_val).float() 133 | 134 | assert stride in [1, 2], 'only stride 1 and 2 are supported' 135 | 136 | if vectorizing_bwd_over_on: 137 | if stride == 1: 138 | sppb.sparse_conv2d_vectorized_backward_over_on_stride_1(input, W_idx_OC, W_idx_IC, W_idx_X, W_idx_Y, W_val, grad_output, grad_input, grad_W_val, K, padding) 139 | else: 140 | sppb.sparse_conv2d_vectorized_backward_over_on_stride_2(input, W_idx_OC, W_idx_IC, W_idx_X, W_idx_Y, W_val, grad_output, grad_input, grad_W_val, K, padding) 141 | else: 142 | go = grad_output.permute(1, 2, 3, 0).contiguous() 143 | if stride == 1: 144 | sppb.sparse_conv2d_vectorized_backward_stride_1(input, W_idx_OC, W_idx_IC, W_idx_X, W_idx_Y, W_val, go, grad_input, grad_W_val, K, padding) 145 | else: 146 | sppb.sparse_conv2d_vectorized_backward_stride_2(input, W_idx_OC, W_idx_IC, W_idx_X, W_idx_Y, W_val, go, grad_input, grad_W_val, K, padding) 147 | 148 | grad_bias = None 149 | if bias is not None: 150 | grad_bias = grad_output.sum(dim=(0, 2, 3)) 151 | 152 | if not vectorizing_bwd_over_on: 153 | grad_input = grad_input.permute(3, 0, 1, 2) 154 | 155 | return grad_input, grad_W_val, None, grad_bias, None, None, None, None, None, None 156 | 157 | -------------------------------------------------------------------------------- /sparseprop/modules/linear.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from scipy.sparse import csr_matrix 3 | 4 | from sparseprop.modules.functions import SparseLinearFunction 5 | from sparseprop.modules.utils import to_csr_2d, from_csr_2d 6 | 7 | class SparseLinear(torch.nn.Module): 8 | def __init__(self, dense_weight, bias=None): 9 | super(SparseLinear, self).__init__() 10 | self.N, self.M = dense_weight.shape 11 | 12 | W_val, W_idx = to_csr_2d(dense_weight) 13 | self.W_val = torch.nn.Parameter(W_val) 14 | self.W_idx = W_idx 15 | 16 | assert bias is None or isinstance(bias, torch.nn.Parameter), f"bias is not a parameter but it's {type(bias)}" 17 | self.bias = bias 18 | 19 | @staticmethod 20 | def from_dense(module): 21 | return SparseLinear( 22 | dense_weight=module.weight.data, 23 | bias=None if module.bias is None else torch.nn.Parameter(module.bias.data.clone()) 24 | ) 25 | 26 | def to_dense(self): 27 | dense_weight = from_csr_2d( 28 | self.W_val, 29 | self.W_idx, 30 | shape=(self.N, self.M) 31 | ) 32 | 33 | linear = torch.nn.Linear( 34 | self.M, 35 | self.N, 36 | bias=self.bias is not None 37 | ) 38 | 39 | with torch.no_grad(): 40 | linear.weight.mul_(0) 41 | linear.weight.add_(dense_weight) 42 | 43 | if self.bias is not None: 44 | linear.bias.mul_(0) 45 | linear.bias.add_(self.bias) 46 | 47 | return linear 48 | 49 | @property 50 | def weight(self): 51 | return self.W_val 52 | 53 | def forward(self, input): 54 | return SparseLinearFunction.apply(input, self.W_val, self.W_idx, self.bias, self.N) 55 | 56 | @torch.no_grad() 57 | def apply_further_mask(self, new_mask): 58 | """ 59 | This function is used when we need to further sparsify a sparse module, e.g., gradual pruning. 60 | """ 61 | 62 | indptr, indices = self.W_idx 63 | dense_weight = torch.Tensor(csr_matrix(( 64 | self.W_val.data, 65 | indices, 66 | indptr 67 | ), shape=(self.N, self.M)).toarray()).float() 68 | 69 | dense_mask = torch.Tensor(csr_matrix(( 70 | new_mask, 71 | indices, 72 | indptr 73 | ), shape=(self.N, self.M)).toarray()).float() 74 | 75 | W_val, W_idx = to_csr_2d(dense_weight * dense_mask) 76 | self.W_val = torch.nn.Parameter(W_val) 77 | self.W_idx = W_idx 78 | 79 | def __repr__(self): 80 | nnz = len(self.W_val) 81 | numel = self.N * self.M 82 | return f"SparseLinear([{self.N}, {self.M}], sp={1. - nnz/numel:.2f}, nnz={nnz})" 83 | -------------------------------------------------------------------------------- /sparseprop/modules/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from copy import deepcopy 3 | import time 4 | from scipy.sparse import csr_matrix 5 | from sparseprop import backend as sppb 6 | 7 | def to_csr_2d(data): 8 | if isinstance(data, torch.nn.Parameter): 9 | data = data.data 10 | spa = csr_matrix(data, shape=data.shape) 11 | val = torch.Tensor(spa.data) 12 | idx_N = torch.Tensor(spa.indptr).int() 13 | idx_M = torch.Tensor(spa.indices).int() 14 | return val, (idx_N, idx_M) 15 | 16 | def from_csr_2d(val, idx, shape): 17 | if isinstance(val, torch.nn.Parameter): 18 | val = val.data 19 | idx_N, idx_M = idx 20 | return torch.Tensor(csr_matrix(( 21 | val, 22 | idx_M, 23 | idx_N 24 | ), shape=shape).toarray()).float() 25 | 26 | def to_sparse_format_conv2d(dense_weight): 27 | if isinstance(dense_weight, torch.nn.Parameter): 28 | dense_weight = dense_weight.data 29 | OC, IC, K, _ = dense_weight.shape 30 | nnz = torch.sum(dense_weight != 0).item() 31 | W_val = torch.zeros(nnz).float() 32 | W_OC = torch.zeros(OC + 1).int() 33 | W_IC = torch.zeros((IC + 1) * OC).type(torch.short) 34 | W_X = torch.zeros(nnz).type(torch.uint8) 35 | W_Y = torch.zeros(nnz).type(torch.uint8) 36 | sppb.sparsify_conv2d(OC, IC, K, dense_weight, W_OC, W_IC, W_X, W_Y, W_val) 37 | return W_val, (W_OC, W_IC, W_X, W_Y) 38 | 39 | def from_sparse_format_conv2d(W_val, W_idx, shape): 40 | if isinstance(W_val, torch.nn.Parameter): 41 | W_val = W_val.data 42 | W_OC, W_IC, W_X, W_Y = W_idx 43 | OC, IC, K, _ = shape 44 | dense_weight = torch.zeros(*shape) 45 | sppb.densify_conv2d(OC, IC, K, dense_weight, W_OC, W_IC, W_X, W_Y, W_val) 46 | return dense_weight -------------------------------------------------------------------------------- /sparseprop/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from copy import deepcopy 3 | 4 | from sparseprop.modules.linear import SparseLinear 5 | from sparseprop.modules.conv2d import SparseConv2d 6 | from sparseprop.modules import sparsify_if_faster 7 | 8 | @torch.no_grad() 9 | def sparsity(module): 10 | if hasattr(module, 'weight'): 11 | return 1. - torch.mean((module.weight != 0).float()).item() 12 | else: 13 | return 0. 14 | 15 | def swap_module(network, module_name, new_module): 16 | name_parts = module_name.split('.') 17 | parent = network 18 | for part in name_parts[:-1]: 19 | if part.isdigit(): 20 | parent = parent[int(part)] 21 | else: 22 | parent = getattr(parent, part) 23 | 24 | last_part = name_parts[-1] 25 | if last_part.isdigit(): 26 | parent[int(last_part)] = new_module 27 | else: 28 | setattr(parent, last_part, new_module) 29 | 30 | class ShapeHook: 31 | def __init__(self): 32 | self.inshape = None 33 | self.outshape = None 34 | 35 | def __call__(self, module, inp, out): 36 | if isinstance(inp, tuple): 37 | inp = inp[0] 38 | if isinstance(out, tuple): 39 | out = out[0] 40 | 41 | self.inshape = inp.shape[1:] 42 | self.outshape = out.shape[1:] 43 | 44 | def generate_intermediate_shapes(network, input_shape): 45 | hooks = {} 46 | handles = [] 47 | for name, module in network.named_modules(): 48 | if any([isinstance(module, c) for c in [torch.nn.Linear, torch.nn.Conv2d]]): 49 | hook = ShapeHook() 50 | handles.append(module.register_forward_hook(hook)) 51 | hooks[name] = hook 52 | 53 | B = 1 54 | training_mode = network.training 55 | network.eval() 56 | with torch.no_grad(): 57 | network(torch.randn(B, *input_shape)) 58 | network.train(training_mode) 59 | 60 | inshapes = {name: hook.inshape for name, hook in hooks.items()} 61 | outshapes = {name: hook.outshape for name, hook in hooks.items()} 62 | 63 | for handle in handles: 64 | handle.remove() 65 | 66 | return inshapes, outshapes 67 | 68 | def swap_modules_with_sparse(network, input_shape, inplace=False, skip_modules=None, verbose=False): 69 | # e.g., shapes_tag='resnet18', input_shape=(B, IC, M, N), skip_modules='input,conv1,conv2' 70 | 71 | if not inplace: 72 | network = deepcopy(network) 73 | if skip_modules is not None: 74 | skip_modules = skip_modules.split(',') 75 | 76 | B = input_shape[0] 77 | input_shape = input_shape[1:] 78 | inshapes, _ = generate_intermediate_shapes(network, input_shape) 79 | 80 | for name, module in network.named_modules(): 81 | is_conv = isinstance(module, torch.nn.Conv2d) 82 | is_linear = isinstance(module, torch.nn.Linear) 83 | if not is_conv and not is_linear: 84 | continue 85 | 86 | found = False 87 | if skip_modules is not None: 88 | for sm in skip_modules: 89 | if name == sm.strip(): 90 | found = True 91 | break 92 | 93 | if verbose: 94 | print('-' * 30) 95 | 96 | if found: 97 | print(f'Skipped {name}.') 98 | continue 99 | 100 | sp = sparsity(module) 101 | new_module = None 102 | if sp > .8: 103 | new_module = sparsify_if_faster( 104 | module, 105 | (B, *inshapes[name]), 106 | verbose=verbose 107 | ) 108 | 109 | if new_module is not None and new_module != module: 110 | swap_module(network, name, new_module) 111 | print(f'module {name} replaced with {str(new_module)}') 112 | else: 113 | print(f'keeping the module {name} dense...') 114 | 115 | return network 116 | 117 | def error(pred, target): 118 | e = torch.mean((pred - target) ** 2) / torch.norm(target) ** 2 119 | return e.item() --------------------------------------------------------------------------------