├── Figs ├── framework.png └── res.png ├── LICENSE ├── README.md └── SViTE ├── LICENSE ├── __pycache__ ├── models.cpython-37.pyc ├── vision2.cpython-37.pyc ├── vision3.cpython-37.pyc ├── vision_transformer.cpython-37.pyc └── vision_transformer_data.cpython-37.pyc ├── backup ├── 01_run_tiny_baseline.sh ├── 02_run_tiny_dst.sh ├── 032801_cifar_dst.log ├── dst_utils_core.py ├── sparselearning │ ├── Grasp.py │ ├── Srelu.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-35.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── core.cpython-36.pyc │ │ ├── models.cpython-35.pyc │ │ ├── models.cpython-36.pyc │ │ ├── snip.cpython-35.pyc │ │ ├── snip.cpython-36.pyc │ │ ├── sparse_sgd.cpython-36.pyc │ │ ├── utils.cpython-35.pyc │ │ └── utils.cpython-36.pyc │ ├── copy │ │ ├── test.py │ │ └── utils.py │ ├── core.py │ ├── funcs.py │ ├── models.py │ ├── resnet.py │ ├── snip.py │ ├── sparse_sgd.py │ └── utils.py ├── test_mask.py └── vmsh │ ├── 0404 │ ├── 01_vm1_run_tiny_dst.sh │ ├── 03_vm3_run_tiny_dst.sh │ └── 09_vm9_run_tiny_dst.sh │ └── 0405 │ ├── 01_vm1_run_tiny_dst.sh │ ├── 03_vm3_run_tiny_dst.sh │ └── 09_vm9_run_tiny_dst.sh ├── cmd ├── utsh │ └── 01_run_tiny_dst_structure.sh └── vm │ ├── 0409 │ ├── vm1.sh │ ├── vm3.sh │ └── vm9.sh │ ├── 0414 │ ├── vm1.sh │ ├── vm3.sh │ └── vm9.sh │ ├── 0416 │ ├── vm1.sh │ ├── vm2.sh │ ├── vm3.sh │ └── vm9.sh │ ├── 0420 │ └── vm2.sh │ ├── 0422 │ ├── vm1.sh │ └── vm3.sh │ ├── 0424 │ └── vm9.sh │ ├── 0426 │ ├── vm1.sh │ └── vm3.sh │ ├── 0428 │ └── deit_tiny_structure.sh │ ├── 0430 │ └── deit_base_structure.sh │ ├── 0502 │ ├── vm14.sh │ ├── vm2.sh │ └── vm3.sh │ ├── 0506 │ ├── structure_small.sh │ └── vm3.sh │ ├── 0507 │ ├── line67.sh │ ├── line69.sh │ └── vm1.sh │ ├── 0508 │ ├── line71.sh │ ├── line72.sh │ ├── line73.sh │ ├── line74.sh │ ├── line75.sh │ ├── line77.sh │ ├── line78.sh │ └── line79.sh │ ├── 0509 │ └── gumbel.sh │ ├── 0516 │ └── line70.sh │ ├── 0517 │ ├── line67.sh │ ├── line67_new_95.sh │ ├── line68.sh │ ├── line69.sh │ ├── line71.sh │ ├── line72.sh │ ├── line73.sh │ └── resume.sh │ ├── 0519 │ ├── layer_wise_token_uns_small.sh │ ├── line72_95_s.sh │ ├── line78.sh │ └── line79.sh │ ├── inference.sh │ └── inference_data.sh ├── datasets.py ├── dst_utils └── core.py ├── engine.py ├── hubconf.py ├── inference.py ├── losses.py ├── main.py ├── models.py ├── requirements.txt ├── run_with_submitit.py ├── samplers.py ├── tox.ini ├── utils.py ├── vision2.py ├── vision3.py ├── vision_gumbel.py ├── vision_gumbel_structure.py ├── vision_transformer.py └── vision_transformer_data.py /Figs/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VITA-Group/SViTE/7dc89fd8fa5f86e797620f00cce6f11e2b73765d/Figs/framework.png -------------------------------------------------------------------------------- /Figs/res.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VITA-Group/SViTE/7dc89fd8fa5f86e797620f00cce6f11e2b73765d/Figs/res.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 VITA 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Chasing Sparsity in Vision Transformers: An End-to-End Exploration 2 | 3 | [![License: MIT](https://img.shields.io/badge/License-MIT-green.svg)](https://opensource.org/licenses/MIT) 4 | 5 | Codes for [NeurIPS'21] [Chasing Sparsity in Vision Transformers: An End-to-End Exploration](https://arxiv.org/pdf/2106.04533.pdf). 6 | 7 | Tianlong Chen, Yu Cheng, Zhe Gan, Lu Yuan, Lei Zhang, Zhangyang Wang 8 | 9 | 10 | 11 | ## Overall Results 12 | 13 | 14 | 15 | Extensive results on ImageNet with diverse ViT backbones validate the effectiveness of our proposals which obtain significantly reduced computational cost and almost unimpaired generalization. Perhaps most surprisingly, we find that the proposed sparse (co-)training can even *improve the ViT accuracy* rather than compromising it, making sparsity a tantalizing “free lunch”. For example, our sparsified DeiT-Small at (5%, 50%) sparsity for (data, architecture), improves 0.28% top-1 accuracy, and meanwhile enjoys 49.32% FLOPs and 4.40% running time savings. 16 | 17 | 18 | 19 | ## Proposed Framework of SViTE 20 | 21 | ![](./Figs/framework.png) 22 | 23 | 24 | 25 | ## Implementations of SViTE 26 | 27 | ### Set Environment 28 | 29 | ```shell 30 | conda create -n vit python=3.6 31 | 32 | pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html 33 | 34 | pip install tqdm scipy timm 35 | 36 | git clone https://github.com/NVIDIA/apex 37 | 38 | cd apex 39 | 40 | pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ 41 | 42 | pip install -v --disable-pip-version-check --no-cache-dir ./ 43 | ``` 44 | 45 | ### Cmd 46 | 47 | Command for unstructured sparsity, i.e., SViTE. 48 | 49 | - SViTE-Small 50 | 51 | ```shell 52 | bash cmd/ vm/0426/vm1.sh 0,1,2,3,4,5,6,7 53 | ``` 54 | 55 | Details 56 | 57 | ```shell 58 | CUDA_VISIBLE_DEVICES=$1 \ 59 | python -m torch.distributed.launch \ 60 | --nproc_per_node=8 \ 61 | --use_env main.py \ 62 | --model deit_small_patch16_224 \ 63 | --epochs 600 \ 64 | --batch-size 64 \ 65 | --data-path ../../imagenet \ 66 | --output_dir ./small_dst_uns_0426_vm1 \ 67 | --dist_url tcp://127.0.0.1:23305 \ 68 | --sparse_init fixed_ERK \ 69 | --density 0.4 \ 70 | --update_frequency 15000 \ 71 | --growth gradient \ 72 | --death magnitude \ 73 | --redistribution none 74 | ``` 75 | 76 | - SViTE-Base 77 | 78 | ```shell 79 | bash cmd/ vm/0426/vm3.sh 0,1,2,3,4,5,6,7 80 | ``` 81 | 82 | Details 83 | 84 | ```shell 85 | CUDA_VISIBLE_DEVICES=$1 \ 86 | python -m torch.distributed.launch \ 87 | --nproc_per_node=8 \ 88 | --use_env main.py \ 89 | --model deit_base_patch16_224 \ 90 | --epochs 600 \ 91 | --batch-size 128 \ 92 | --data-path ../../imagenet \ 93 | --output_dir ./base_dst_uns_0426_vm3 \ 94 | --dist_url tcp://127.0.0.1:23305 \ 95 | --sparse_init fixed_ERK \ 96 | --density 0.4 \ 97 | --update_frequency 7000 \ 98 | --growth gradient \ 99 | --death magnitude \ 100 | --redistribution none 101 | ``` 102 | 103 | **Remark.** More commands can be found under the "cmd" folder. 104 | 105 | ## Citation 106 | 107 | ``` 108 | @misc{chen2021chasing, 109 | title={Chasing Sparsity in Vision Transformers:An End-to-End Exploration}, 110 | author={Tianlong Chen and Yu Cheng and Zhe Gan and Lu Yuan and Lei Zhang and Zhangyang Wang}, 111 | year={2021}, 112 | eprint={2106.04533}, 113 | archivePrefix={arXiv}, 114 | primaryClass={cs.CV} 115 | } 116 | ``` 117 | 118 | 119 | 120 | ## Acknowledge Related Repos 121 | 122 | ViT : https://github.com/jeonsworld/ViT-pytorch 123 | 124 | ViT : https://github.com/google-research/vision_transformer 125 | 126 | Rig : https://github.com/google-research/rigl 127 | 128 | DeiT: https://github.com/facebookresearch/deit 129 | -------------------------------------------------------------------------------- /SViTE/LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2020 - present, Facebook, Inc 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /SViTE/__pycache__/models.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VITA-Group/SViTE/7dc89fd8fa5f86e797620f00cce6f11e2b73765d/SViTE/__pycache__/models.cpython-37.pyc -------------------------------------------------------------------------------- /SViTE/__pycache__/vision2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VITA-Group/SViTE/7dc89fd8fa5f86e797620f00cce6f11e2b73765d/SViTE/__pycache__/vision2.cpython-37.pyc -------------------------------------------------------------------------------- /SViTE/__pycache__/vision3.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VITA-Group/SViTE/7dc89fd8fa5f86e797620f00cce6f11e2b73765d/SViTE/__pycache__/vision3.cpython-37.pyc -------------------------------------------------------------------------------- /SViTE/__pycache__/vision_transformer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VITA-Group/SViTE/7dc89fd8fa5f86e797620f00cce6f11e2b73765d/SViTE/__pycache__/vision_transformer.cpython-37.pyc -------------------------------------------------------------------------------- /SViTE/__pycache__/vision_transformer_data.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VITA-Group/SViTE/7dc89fd8fa5f86e797620f00cce6f11e2b73765d/SViTE/__pycache__/vision_transformer_data.cpython-37.pyc -------------------------------------------------------------------------------- /SViTE/backup/01_run_tiny_baseline.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=$1 \ 2 | python -m torch.distributed.launch \ 3 | --nproc_per_node=2 \ 4 | --use_env main.py \ 5 | --model deit_tiny_patch16_224 \ 6 | --batch-size 256 \ 7 | --data-path ../../../imagenet \ 8 | --output_dir ./tiny_baseline \ 9 | --dist_url tcp://127.0.0.1:3333 -------------------------------------------------------------------------------- /SViTE/backup/02_run_tiny_dst.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=$1 \ 2 | python -m torch.distributed.launch \ 3 | --nproc_per_node=8 \ 4 | --use_env main.py \ 5 | --model deit_tiny_patch16_224 \ 6 | --batch-size 64 \ 7 | --data-path ../../imagenet \ 8 | --output_dir ./tiny_dst \ 9 | --dist_url tcp://127.0.0.1:2457 \ 10 | --sparse_init fixed_ERK \ 11 | --density 0.05 \ 12 | --update_frequency 1000 \ 13 | --growth gradient \ 14 | --death magnitude \ 15 | --redistribution none -------------------------------------------------------------------------------- /SViTE/backup/sparselearning/Grasp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.autograd as autograd 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import math 6 | 7 | import copy 8 | import types 9 | 10 | 11 | def GraSP_fetch_data(dataloader, num_classes, samples_per_class): 12 | datas = [[] for _ in range(num_classes)] 13 | labels = [[] for _ in range(num_classes)] 14 | mark = dict() 15 | dataloader_iter = iter(dataloader) 16 | while True: 17 | inputs, targets = next(dataloader_iter) 18 | for idx in range(inputs.shape[0]): 19 | x, y = inputs[idx:idx+1], targets[idx:idx+1] 20 | category = y.item() 21 | if len(datas[category]) == samples_per_class: 22 | mark[category] = True 23 | continue 24 | datas[category].append(x) 25 | labels[category].append(y) 26 | if len(mark) == num_classes: 27 | break 28 | 29 | X, y = torch.cat([torch.cat(_, 0) for _ in datas]), torch.cat([torch.cat(_) for _ in labels]).view(-1) 30 | return X, y 31 | 32 | 33 | def count_total_parameters(net): 34 | total = 0 35 | for m in net.modules(): 36 | if isinstance(m, (nn.Linear, nn.Conv2d)): 37 | total += m.weight.numel() 38 | return total 39 | 40 | 41 | def count_fc_parameters(net): 42 | total = 0 43 | for m in net.modules(): 44 | if isinstance(m, (nn.Linear)): 45 | total += m.weight.numel() 46 | return total 47 | 48 | 49 | def GraSP(net, ratio, train_dataloader, device, num_classes=10, samples_per_class=25, num_iters=1, T=200, reinit=True): 50 | eps = 1e-10 51 | keep_ratio = ratio 52 | old_net = net 53 | 54 | net = copy.deepcopy(net) # .eval() 55 | net.zero_grad() 56 | 57 | weights = [] 58 | total_parameters = count_total_parameters(net) 59 | fc_parameters = count_fc_parameters(net) 60 | 61 | # rescale_weights(net) 62 | for layer in net.modules(): 63 | if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear): 64 | if isinstance(layer, nn.Linear) and reinit: 65 | nn.init.xavier_normal(layer.weight) 66 | weights.append(layer.weight) 67 | 68 | inputs_one = [] 69 | targets_one = [] 70 | 71 | grad_w = None 72 | for w in weights: 73 | w.requires_grad_(True) 74 | 75 | print_once = False 76 | for it in range(num_iters): 77 | print("(1): Iterations %d/%d." % (it, num_iters)) 78 | inputs, targets = GraSP_fetch_data(train_dataloader, num_classes, samples_per_class) 79 | N = inputs.shape[0] 80 | din = copy.deepcopy(inputs) 81 | dtarget = copy.deepcopy(targets) 82 | inputs_one.append(din[:N//2]) 83 | targets_one.append(dtarget[:N//2]) 84 | inputs_one.append(din[N // 2:]) 85 | targets_one.append(dtarget[N // 2:]) 86 | inputs = inputs.to(device) 87 | targets = targets.to(device) 88 | 89 | outputs = net.forward(inputs[:N//2])/T 90 | if print_once: 91 | # import pdb; pdb.set_trace() 92 | x = F.softmax(outputs) 93 | print(x) 94 | print(x.max(), x.min()) 95 | print_once = False 96 | loss = F.cross_entropy(outputs, targets[:N//2]) 97 | # ===== debug ================ 98 | grad_w_p = autograd.grad(loss, weights) 99 | if grad_w is None: 100 | grad_w = list(grad_w_p) 101 | else: 102 | for idx in range(len(grad_w)): 103 | grad_w[idx] += grad_w_p[idx] 104 | 105 | outputs = net.forward(inputs[N // 2:])/T 106 | loss = F.cross_entropy(outputs, targets[N // 2:]) 107 | grad_w_p = autograd.grad(loss, weights, create_graph=False) 108 | if grad_w is None: 109 | grad_w = list(grad_w_p) 110 | else: 111 | for idx in range(len(grad_w)): 112 | grad_w[idx] += grad_w_p[idx] 113 | 114 | ret_inputs = [] 115 | ret_targets = [] 116 | 117 | for it in range(len(inputs_one)): 118 | print("(2): Iterations %d/%d." % (it, num_iters)) 119 | inputs = inputs_one.pop(0).to(device) 120 | targets = targets_one.pop(0).to(device) 121 | ret_inputs.append(inputs) 122 | ret_targets.append(targets) 123 | outputs = net.forward(inputs)/T 124 | loss = F.cross_entropy(outputs, targets) 125 | # ===== debug ============== 126 | 127 | grad_f = autograd.grad(loss, weights, create_graph=True) 128 | z = 0 129 | count = 0 130 | for layer in net.modules(): 131 | if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear): 132 | z += (grad_w[count].data * grad_f[count]).sum() 133 | count += 1 134 | z.backward() 135 | 136 | grads = dict() 137 | old_modules = list(old_net.modules()) 138 | for idx, layer in enumerate(net.modules()): 139 | if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear): 140 | grads[old_modules[idx]] = -layer.weight.data * layer.weight.grad # -theta_q Hg 141 | 142 | # Gather all scores in a single vector and normalise 143 | all_scores = torch.cat([torch.flatten(x) for x in grads.values()]) 144 | norm_factor = torch.abs(torch.sum(all_scores)) + eps 145 | print("** norm factor:", norm_factor) 146 | all_scores.div_(norm_factor) 147 | 148 | num_params_to_rm = int(len(all_scores) * (1-keep_ratio)) 149 | threshold, _ = torch.topk(all_scores, num_params_to_rm, sorted=True) 150 | # import pdb; pdb.set_trace() 151 | acceptable_score = threshold[-1] 152 | print('** accept: ', acceptable_score) 153 | keep_masks = [] 154 | for m, g in grads.items(): 155 | keep_masks.append(((g / norm_factor) <= acceptable_score).float()) 156 | 157 | # print(torch.sum(torch.cat([torch.flatten(x == 1) for x in keep_masks.values()]))) 158 | 159 | return keep_masks -------------------------------------------------------------------------------- /SViTE/backup/sparselearning/Srelu.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script defined the SReLU (S-shaped Rectified Linear Activation Unit): 3 | .. math:: 4 | h(x_i) = \\left\\{\\begin{matrix} t_i^r + a_i^r(x_i - t_i^r), x_i \\geq t_i^r \\\\ x_i, t_i^r > x_i > t_i^l\\\\ t_i^l + a_i^l(x_i - t_i^l), x_i \\leq t_i^l \\\\ \\end{matrix}\\right. 5 | See SReLU paper: 6 | https://arxiv.org/pdf/1512.07030.pdf 7 | """ 8 | 9 | # import pytorch 10 | import torch 11 | from torch import nn 12 | from torch.nn.parameter import Parameter 13 | 14 | 15 | class SReLU(nn.Module): 16 | """ 17 | SReLU (S-shaped Rectified Linear Activation Unit): a combination of three linear functions, which perform mapping R → R with the following formulation: 18 | .. math:: 19 | h(x_i) = \\left\\{\\begin{matrix} t_i^r + a_i^r(x_i - t_i^r), x_i \\geq t_i^r \\\\ x_i, t_i^r > x_i > t_i^l\\\\ t_i^l + a_i^l(x_i - t_i^l), x_i \\leq t_i^l \\\\ \\end{matrix}\\right. 20 | with 4 trainable parameters. 21 | Shape: 22 | - Input: (N, *) where * means, any number of additional 23 | dimensions 24 | - Output: (N, *), same shape as the input 25 | Parameters: 26 | .. math:: \\{t_i^r, a_i^r, t_i^l, a_i^l\\} 27 | 4 trainable parameters, which model an individual SReLU activation unit. The subscript i indicates that we allow SReLU to vary in different channels. Parameters can be initialized manually or randomly. 28 | References: 29 | - See SReLU paper: 30 | https://arxiv.org/pdf/1512.07030.pdf 31 | Examples: 32 | >>> srelu_activation = srelu((2,2)) 33 | >>> t = torch.randn((2,2), dtype=torch.float, requires_grad = True) 34 | >>> output = srelu_activation(t) 35 | """ 36 | 37 | def __init__(self, in_features, parameters=None): 38 | """ 39 | Initialization. 40 | INPUT: 41 | - in_features: shape of the input 42 | - parameters: (tr, tl, ar, al) parameters for manual initialization, default value is None. If None is passed, parameters are initialized randomly. 43 | """ 44 | super(SReLU, self).__init__() 45 | self.in_features = in_features 46 | 47 | if parameters is None: 48 | self.tr = Parameter( 49 | torch.randn(in_features, dtype=torch.float, requires_grad=True) 50 | ) 51 | self.tl = Parameter( 52 | torch.randn(in_features, dtype=torch.float, requires_grad=True) 53 | ) 54 | self.ar = Parameter( 55 | torch.randn(in_features, dtype=torch.float, requires_grad=True) 56 | ) 57 | self.al = Parameter( 58 | torch.randn(in_features, dtype=torch.float, requires_grad=True) 59 | ) 60 | else: 61 | self.tr, self.tl, self.ar, self.al = parameters 62 | 63 | def forward(self, x): 64 | """ 65 | Forward pass of the function 66 | """ 67 | return ( 68 | (x >= self.tr).float() * (self.tr + self.ar * (x + self.tr)) 69 | + (x < self.tr).float() * (x > self.tl).float() * x 70 | + (x <= self.tl).float() * (self.tl + self.al * (x + self.tl)) 71 | ) -------------------------------------------------------------------------------- /SViTE/backup/sparselearning/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | logging.getLogger(__name__).addHandler(logging.NullHandler()) 3 | -------------------------------------------------------------------------------- /SViTE/backup/sparselearning/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VITA-Group/SViTE/7dc89fd8fa5f86e797620f00cce6f11e2b73765d/SViTE/backup/sparselearning/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /SViTE/backup/sparselearning/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VITA-Group/SViTE/7dc89fd8fa5f86e797620f00cce6f11e2b73765d/SViTE/backup/sparselearning/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /SViTE/backup/sparselearning/__pycache__/core.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VITA-Group/SViTE/7dc89fd8fa5f86e797620f00cce6f11e2b73765d/SViTE/backup/sparselearning/__pycache__/core.cpython-36.pyc -------------------------------------------------------------------------------- /SViTE/backup/sparselearning/__pycache__/models.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VITA-Group/SViTE/7dc89fd8fa5f86e797620f00cce6f11e2b73765d/SViTE/backup/sparselearning/__pycache__/models.cpython-35.pyc -------------------------------------------------------------------------------- /SViTE/backup/sparselearning/__pycache__/models.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VITA-Group/SViTE/7dc89fd8fa5f86e797620f00cce6f11e2b73765d/SViTE/backup/sparselearning/__pycache__/models.cpython-36.pyc -------------------------------------------------------------------------------- /SViTE/backup/sparselearning/__pycache__/snip.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VITA-Group/SViTE/7dc89fd8fa5f86e797620f00cce6f11e2b73765d/SViTE/backup/sparselearning/__pycache__/snip.cpython-35.pyc -------------------------------------------------------------------------------- /SViTE/backup/sparselearning/__pycache__/snip.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VITA-Group/SViTE/7dc89fd8fa5f86e797620f00cce6f11e2b73765d/SViTE/backup/sparselearning/__pycache__/snip.cpython-36.pyc -------------------------------------------------------------------------------- /SViTE/backup/sparselearning/__pycache__/sparse_sgd.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VITA-Group/SViTE/7dc89fd8fa5f86e797620f00cce6f11e2b73765d/SViTE/backup/sparselearning/__pycache__/sparse_sgd.cpython-36.pyc -------------------------------------------------------------------------------- /SViTE/backup/sparselearning/__pycache__/utils.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VITA-Group/SViTE/7dc89fd8fa5f86e797620f00cce6f11e2b73765d/SViTE/backup/sparselearning/__pycache__/utils.cpython-35.pyc -------------------------------------------------------------------------------- /SViTE/backup/sparselearning/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VITA-Group/SViTE/7dc89fd8fa5f86e797620f00cce6f11e2b73765d/SViTE/backup/sparselearning/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /SViTE/backup/sparselearning/copy/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | grad = torch.tensor([[-0.1,-0.2,-0.3],[0,0,0],[0.1,0.2,0.3]]) 4 | new_mask = torch.tensor([[0,0,0],[0,0,0],[0,0,0]]) 5 | print(new_mask) 6 | total_regrowth = 6 7 | y, idx = torch.sort(torch.abs(grad).flatten(), descending=True) 8 | new_mask.data.view(-1)[idx[:total_regrowth]] = 1.0 9 | print(new_mask) 10 | -------------------------------------------------------------------------------- /SViTE/backup/sparselearning/copy/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from torchvision import datasets, transforms 7 | 8 | class DatasetSplitter(torch.utils.data.Dataset): 9 | """This splitter makes sure that we always use the same training/validation split""" 10 | def __init__(self,parent_dataset,split_start=-1,split_end= -1): 11 | split_start = split_start if split_start != -1 else 0 12 | split_end = split_end if split_end != -1 else len(parent_dataset) 13 | assert split_start <= len(parent_dataset) - 1 and split_end <= len(parent_dataset) and split_start < split_end , "invalid dataset split" 14 | 15 | self.parent_dataset = parent_dataset 16 | self.split_start = split_start 17 | self.split_end = split_end 18 | 19 | def __len__(self): 20 | return self.split_end - self.split_start 21 | 22 | 23 | def __getitem__(self,index): 24 | assert index < len(self),"index out of bounds in split_datset" 25 | return self.parent_dataset[index + self.split_start] 26 | 27 | def get_cifar10_dataloaders(args, validation_split=0.0, max_threads=10): 28 | """Creates augmented train, validation, and test data loaders.""" 29 | 30 | normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), 31 | (0.2023, 0.1994, 0.2010)) 32 | 33 | train_transform = transforms.Compose([ 34 | transforms.ToTensor(), 35 | transforms.Lambda(lambda x: F.pad(x.unsqueeze(0), 36 | (4,4,4,4),mode='reflect').squeeze()), 37 | transforms.ToPILImage(), 38 | transforms.RandomCrop(32), 39 | transforms.RandomHorizontalFlip(), 40 | transforms.ToTensor(), 41 | print('no normalize!') 42 | ]) 43 | 44 | test_transform = transforms.Compose([ 45 | transforms.ToTensor(), 46 | normalize 47 | ]) 48 | 49 | full_dataset = datasets.CIFAR10('_dataset', True, train_transform, download=True) 50 | test_dataset = datasets.CIFAR10('_dataset', False, test_transform, download=False) 51 | 52 | 53 | # we need at least two threads 54 | max_threads = 2 if max_threads < 2 else max_threads 55 | if max_threads >= 6: 56 | val_threads = 2 57 | train_threads = max_threads - val_threads 58 | else: 59 | val_threads = 1 60 | train_threads = max_threads - 1 61 | 62 | 63 | valid_loader = None 64 | if validation_split > 0.0: 65 | split = int(np.floor((1.0-validation_split) * len(full_dataset))) 66 | train_dataset = DatasetSplitter(full_dataset,split_end=split) 67 | val_dataset = DatasetSplitter(full_dataset,split_start=split) 68 | train_loader = torch.utils.data.DataLoader( 69 | train_dataset, 70 | args.batch_size, 71 | num_workers=train_threads, 72 | pin_memory=True, shuffle=True) 73 | valid_loader = torch.utils.data.DataLoader( 74 | val_dataset, 75 | args.test_batch_size, 76 | num_workers=val_threads, 77 | pin_memory=True) 78 | else: 79 | train_loader = torch.utils.data.DataLoader( 80 | full_dataset, 81 | args.batch_size, 82 | num_workers=8, 83 | pin_memory=True, shuffle=True) 84 | 85 | print('Train loader length', len(train_loader)) 86 | 87 | test_loader = torch.utils.data.DataLoader( 88 | test_dataset, 89 | args.test_batch_size, 90 | shuffle=False, 91 | num_workers=1, 92 | pin_memory=True) 93 | 94 | train_loader_jac = torch.utils.data.DataLoader( 95 | full_dataset, 96 | batch_size=args.batch_size_jac, shuffle=False, num_workers=1, 97 | pin_memory=True) 98 | 99 | return train_loader, valid_loader, test_loader, train_loader_jac 100 | 101 | 102 | def get_mnist_dataloaders(args, validation_split=0.0): 103 | """Creates augmented train, validation, and test data loaders.""" 104 | normalize = transforms.Normalize((0.1307,), (0.3081,)) 105 | transform = transform=transforms.Compose([transforms.ToTensor(),normalize]) 106 | 107 | full_dataset = datasets.MNIST('../data', train=True, download=True, transform=transform) 108 | test_dataset = datasets.MNIST('../data', train=False, transform=transform) 109 | 110 | dataset_size = len(full_dataset) 111 | indices = list(range(dataset_size)) 112 | split = int(np.floor(validation_split * dataset_size)) 113 | 114 | valid_loader = None 115 | if validation_split > 0.0: 116 | split = int(np.floor((1.0-validation_split) * len(full_dataset))) 117 | train_dataset = DatasetSplitter(full_dataset,split_end=split) 118 | val_dataset = DatasetSplitter(full_dataset,split_start=split) 119 | train_loader = torch.utils.data.DataLoader( 120 | train_dataset, 121 | args.batch_size, 122 | num_workers=8, 123 | pin_memory=True, shuffle=True) 124 | valid_loader = torch.utils.data.DataLoader( 125 | val_dataset, 126 | args.test_batch_size, 127 | num_workers=2, 128 | pin_memory=True) 129 | else: 130 | train_loader = torch.utils.data.DataLoader( 131 | full_dataset, 132 | args.batch_size, 133 | num_workers=8, 134 | pin_memory=True, shuffle=True) 135 | 136 | print('Train loader length', len(train_loader)) 137 | 138 | test_loader = torch.utils.data.DataLoader( 139 | test_dataset, 140 | args.test_batch_size, 141 | shuffle=False, 142 | num_workers=1, 143 | pin_memory=True) 144 | 145 | return train_loader, valid_loader, test_loader 146 | 147 | 148 | def plot_class_feature_histograms(args, model, device, test_loader, optimizer): 149 | if not os.path.exists('./results'): os.mkdir('./results') 150 | model.eval() 151 | agg = {} 152 | num_classes = 10 153 | feat_id = 0 154 | sparse = not args.dense 155 | model_name = 'alexnet' 156 | #model_name = 'vgg' 157 | #model_name = 'wrn' 158 | 159 | 160 | densities = None 161 | for batch_idx, (data, target) in enumerate(test_loader): 162 | if batch_idx % 100 == 0: print(batch_idx,'/', len(test_loader)) 163 | with torch.no_grad(): 164 | #if batch_idx == 10: break 165 | data, target = data.to(device), target.to(device) 166 | for cls in range(num_classes): 167 | #print('=='*50) 168 | #print('CLASS {0}'.format(cls)) 169 | model.t = target 170 | sub_data = data[target == cls] 171 | 172 | output = model(sub_data) 173 | 174 | feats = model.feats 175 | if densities is None: 176 | densities = [] 177 | densities += model.densities 178 | 179 | if len(agg) == 0: 180 | for feat_id, feat in enumerate(feats): 181 | agg[feat_id] = [] 182 | #print(feat.shape) 183 | for i in range(feat.shape[1]): 184 | agg[feat_id].append(np.zeros((num_classes,))) 185 | 186 | for feat_id, feat in enumerate(feats): 187 | map_contributions = torch.abs(feat).sum([0, 2, 3]) 188 | for map_id in range(map_contributions.shape[0]): 189 | #print(feat_id, map_id, cls) 190 | #print(len(agg), len(agg[feat_id]), len(agg[feat_id][map_id]), len(feats)) 191 | agg[feat_id][map_id][cls] += map_contributions[map_id].item() 192 | 193 | del model.feats[:] 194 | del model.densities[:] 195 | model.feats = [] 196 | model.densities = [] 197 | 198 | if sparse: 199 | np.save('./results/{0}_sparse_density_data'.format(model_name), densities) 200 | 201 | for feat_id, map_data in agg.items(): 202 | data = np.array(map_data) 203 | #print(feat_id, data) 204 | full_contribution = data.sum() 205 | #print(full_contribution, data) 206 | contribution_per_channel = ((1.0/full_contribution)*data.sum(1)) 207 | #print('pre', data.shape[0]) 208 | channels = data.shape[0] 209 | #data = data[contribution_per_channel > 0.001] 210 | 211 | channel_density = np.cumsum(np.sort(contribution_per_channel)) 212 | print(channel_density) 213 | idx = np.argsort(contribution_per_channel) 214 | 215 | threshold_idx = np.searchsorted(channel_density, 0.05) 216 | print(data.shape, 'pre') 217 | data = data[idx[threshold_idx:]] 218 | print(data.shape, 'post') 219 | 220 | #perc = np.percentile(contribution_per_channel[contribution_per_channel > 0.0], 10) 221 | #print(contribution_per_channel, perc, feat_id) 222 | #data = data[contribution_per_channel > perc] 223 | #print(contribution_per_channel[contribution_per_channel < perc].sum()) 224 | #print('post', data.shape[0]) 225 | normed_data = np.max(data/np.sum(data,1).reshape(-1, 1), 1) 226 | #normed_data = (data/np.sum(data,1).reshape(-1, 1) > 0.2).sum(1) 227 | #counts, bins = np.histogram(normed_data, bins=4, range=(0, 4)) 228 | np.save('./results/{2}_{1}_feat_data_layer_{0}'.format(feat_id, 'sparse' if sparse else 'dense', model_name), normed_data) 229 | #plt.ylim(0, channels/2.0) 230 | ##plt.hist(normed_data, bins=range(0, 5)) 231 | #plt.hist(normed_data, bins=[(i+20)/float(200) for i in range(180)]) 232 | #plt.xlim(0.1, 0.5) 233 | #if sparse: 234 | # plt.title("Sparse: Conv2D layer {0}".format(feat_id)) 235 | # plt.savefig('./output/feat_histo/layer_{0}_sp.png'.format(feat_id)) 236 | #else: 237 | # plt.title("Dense: Conv2D layer {0}".format(feat_id)) 238 | # plt.savefig('./output/feat_histo/layer_{0}_d.png'.format(feat_id)) 239 | #plt.clf() 240 | -------------------------------------------------------------------------------- /SViTE/backup/sparselearning/funcs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | ''' 4 | REDISTRIBUTION 5 | ''' 6 | 7 | def momentum_redistribution(masking, name, weight, mask): 8 | """Calculates momentum redistribution statistics. 9 | 10 | Args: 11 | masking Masking class with state about current 12 | layers and the entire sparse network. 13 | 14 | name The name of the layer. This can be used to 15 | access layer-specific statistics in the 16 | masking class. 17 | 18 | weight The weight of the respective sparse layer. 19 | This is a torch parameter. 20 | 21 | mask The binary mask. 1s indicated active weights. 22 | 23 | Returns: 24 | Layer Statistic The unnormalized layer statistics 25 | for the layer "name". A higher value indicates 26 | that more pruned parameters are redistributed 27 | to this layer compared to layers with lower value. 28 | The values will be automatically sum-normalized 29 | after this step. 30 | 31 | 32 | The calculation of redistribution statistics is the first 33 | step in this sparse learning library. 34 | """ 35 | grad = masking.get_momentum_for_weight(weight) 36 | mean_magnitude = torch.abs(grad[mask.bool()]).mean().item() 37 | return mean_magnitude 38 | 39 | def magnitude_redistribution(masking, name, weight, mask): 40 | mean_magnitude = torch.abs(weight)[mask.bool()].mean().item() 41 | return mean_magnitude 42 | 43 | def nonzero_redistribution(masking, name, weight, mask): 44 | nonzero = (weight !=0.0).sum().item() 45 | return nonzero 46 | 47 | def no_redistribution(masking, name, weight, mask): 48 | num_params = masking.baseline_nonzero 49 | n = weight.numel() 50 | return n/float(num_params) 51 | 52 | 53 | ''' 54 | PRUNE 55 | ''' 56 | def magnitude_prune(masking, mask, weight, name): 57 | """Prunes the weights with smallest magnitude. 58 | 59 | The pruning functions in this sparse learning library 60 | work by constructing a binary mask variable "mask" 61 | which prevents gradient flow to weights and also 62 | sets the weights to zero where the binary mask is 0. 63 | Thus 1s in the "mask" variable indicate where the sparse 64 | network has active weights. In this function name 65 | and masking can be used to access global statistics 66 | about the specific layer (name) and the sparse network 67 | as a whole. 68 | 69 | Args: 70 | masking Masking class with state about current 71 | layers and the entire sparse network. 72 | 73 | mask The binary mask. 1s indicated active weights. 74 | 75 | weight The weight of the respective sparse layer. 76 | This is a torch parameter. 77 | 78 | name The name of the layer. This can be used to 79 | access layer-specific statistics in the 80 | masking class. 81 | 82 | Returns: 83 | mask Pruned Binary mask where 1s indicated active 84 | weights. Can be modified in-place or newly 85 | constructed 86 | 87 | Accessable global statistics: 88 | 89 | Layer statistics: 90 | Non-zero count of layer: 91 | masking.name2nonzeros[name] 92 | Zero count of layer: 93 | masking.name2zeros[name] 94 | Redistribution proportion: 95 | masking.name2variance[name] 96 | Number of items removed through pruning: 97 | masking.name2removed[name] 98 | 99 | Network statistics: 100 | Total number of nonzero parameter in the network: 101 | masking.total_nonzero = 0 102 | Total number of zero-valued parameter in the network: 103 | masking.total_zero = 0 104 | Total number of parameters removed in pruning: 105 | masking.total_removed = 0 106 | """ 107 | num_remove = math.ceil(masking.name2prune_rate[name]*masking.name2nonzeros[name]) 108 | num_zeros = masking.name2zeros[name] 109 | k = math.ceil(num_zeros + num_remove) 110 | if num_remove == 0.0: return weight.data != 0.0 111 | 112 | x, idx = torch.sort(torch.abs(weight.data.view(-1))) 113 | mask.data.view(-1)[idx[:k]] = 0.0 114 | return mask 115 | 116 | def global_magnitude_prune(masking): 117 | prune_rate = 0.0 118 | for name in masking.name2prune_rate: 119 | if name in masking.masks: 120 | prune_rate = masking.name2prune_rate[name] 121 | tokill = math.ceil(prune_rate*masking.baseline_nonzero) 122 | total_removed = 0 123 | prev_removed = 0 124 | while total_removed < tokill*(1.0-masking.tolerance) or (total_removed > tokill*(1.0+masking.tolerance)): 125 | total_removed = 0 126 | for module in masking.modules: 127 | for name, weight in module.named_parameters(): 128 | if name not in masking.masks: continue 129 | remain = (torch.abs(weight.data) > masking.prune_threshold).sum().item() 130 | total_removed += masking.name2nonzeros[name] - remain 131 | 132 | if prev_removed == total_removed: break 133 | prev_removed = total_removed 134 | if total_removed > tokill*(1.0+masking.tolerance): 135 | masking.prune_threshold *= 1.0-masking.increment 136 | masking.increment *= 0.99 137 | elif total_removed < tokill*(1.0-masking.tolerance): 138 | masking.prune_threshold *= 1.0+masking.increment 139 | masking.increment *= 0.99 140 | 141 | for module in masking.modules: 142 | for name, weight in module.named_parameters(): 143 | if name not in masking.masks: continue 144 | masking.masks[name][:] = torch.abs(weight.data) > masking.prune_threshold 145 | 146 | return int(total_removed) 147 | 148 | 149 | def magnitude_and_negativity_prune(masking, mask, weight, name): 150 | num_remove = math.ceil(masking.name2prune_rate[name]*masking.name2nonzeros[name]) 151 | if num_remove == 0.0: return weight.data != 0.0 152 | 153 | num_zeros = masking.name2zeros[name] 154 | k = math.ceil(num_zeros + (num_remove/2.0)) 155 | 156 | # remove all weights which absolute value is smaller than threshold 157 | x, idx = torch.sort(torch.abs(weight.data.view(-1))) 158 | mask.data.view(-1)[idx[:k]] = 0.0 159 | 160 | # remove the most negative weights 161 | x, idx = torch.sort(weight.data.view(-1)) 162 | mask.data.view(-1)[idx[:math.ceil(num_remove/2.0)]] = 0.0 163 | 164 | return mask 165 | 166 | ''' 167 | GROWTH 168 | ''' 169 | 170 | def random_growth(masking, name, new_mask, total_regrowth, weight): 171 | n = (new_mask==0).sum().item() 172 | if n == 0: return new_mask 173 | expeced_growth_probability = (total_regrowth/n) 174 | new_weights = torch.rand(new_mask.shape).cuda() < expeced_growth_probability 175 | return new_mask.bool() | new_weights 176 | 177 | def momentum_growth(masking, name, new_mask, total_regrowth, weight): 178 | """Grows weights in places where the momentum is largest. 179 | 180 | Growth function in the sparse learning library work by 181 | changing 0s to 1s in a binary mask which will enable 182 | gradient flow. Weights default value are 0 and it can 183 | be changed in this function. The number of parameters 184 | to be regrown is determined by the total_regrowth 185 | parameter. The masking object in conjunction with the name 186 | of the layer enables the access to further statistics 187 | and objects that allow more flexibility to implement 188 | custom growth functions. 189 | 190 | Args: 191 | masking Masking class with state about current 192 | layers and the entire sparse network. 193 | 194 | name The name of the layer. This can be used to 195 | access layer-specific statistics in the 196 | masking class. 197 | 198 | new_mask The binary mask. 1s indicated active weights. 199 | This binary mask has already been pruned in the 200 | pruning step that preceeds the growth step. 201 | 202 | total_regrowth This variable determines the number of 203 | parameters to regrowtn in this function. 204 | It is automatically determined by the 205 | redistribution function and algorithms 206 | internal to the sparselearning library. 207 | 208 | weight The weight of the respective sparse layer. 209 | This is a torch parameter. 210 | 211 | Returns: 212 | mask Binary mask with newly grown weights. 213 | 1s indicated active weights in the binary mask. 214 | 215 | Access to optimizer: 216 | masking.optimizer 217 | 218 | Access to momentum/Adam update: 219 | masking.get_momentum_for_weight(weight) 220 | 221 | Accessable global statistics: 222 | 223 | Layer statistics: 224 | Non-zero count of layer: 225 | masking.name2nonzeros[name] 226 | Zero count of layer: 227 | masking.name2zeros[name] 228 | Redistribution proportion: 229 | masking.name2variance[name] 230 | Number of items removed through pruning: 231 | masking.name2removed[name] 232 | 233 | Network statistics: 234 | Total number of nonzero parameter in the network: 235 | masking.total_nonzero = 0 236 | Total number of zero-valued parameter in the network: 237 | masking.total_zero = 0 238 | Total number of parameters removed in pruning: 239 | masking.total_removed = 0 240 | """ 241 | grad = masking.get_momentum_for_weight(weight) 242 | if grad.dtype == torch.float16: 243 | grad = grad*(new_mask==0).half() 244 | else: 245 | grad = grad*(new_mask==0).float() 246 | y, idx = torch.sort(torch.abs(grad).flatten(), descending=True) 247 | new_mask.data.view(-1)[idx[:total_regrowth]] = 1.0 248 | 249 | return new_mask 250 | 251 | def momentum_neuron_growth(masking, name, new_mask, total_regrowth, weight): 252 | grad = masking.get_momentum_for_weight(weight) 253 | 254 | M = torch.abs(grad) 255 | if len(M.shape) == 2: sum_dim = [1] 256 | elif len(M.shape) == 4: sum_dim = [1, 2, 3] 257 | 258 | v = M.mean(sum_dim).data 259 | v /= v.sum() 260 | 261 | slots_per_neuron = (new_mask==0).sum(sum_dim) 262 | 263 | M = M*(new_mask==0).float() 264 | for i, fraction in enumerate(v): 265 | neuron_regrowth = math.floor(fraction.item()*total_regrowth) 266 | available = slots_per_neuron[i].item() 267 | 268 | y, idx = torch.sort(M[i].flatten()) 269 | if neuron_regrowth > available: 270 | neuron_regrowth = available 271 | # TODO: Work into more stable growth method 272 | threshold = y[-(neuron_regrowth)].item() 273 | if threshold == 0.0: continue 274 | if neuron_regrowth < 10: continue 275 | new_mask[i] = new_mask[i] | (M[i] > threshold) 276 | 277 | return new_mask 278 | 279 | 280 | def global_momentum_growth(masking, total_regrowth): 281 | togrow = total_regrowth 282 | total_grown = 0 283 | last_grown = 0 284 | while total_grown < togrow*(1.0-masking.tolerance) or (total_grown > togrow*(1.0+masking.tolerance)): 285 | total_grown = 0 286 | total_possible = 0 287 | for module in masking.modules: 288 | for name, weight in module.named_parameters(): 289 | if name not in masking.masks: continue 290 | 291 | new_mask = masking.masks[name] 292 | grad = masking.get_momentum_for_weight(weight) 293 | grad = grad*(new_mask==0).float() 294 | possible = (grad !=0.0).sum().item() 295 | total_possible += possible 296 | grown = (torch.abs(grad.data) > masking.growth_threshold).sum().item() 297 | total_grown += grown 298 | if total_grown == last_grown: break 299 | last_grown = total_grown 300 | 301 | 302 | if total_grown > togrow*(1.0+masking.tolerance): 303 | masking.growth_threshold *= 1.02 304 | #masking.growth_increment *= 0.95 305 | elif total_grown < togrow*(1.0-masking.tolerance): 306 | masking.growth_threshold *= 0.98 307 | #masking.growth_increment *= 0.95 308 | 309 | total_new_nonzeros = 0 310 | for module in masking.modules: 311 | for name, weight in module.named_parameters(): 312 | if name not in masking.masks: continue 313 | 314 | new_mask = masking.masks[name] 315 | grad = masking.get_momentum_for_weight(weight) 316 | grad = grad*(new_mask==0).float() 317 | masking.masks[name][:] = (new_mask.bool() | (torch.abs(grad.data) > masking.growth_threshold)).float() 318 | total_new_nonzeros += new_mask.sum().item() 319 | return total_new_nonzeros 320 | 321 | 322 | 323 | 324 | prune_funcs = {} 325 | prune_funcs['magnitude'] = magnitude_prune 326 | prune_funcs['SET'] = magnitude_and_negativity_prune 327 | prune_funcs['global_magnitude'] = global_magnitude_prune 328 | 329 | growth_funcs = {} 330 | growth_funcs['random'] = random_growth 331 | growth_funcs['momentum'] = momentum_growth 332 | growth_funcs['momentum_neuron'] = momentum_neuron_growth 333 | growth_funcs['global_momentum_growth'] = global_momentum_growth 334 | 335 | redistribution_funcs = {} 336 | redistribution_funcs['momentum'] = momentum_redistribution 337 | redistribution_funcs['nonzero'] = nonzero_redistribution 338 | redistribution_funcs['magnitude'] = magnitude_redistribution 339 | redistribution_funcs['none'] = no_redistribution 340 | -------------------------------------------------------------------------------- /SViTE/backup/sparselearning/resnet.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | 3 | For Pre-activation ResNet, see 'preact_resnet.py'. 4 | 5 | Reference: 6 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 7 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 8 | ''' 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | class BasicBlock(nn.Module): 15 | expansion = 1 16 | 17 | def __init__(self, in_planes, planes, stride=1): 18 | super(BasicBlock, self).__init__() 19 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 20 | self.bn1 = nn.BatchNorm2d(planes) 21 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | 24 | self.shortcut = nn.Sequential() 25 | if stride != 1 or in_planes != self.expansion*planes: 26 | self.shortcut = nn.Sequential( 27 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 28 | nn.BatchNorm2d(self.expansion*planes) 29 | ) 30 | 31 | def forward(self, x): 32 | out = F.relu(self.bn1(self.conv1(x))) 33 | out = self.bn2(self.conv2(out)) 34 | out += self.shortcut(x) 35 | out = F.relu(out) 36 | return out 37 | 38 | 39 | class Bottleneck(nn.Module): 40 | expansion = 4 41 | 42 | def __init__(self, in_planes, planes, stride=1): 43 | super(Bottleneck, self).__init__() 44 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 45 | self.bn1 = nn.BatchNorm2d(planes) 46 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 47 | self.bn2 = nn.BatchNorm2d(planes) 48 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 49 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 50 | 51 | self.shortcut = nn.Sequential() 52 | if stride != 1 or in_planes != self.expansion*planes: 53 | self.shortcut = nn.Sequential( 54 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 55 | nn.BatchNorm2d(self.expansion*planes) 56 | ) 57 | 58 | def forward(self, x): 59 | out = F.relu(self.bn1(self.conv1(x))) 60 | out = F.relu(self.bn2(self.conv2(out))) 61 | out = self.bn3(self.conv3(out)) 62 | out += self.shortcut(x) 63 | out = F.relu(out) 64 | return out 65 | 66 | 67 | class ResNet(nn.Module): 68 | def __init__(self, block, num_blocks, num_classes): 69 | super(ResNet, self).__init__() 70 | self.in_planes = 64 71 | 72 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 73 | self.bn1 = nn.BatchNorm2d(64) 74 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 75 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 76 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 77 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 78 | self.classifier = nn.Linear(512*block.expansion, num_classes, bias=False) 79 | 80 | def _make_layer(self, block, planes, num_blocks, stride): 81 | strides = [stride] + [1]*(num_blocks-1) 82 | layers = [] 83 | for stride in strides: 84 | layers.append(block(self.in_planes, planes, stride)) 85 | self.in_planes = planes * block.expansion 86 | return nn.Sequential(*layers) 87 | 88 | def forward(self, x): 89 | out = F.relu(self.bn1(self.conv1(x))) 90 | out = self.layer1(out) 91 | out = self.layer2(out) 92 | out = self.layer3(out) 93 | out = self.layer4(out) 94 | out = F.avg_pool2d(out, 4) 95 | out = out.view(out.size(0), -1) 96 | out = self.classifier(out) 97 | out = F.log_softmax(out, dim=1) 98 | return out 99 | 100 | 101 | def ResNet18(c=1000): 102 | return ResNet(BasicBlock, [2,2,2,2],c) 103 | 104 | def ResNet34(c=10): 105 | return ResNet(BasicBlock, [3,4,6,3],c) 106 | 107 | def ResNet50(c=10): 108 | return ResNet(Bottleneck, [3,4,6,3],c) 109 | 110 | def ResNet101(c=10): 111 | return ResNet(Bottleneck, [3,4,23,3],c) 112 | 113 | def ResNet152(c=10): 114 | return ResNet(Bottleneck, [3,8,36,3],c) 115 | 116 | 117 | def test(): 118 | net = ResNet18() 119 | y = net(torch.randn(1,3,32,32)) 120 | print(y.size()) 121 | 122 | # test() 123 | -------------------------------------------------------------------------------- /SViTE/backup/sparselearning/sparse_sgd.py: -------------------------------------------------------------------------------- 1 | from torch.optim.optimizer import Optimizer, required 2 | import torch 3 | import numpy as np 4 | class sparse_SGD(Optimizer): 5 | r"""Implements sparse stochastic gradient descent (optionally with momentum), according to the pytorch version 1.5.1. 6 | 7 | Nesterov momentum is based on the formula from 8 | `On the importance of initialization and momentum in deep learning`__. 9 | 10 | Args: 11 | params (iterable): iterable of parameters to optimize or dicts defining 12 | parameter groups 13 | lr (float): learning rate 14 | momentum (float, optional): momentum factor (default: 0) 15 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 16 | dampening (float, optional): dampening for momentum (default: 0) 17 | nesterov (bool, optional): enables Nesterov momentum (default: False) 18 | 19 | Example: 20 | >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) 21 | >>> optimizer.zero_grad() 22 | >>> loss_fn(model(input), target).backward() 23 | >>> optimizer.step() 24 | 25 | __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf 26 | 27 | .. note:: 28 | The implementation of SGD with Momentum/Nesterov subtly differs from 29 | Sutskever et. al. and implementations in some other frameworks. 30 | 31 | Considering the specific case of Momentum, the update can be written as 32 | 33 | .. math:: 34 | v = \rho * v + g \\ 35 | p = p - lr * v 36 | 37 | where p, g, v and :math:`\rho` denote the parameters, gradient, 38 | velocity, and momentum respectively. 39 | 40 | This is in contrast to Sutskever et. al. and 41 | other frameworks which employ an update of the form 42 | 43 | .. math:: 44 | v = \rho * v + lr * g \\ 45 | p = p - v 46 | 47 | The Nesterov version is analogously modified. 48 | """ 49 | 50 | def __init__(self, params, lr=required, momentum=0, dampening=0, 51 | weight_decay=0, nesterov=False): 52 | if lr is not required and lr < 0.0: 53 | raise ValueError("Invalid learning rate: {}".format(lr)) 54 | if momentum < 0.0: 55 | raise ValueError("Invalid momentum value: {}".format(momentum)) 56 | if weight_decay < 0.0: 57 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 58 | 59 | defaults = dict(lr=lr, momentum=momentum, dampening=dampening, 60 | weight_decay=weight_decay, nesterov=nesterov) 61 | if nesterov and (momentum <= 0 or dampening != 0): 62 | raise ValueError("Nesterov momentum requires a momentum and zero dampening") 63 | super(sparse_SGD, self).__init__(params, defaults) 64 | 65 | def __setstate__(self, state): 66 | super(sparse_SGD, self).__setstate__(state) 67 | for group in self.param_groups: 68 | group.setdefault('nesterov', False) 69 | 70 | @torch.no_grad() 71 | def step(self, closure=None, nonzero_masks=None, new_masks=None, gamma=None, epoch=None): 72 | """Performs a single optimization step. 73 | 74 | Arguments: 75 | closure (callable, optional): A closure that reevaluates the model 76 | and returns the loss. 77 | """ 78 | loss = None 79 | if closure is not None: 80 | with torch.enable_grad(): 81 | loss = closure() 82 | 83 | if epoch <= 100: 84 | for group in self.param_groups: 85 | weight_decay = group['weight_decay'] 86 | momentum = group['momentum'] 87 | dampening = group['dampening'] 88 | nesterov = group['nesterov'] 89 | 90 | for p in group['params']: 91 | if p.grad is None: 92 | continue 93 | d_p = p.grad 94 | if weight_decay != 0: 95 | d_p = d_p.add(p, alpha=weight_decay) 96 | if momentum != 0: 97 | param_state = self.state[p] 98 | if 'momentum_buffer' not in param_state: 99 | buf = param_state['momentum_buffer'] = torch.clone(d_p).detach() 100 | else: 101 | buf = param_state['momentum_buffer'] 102 | buf.mul_(momentum).add_(d_p, alpha=1 - dampening) 103 | if nesterov: 104 | d_p = d_p.add(buf, alpha=momentum) 105 | else: 106 | d_p = buf 107 | 108 | p.add_(d_p, alpha=-group['lr']) 109 | else: 110 | for group in self.param_groups: 111 | weight_decay = group['weight_decay'] 112 | momentum = group['momentum'] 113 | dampening = group['dampening'] 114 | nesterov = group['nesterov'] 115 | 116 | for i, p in enumerate(group['params']): 117 | if p.grad is None: 118 | continue 119 | 120 | sparse_layer_flag = False 121 | for key in nonzero_masks.keys(): 122 | if i == float(key.split('_')[-1]): 123 | nonzero_mask = nonzero_masks[key] 124 | new_mask = new_masks[key] 125 | sparse_layer_flag = True 126 | 127 | d_p = p.grad 128 | if weight_decay != 0: 129 | d_p = d_p.add(p, alpha=weight_decay) 130 | if momentum != 0: 131 | param_state = self.state[p] 132 | if 'momentum_buffer' not in param_state: 133 | buf = param_state['momentum_buffer'] = torch.clone(d_p).detach() 134 | else: 135 | buf = param_state['momentum_buffer'] 136 | buf.mul_(momentum).add_(d_p, alpha=1 - dampening) 137 | if nesterov: 138 | d_p = d_p.add(buf, alpha=momentum) 139 | else: 140 | d_p = buf 141 | 142 | p.add_(d_p, alpha=-group['lr']) 143 | 144 | if sparse_layer_flag: 145 | p.add_(d_p * nonzero_mask, alpha=-group['lr']) 146 | p.add_(d_p * new_mask, alpha=-gamma) 147 | 148 | else: 149 | p.add_(d_p, alpha=-group['lr']) 150 | 151 | return loss -------------------------------------------------------------------------------- /SViTE/backup/sparselearning/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from torchvision import datasets, transforms 7 | 8 | class DatasetSplitter(torch.utils.data.Dataset): 9 | """This splitter makes sure that we always use the same training/validation split""" 10 | def __init__(self,parent_dataset,split_start=-1,split_end= -1): 11 | split_start = split_start if split_start != -1 else 0 12 | split_end = split_end if split_end != -1 else len(parent_dataset) 13 | assert split_start <= len(parent_dataset) - 1 and split_end <= len(parent_dataset) and split_start < split_end , "invalid dataset split" 14 | 15 | self.parent_dataset = parent_dataset 16 | self.split_start = split_start 17 | self.split_end = split_end 18 | 19 | def __len__(self): 20 | return self.split_end - self.split_start 21 | 22 | 23 | def __getitem__(self,index): 24 | assert index < len(self),"index out of bounds in split_datset" 25 | return self.parent_dataset[index + self.split_start] 26 | 27 | def get_cifar10_dataloaders(args, validation_split=0.0, max_threads=10): 28 | """Creates augmented train, validation, and test data loaders.""" 29 | 30 | normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), 31 | (0.2023, 0.1994, 0.2010)) 32 | 33 | train_transform = transforms.Compose([ 34 | transforms.ToTensor(), 35 | transforms.Lambda(lambda x: F.pad(x.unsqueeze(0), 36 | (4,4,4,4),mode='reflect').squeeze()), 37 | transforms.ToPILImage(), 38 | transforms.RandomCrop(32), 39 | transforms.RandomHorizontalFlip(), 40 | transforms.ToTensor(), 41 | normalize, 42 | ]) 43 | 44 | test_transform = transforms.Compose([ 45 | transforms.ToTensor(), 46 | normalize 47 | ]) 48 | 49 | full_dataset = datasets.CIFAR10('_dataset/cifar10', True, train_transform, download=True) 50 | test_dataset = datasets.CIFAR10('_dataset/cifar10', False, test_transform, download=False) 51 | 52 | 53 | # we need at least two threads 54 | max_threads = 2 if max_threads < 2 else max_threads 55 | if max_threads >= 6: 56 | val_threads = 2 57 | train_threads = max_threads - val_threads 58 | else: 59 | val_threads = 1 60 | train_threads = max_threads - 1 61 | 62 | 63 | valid_loader = None 64 | if validation_split > 0.0: 65 | split = int(np.floor((1.0-validation_split) * len(full_dataset))) 66 | train_dataset = DatasetSplitter(full_dataset,split_end=split) 67 | val_dataset = DatasetSplitter(full_dataset,split_start=split) 68 | train_loader = torch.utils.data.DataLoader( 69 | train_dataset, 70 | args.batch_size, 71 | num_workers=train_threads, 72 | pin_memory=True, shuffle=True) 73 | valid_loader = torch.utils.data.DataLoader( 74 | val_dataset, 75 | args.test_batch_size, 76 | num_workers=val_threads, 77 | pin_memory=True) 78 | else: 79 | train_loader = torch.utils.data.DataLoader( 80 | full_dataset, 81 | args.batch_size, 82 | num_workers=8, 83 | pin_memory=True, shuffle=True) 84 | 85 | print('Train loader length', len(train_loader)) 86 | 87 | test_loader = torch.utils.data.DataLoader( 88 | test_dataset, 89 | args.test_batch_size, 90 | shuffle=False, 91 | num_workers=1, 92 | pin_memory=True) 93 | 94 | return train_loader, valid_loader, test_loader 95 | 96 | def get_tinyimagenet_dataloaders(args, validation_split=0.0): 97 | traindir = os.path.join(args.datadir, 'train') 98 | valdir = os.path.join(args.datadir, 'val') 99 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 100 | std=[0.229, 0.224, 0.225]) 101 | 102 | train_dataset = datasets.ImageFolder( 103 | traindir, 104 | transforms.Compose([ 105 | transforms.RandomResizedCrop(224), 106 | transforms.RandomHorizontalFlip(), 107 | transforms.ToTensor(), 108 | normalize, 109 | ])) 110 | 111 | if args.distributed: 112 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 113 | else: 114 | train_sampler = None 115 | 116 | train_loader = torch.utils.data.DataLoader( 117 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 118 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 119 | 120 | val_loader = torch.utils.data.DataLoader( 121 | datasets.ImageFolder(valdir, transforms.Compose([ 122 | transforms.Resize(256), 123 | transforms.CenterCrop(224), 124 | transforms.ToTensor(), 125 | normalize, 126 | ])), 127 | batch_size=args.batch_size, shuffle=False, 128 | num_workers=args.workers, pin_memory=True) 129 | return train_loader, val_loader 130 | 131 | def get_mnist_dataloaders(args, validation_split=0.0): 132 | """Creates augmented train, validation, and test data loaders.""" 133 | normalize = transforms.Normalize((0.1307,), (0.3081,)) 134 | transform = transform=transforms.Compose([transforms.ToTensor(),normalize]) 135 | 136 | full_dataset = datasets.MNIST('../data', train=True, download=True, transform=transform) 137 | test_dataset = datasets.MNIST('../data', train=False, transform=transform) 138 | 139 | dataset_size = len(full_dataset) 140 | indices = list(range(dataset_size)) 141 | split = int(np.floor(validation_split * dataset_size)) 142 | 143 | valid_loader = None 144 | if validation_split > 0.0: 145 | split = int(np.floor((1.0-validation_split) * len(full_dataset))) 146 | train_dataset = DatasetSplitter(full_dataset,split_end=split) 147 | val_dataset = DatasetSplitter(full_dataset,split_start=split) 148 | train_loader = torch.utils.data.DataLoader( 149 | train_dataset, 150 | args.batch_size, 151 | num_workers=8, 152 | pin_memory=True, shuffle=True) 153 | valid_loader = torch.utils.data.DataLoader( 154 | val_dataset, 155 | args.test_batch_size, 156 | num_workers=2, 157 | pin_memory=True) 158 | else: 159 | train_loader = torch.utils.data.DataLoader( 160 | full_dataset, 161 | args.batch_size, 162 | num_workers=8, 163 | pin_memory=True, shuffle=True) 164 | 165 | print('Train loader length', len(train_loader)) 166 | 167 | test_loader = torch.utils.data.DataLoader( 168 | test_dataset, 169 | args.test_batch_size, 170 | shuffle=False, 171 | num_workers=1, 172 | pin_memory=True) 173 | 174 | return train_loader, valid_loader, test_loader 175 | 176 | 177 | def plot_class_feature_histograms(args, model, device, test_loader, optimizer): 178 | if not os.path.exists('./results'): os.mkdir('./results') 179 | model.eval() 180 | agg = {} 181 | num_classes = 10 182 | feat_id = 0 183 | sparse = not args.dense 184 | model_name = 'alexnet' 185 | #model_name = 'vgg' 186 | #model_name = 'wrn' 187 | 188 | 189 | densities = None 190 | for batch_idx, (data, target) in enumerate(test_loader): 191 | if batch_idx % 100 == 0: print(batch_idx,'/', len(test_loader)) 192 | with torch.no_grad(): 193 | #if batch_idx == 10: break 194 | data, target = data.to(device), target.to(device) 195 | for cls in range(num_classes): 196 | #print('=='*50) 197 | #print('CLASS {0}'.format(cls)) 198 | model.t = target 199 | sub_data = data[target == cls] 200 | 201 | output = model(sub_data) 202 | 203 | feats = model.feats 204 | if densities is None: 205 | densities = [] 206 | densities += model.densities 207 | 208 | if len(agg) == 0: 209 | for feat_id, feat in enumerate(feats): 210 | agg[feat_id] = [] 211 | #print(feat.shape) 212 | for i in range(feat.shape[1]): 213 | agg[feat_id].append(np.zeros((num_classes,))) 214 | 215 | for feat_id, feat in enumerate(feats): 216 | map_contributions = torch.abs(feat).sum([0, 2, 3]) 217 | for map_id in range(map_contributions.shape[0]): 218 | #print(feat_id, map_id, cls) 219 | #print(len(agg), len(agg[feat_id]), len(agg[feat_id][map_id]), len(feats)) 220 | agg[feat_id][map_id][cls] += map_contributions[map_id].item() 221 | 222 | del model.feats[:] 223 | del model.densities[:] 224 | model.feats = [] 225 | model.densities = [] 226 | 227 | if sparse: 228 | np.save('./results/{0}_sparse_density_data'.format(model_name), densities) 229 | 230 | for feat_id, map_data in agg.items(): 231 | data = np.array(map_data) 232 | #print(feat_id, data) 233 | full_contribution = data.sum() 234 | #print(full_contribution, data) 235 | contribution_per_channel = ((1.0/full_contribution)*data.sum(1)) 236 | #print('pre', data.shape[0]) 237 | channels = data.shape[0] 238 | #data = data[contribution_per_channel > 0.001] 239 | 240 | channel_density = np.cumsum(np.sort(contribution_per_channel)) 241 | print(channel_density) 242 | idx = np.argsort(contribution_per_channel) 243 | 244 | threshold_idx = np.searchsorted(channel_density, 0.05) 245 | print(data.shape, 'pre') 246 | data = data[idx[threshold_idx:]] 247 | print(data.shape, 'post') 248 | 249 | #perc = np.percentile(contribution_per_channel[contribution_per_channel > 0.0], 10) 250 | #print(contribution_per_channel, perc, feat_id) 251 | #data = data[contribution_per_channel > perc] 252 | #print(contribution_per_channel[contribution_per_channel < perc].sum()) 253 | #print('post', data.shape[0]) 254 | normed_data = np.max(data/np.sum(data,1).reshape(-1, 1), 1) 255 | #normed_data = (data/np.sum(data,1).reshape(-1, 1) > 0.2).sum(1) 256 | #counts, bins = np.histogram(normed_data, bins=4, range=(0, 4)) 257 | np.save('./results/{2}_{1}_feat_data_layer_{0}'.format(feat_id, 'sparse' if sparse else 'dense', model_name), normed_data) 258 | #plt.ylim(0, channels/2.0) 259 | ##plt.hist(normed_data, bins=range(0, 5)) 260 | #plt.hist(normed_data, bins=[(i+20)/float(200) for i in range(180)]) 261 | #plt.xlim(0.1, 0.5) 262 | #if sparse: 263 | # plt.title("Sparse: Conv2D layer {0}".format(feat_id)) 264 | # plt.savefig('./output/feat_histo/layer_{0}_sp.png'.format(feat_id)) 265 | #else: 266 | # plt.title("Dense: Conv2D layer {0}".format(feat_id)) 267 | # plt.savefig('./output/feat_histo/layer_{0}_d.png'.format(feat_id)) 268 | #plt.clf() 269 | -------------------------------------------------------------------------------- /SViTE/backup/test_mask.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | 5 | 6 | all_masks = {} 7 | for i in range(8): 8 | all_masks[i] = torch.load('{}-init_mask.pt'.format(i), map_location='cpu') 9 | 10 | for key in all_masks[0].keys(): 11 | result = [] 12 | for i in range(8): 13 | result.append((all_masks[i][key]==all_masks[1][key]).float().mean().item()) 14 | print(key, result) 15 | 16 | 17 | all_masks = {} 18 | for i in range(8): 19 | all_masks[i] = torch.load('{}-init_mask_syn.pt'.format(i), map_location='cpu') 20 | 21 | for key in all_masks[0].keys(): 22 | result = [] 23 | for i in range(8): 24 | result.append((all_masks[i][key]==all_masks[1][key]).float().mean().item()) 25 | print(key, result) -------------------------------------------------------------------------------- /SViTE/backup/vmsh/0404/01_vm1_run_tiny_dst.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=$1 \ 2 | python -m torch.distributed.launch \ 3 | --nproc_per_node=8 \ 4 | --use_env main.py \ 5 | --model deit_tiny_patch16_224 \ 6 | --epochs 600 \ 7 | --batch-size 64 \ 8 | --data-path ../../imagenet \ 9 | --output_dir ./tiny_dst \ 10 | --dist_url tcp://127.0.0.1:2457 \ 11 | --sparse_init fixed_ERK \ 12 | --density 0.5 \ 13 | --update_frequency 1000 \ 14 | --growth gradient \ 15 | --death magnitude \ 16 | --redistribution none -------------------------------------------------------------------------------- /SViTE/backup/vmsh/0404/03_vm3_run_tiny_dst.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=$1 \ 2 | python -m torch.distributed.launch \ 3 | --nproc_per_node=8 \ 4 | --use_env main.py \ 5 | --model deit_tiny_patch16_224 \ 6 | --epochs 600 \ 7 | --batch-size 64 \ 8 | --data-path ../../imagenet \ 9 | --output_dir ./tiny_dst \ 10 | --dist_url tcp://127.0.0.1:2457 \ 11 | --sparse_init fixed_ERK \ 12 | --density 0.4 \ 13 | --update_frequency 1000 \ 14 | --growth gradient \ 15 | --death magnitude \ 16 | --redistribution none -------------------------------------------------------------------------------- /SViTE/backup/vmsh/0404/09_vm9_run_tiny_dst.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=$1 \ 2 | python -m torch.distributed.launch \ 3 | --nproc_per_node=8 \ 4 | --use_env main.py \ 5 | --model deit_tiny_patch16_224 \ 6 | --epochs 600 \ 7 | --batch-size 64 \ 8 | --data-path ../../imagenet \ 9 | --output_dir ./tiny_dst \ 10 | --dist_url tcp://127.0.0.1:2457 \ 11 | --sparse_init fixed_ERK \ 12 | --density 0.3 \ 13 | --update_frequency 1000 \ 14 | --growth gradient \ 15 | --death magnitude \ 16 | --redistribution none -------------------------------------------------------------------------------- /SViTE/backup/vmsh/0405/01_vm1_run_tiny_dst.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=$1 \ 2 | python -m torch.distributed.launch \ 3 | --nproc_per_node=8 \ 4 | --use_env main.py \ 5 | --model deit_tiny_patch16_224 \ 6 | --epochs 600 \ 7 | --batch-size 64 \ 8 | --data-path ../../imagenet \ 9 | --output_dir ./tiny_dst_uns_040501 \ 10 | --dist_url tcp://127.0.0.1:2457 \ 11 | --sparse_init fixed_ERK \ 12 | --density 0.5 \ 13 | --update_frequency 2000 \ 14 | --growth gradient \ 15 | --death magnitude \ 16 | --redistribution none -------------------------------------------------------------------------------- /SViTE/backup/vmsh/0405/03_vm3_run_tiny_dst.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=$1 \ 2 | python -m torch.distributed.launch \ 3 | --nproc_per_node=8 \ 4 | --use_env main.py \ 5 | --model deit_tiny_patch16_224 \ 6 | --epochs 600 \ 7 | --batch-size 64 \ 8 | --data-path ../../imagenet \ 9 | --output_dir ./tiny_dst_uns_040502 \ 10 | --dist_url tcp://127.0.0.1:2457 \ 11 | --sparse_init fixed_ERK \ 12 | --density 0.5 \ 13 | --update_frequency 5000 \ 14 | --growth gradient \ 15 | --death magnitude \ 16 | --redistribution none -------------------------------------------------------------------------------- /SViTE/backup/vmsh/0405/09_vm9_run_tiny_dst.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=$1 \ 2 | python -m torch.distributed.launch \ 3 | --nproc_per_node=8 \ 4 | --use_env main.py \ 5 | --model deit_tiny_patch16_224 \ 6 | --epochs 600 \ 7 | --batch-size 64 \ 8 | --data-path ../../imagenet \ 9 | --output_dir ./tiny_dst_uns_040503 \ 10 | --dist_url tcp://127.0.0.1:2457 \ 11 | --sparse_init fixed_ERK \ 12 | --density 0.5 \ 13 | --update_frequency 10000 \ 14 | --growth gradient \ 15 | --death magnitude \ 16 | --redistribution none -------------------------------------------------------------------------------- /SViTE/cmd/utsh/01_run_tiny_dst_structure.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=$1 \ 2 | python -m torch.distributed.launch \ 3 | --nproc_per_node=4 \ 4 | --use_env main.py \ 5 | --model deit_tiny_patch16_224 \ 6 | --epochs 600 \ 7 | --batch-size 128 \ 8 | --data-path ../../../imagenet \ 9 | --output_dir ./tiny_structure \ 10 | --dist_url tcp://127.0.0.1:2454 \ 11 | --sparse_init fixed_ERK \ 12 | --update_frequency 1000 \ 13 | --growth gradient \ 14 | --death magnitude \ 15 | --redistribution none \ 16 | --atten_head 3 \ 17 | --pruning_type structure -------------------------------------------------------------------------------- /SViTE/cmd/vm/0409/vm1.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=$1 \ 2 | python -m torch.distributed.launch \ 3 | --nproc_per_node=8 \ 4 | --use_env main.py \ 5 | --model deit_tiny_patch16_224 \ 6 | --epochs 600 \ 7 | --batch-size 64 \ 8 | --data-path ../../imagenet \ 9 | --output_dir ./tiny_dst_uns_0409_vm1 \ 10 | --dist_url tcp://127.0.0.1:2457 \ 11 | --sparse_init fixed_ERK \ 12 | --density 0.5 \ 13 | --update_frequency 15000 \ 14 | --growth gradient \ 15 | --death magnitude \ 16 | --redistribution none -------------------------------------------------------------------------------- /SViTE/cmd/vm/0409/vm3.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=$1 \ 2 | python -m torch.distributed.launch \ 3 | --nproc_per_node=8 \ 4 | --use_env main.py \ 5 | --model deit_tiny_patch16_224 \ 6 | --epochs 600 \ 7 | --batch-size 32 \ 8 | --data-path ../../imagenet \ 9 | --output_dir ./tiny_dst_uns_0409_vm3 \ 10 | --dist_url tcp://127.0.0.1:2457 \ 11 | --sparse_init fixed_ERK \ 12 | --density 0.5 \ 13 | --update_frequency 10000 \ 14 | --growth gradient \ 15 | --death magnitude \ 16 | --redistribution none -------------------------------------------------------------------------------- /SViTE/cmd/vm/0409/vm9.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=$1 \ 2 | python -m torch.distributed.launch \ 3 | --nproc_per_node=8 \ 4 | --use_env main.py \ 5 | --model deit_tiny_patch16_224 \ 6 | --epochs 600 \ 7 | --batch-size 16 \ 8 | --data-path ../../imagenet \ 9 | --output_dir ./tiny_dst_uns_0409_vm9 \ 10 | --dist_url tcp://127.0.0.1:2457 \ 11 | --sparse_init fixed_ERK \ 12 | --density 0.5 \ 13 | --update_frequency 20000 \ 14 | --growth gradient \ 15 | --death magnitude \ 16 | --redistribution none -------------------------------------------------------------------------------- /SViTE/cmd/vm/0414/vm1.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=$1 \ 2 | python -m torch.distributed.launch \ 3 | --nproc_per_node=8 \ 4 | --use_env main.py \ 5 | --model deit_tiny_patch16_224 \ 6 | --epochs 600 \ 7 | --batch-size 64 \ 8 | --data-path ../../imagenet \ 9 | --output_dir ./tiny_dst_uns_0414_vm1 \ 10 | --dist_url tcp://127.0.0.1:23305 \ 11 | --sparse_init fixed_ERK \ 12 | --density 0.5 \ 13 | --update_frequency 20000 \ 14 | --growth gradient \ 15 | --death magnitude \ 16 | --redistribution none -------------------------------------------------------------------------------- /SViTE/cmd/vm/0414/vm3.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=$1 \ 2 | python -m torch.distributed.launch \ 3 | --nproc_per_node=8 \ 4 | --use_env main.py \ 5 | --model deit_tiny_patch16_224 \ 6 | --epochs 600 \ 7 | --batch-size 64 \ 8 | --data-path ../../imagenet \ 9 | --output_dir ./tiny_dst_uns_0414_vm3 \ 10 | --dist_url tcp://127.0.0.1:2457 \ 11 | --sparse_init fixed_ERK \ 12 | --density 0.5 \ 13 | --update_frequency 25000 \ 14 | --growth gradient \ 15 | --death magnitude \ 16 | --redistribution none -------------------------------------------------------------------------------- /SViTE/cmd/vm/0414/vm9.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=$1 \ 2 | python -m torch.distributed.launch \ 3 | --nproc_per_node=8 \ 4 | --use_env main.py \ 5 | --model deit_tiny_patch16_224 \ 6 | --epochs 600 \ 7 | --batch-size 64 \ 8 | --data-path ../../imagenet \ 9 | --output_dir ./tiny_dst_uns_0414_vm9 \ 10 | --dist_url tcp://127.0.0.1:2457 \ 11 | --sparse_init custom \ 12 | --density 0.5 \ 13 | --update_frequency 15000 \ 14 | --growth gradient \ 15 | --death magnitude \ 16 | --redistribution none \ 17 | --mask_path deit_tiny_mask_init/deit_tiny_patch16_224_sparse0.5_at_init.pt -------------------------------------------------------------------------------- /SViTE/cmd/vm/0416/vm1.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=$1 \ 2 | python -m torch.distributed.launch \ 3 | --nproc_per_node=8 \ 4 | --use_env main.py \ 5 | --model deit_small_patch16_224 \ 6 | --epochs 600 \ 7 | --batch-size 64 \ 8 | --data-path ../../imagenet \ 9 | --output_dir ./tiny_dst_uns_0417_vm1 \ 10 | --dist_url tcp://127.0.0.1:23305 \ 11 | --sparse_init fixed_ERK \ 12 | --density 0.3 \ 13 | --update_frequency 15000 \ 14 | --growth gradient \ 15 | --death magnitude \ 16 | --redistribution none -------------------------------------------------------------------------------- /SViTE/cmd/vm/0416/vm2.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=$1 \ 2 | python -m torch.distributed.launch \ 3 | --nproc_per_node=8 \ 4 | --use_env main.py \ 5 | --model deit_tiny_patch16_224 \ 6 | --epochs 600 \ 7 | --batch-size 64 \ 8 | --data-path ../../imagenet \ 9 | --output_dir ./tiny_dst_uns_0416_vm2_DIS \ 10 | --dist_url tcp://127.0.0.1:23305 \ 11 | --sparse_init fixed_ERK \ 12 | --density 0.5 \ 13 | --update_frequency 20000 \ 14 | --growth gradient \ 15 | --death magnitude \ 16 | --redistribution none -------------------------------------------------------------------------------- /SViTE/cmd/vm/0416/vm3.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=$1 \ 2 | python -m torch.distributed.launch \ 3 | --nproc_per_node=8 \ 4 | --use_env main.py \ 5 | --model deit_small_patch16_224 \ 6 | --epochs 600 \ 7 | --batch-size 64 \ 8 | --data-path ../../imagenet \ 9 | --output_dir ./tiny_dst_uns_0417_vm3 \ 10 | --dist_url tcp://127.0.0.1:23305 \ 11 | --sparse_init fixed_ERK \ 12 | --density 0.5 \ 13 | --update_frequency 15000 \ 14 | --growth gradient \ 15 | --death magnitude \ 16 | --redistribution none -------------------------------------------------------------------------------- /SViTE/cmd/vm/0416/vm9.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=$1 \ 2 | python -m torch.distributed.launch \ 3 | --nproc_per_node=8 \ 4 | --use_env main.py \ 5 | --model deit_base_patch16_224 \ 6 | --epochs 600 \ 7 | --batch-size 128 \ 8 | --data-path ../../imagenet \ 9 | --output_dir ./tiny_dst_uns_0417_vm9 \ 10 | --dist_url tcp://127.0.0.1:23305 \ 11 | --sparse_init fixed_ERK \ 12 | --density 0.5 \ 13 | --update_frequency 7000 \ 14 | --growth gradient \ 15 | --death magnitude \ 16 | --redistribution none -------------------------------------------------------------------------------- /SViTE/cmd/vm/0420/vm2.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=$1 \ 2 | python -m torch.distributed.launch \ 3 | --nproc_per_node=8 \ 4 | --use_env main.py \ 5 | --model deit_base_patch16_224 \ 6 | --epochs 600 \ 7 | --batch-size 128 \ 8 | --data-path /datadrive_a/TLC/imagenet \ 9 | --output_dir ./tiny_dst_uns_0420_vm2 \ 10 | --dist_url tcp://127.0.0.1:23305 \ 11 | --sparse_init fixed_ERK \ 12 | --density 0.5 \ 13 | --update_frequency 6000 \ 14 | --growth gradient \ 15 | --death magnitude \ 16 | --redistribution none -------------------------------------------------------------------------------- /SViTE/cmd/vm/0422/vm1.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=$1 \ 2 | python -m torch.distributed.launch \ 3 | --nproc_per_node=8 \ 4 | --use_env main.py \ 5 | --model deit_tiny_patch16_224 \ 6 | --epochs 600 \ 7 | --batch-size 64 \ 8 | --data-path ../../imagenet \ 9 | --output_dir ./tiny_dst_uns_0422_vm1 \ 10 | --dist_url tcp://127.0.0.1:23305 \ 11 | --sparse_init fixed_ERK \ 12 | --density 0.6 \ 13 | --update_frequency 15000 \ 14 | --growth gradient \ 15 | --death magnitude \ 16 | --redistribution none -------------------------------------------------------------------------------- /SViTE/cmd/vm/0422/vm3.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=$1 \ 2 | python -m torch.distributed.launch \ 3 | --nproc_per_node=8 \ 4 | --use_env main.py \ 5 | --model deit_tiny_patch16_224 \ 6 | --epochs 600 \ 7 | --batch-size 64 \ 8 | --data-path ../../imagenet \ 9 | --output_dir ./tiny_dst_uns_0422_vm1 \ 10 | --dist_url tcp://127.0.0.1:23305 \ 11 | --sparse_init fixed_ERK \ 12 | --density 0.7 \ 13 | --update_frequency 15000 \ 14 | --growth gradient \ 15 | --death magnitude \ 16 | --redistribution none -------------------------------------------------------------------------------- /SViTE/cmd/vm/0424/vm9.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=$1 \ 2 | python -m torch.distributed.launch \ 3 | --nproc_per_node=8 \ 4 | --use_env main.py \ 5 | --model deit_base_patch16_224 \ 6 | --epochs 600 \ 7 | --batch-size 128 \ 8 | --data-path ../../imagenet \ 9 | --output_dir ./base_dst_uns_0424_vm9 \ 10 | --dist_url tcp://127.0.0.1:23305 \ 11 | --sparse_init fixed_ERK \ 12 | --density 0.3 \ 13 | --update_frequency 6000 \ 14 | --growth gradient \ 15 | --death magnitude \ 16 | --redistribution none -------------------------------------------------------------------------------- /SViTE/cmd/vm/0426/vm1.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=$1 \ 2 | python -m torch.distributed.launch \ 3 | --nproc_per_node=8 \ 4 | --use_env main.py \ 5 | --model deit_small_patch16_224 \ 6 | --epochs 600 \ 7 | --batch-size 64 \ 8 | --data-path ../../imagenet \ 9 | --output_dir ./small_dst_uns_0426_vm1 \ 10 | --dist_url tcp://127.0.0.1:23305 \ 11 | --sparse_init fixed_ERK \ 12 | --density 0.4 \ 13 | --update_frequency 15000 \ 14 | --growth gradient \ 15 | --death magnitude \ 16 | --redistribution none -------------------------------------------------------------------------------- /SViTE/cmd/vm/0426/vm3.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=$1 \ 2 | python -m torch.distributed.launch \ 3 | --nproc_per_node=8 \ 4 | --use_env main.py \ 5 | --model deit_base_patch16_224 \ 6 | --epochs 600 \ 7 | --batch-size 128 \ 8 | --data-path ../../imagenet \ 9 | --output_dir ./base_dst_uns_0426_vm3 \ 10 | --dist_url tcp://127.0.0.1:23305 \ 11 | --sparse_init fixed_ERK \ 12 | --density 0.4 \ 13 | --update_frequency 7000 \ 14 | --growth gradient \ 15 | --death magnitude \ 16 | --redistribution none -------------------------------------------------------------------------------- /SViTE/cmd/vm/0428/deit_tiny_structure.sh: -------------------------------------------------------------------------------- 1 | python -m torch.distributed.launch \ 2 | --nproc_per_node=8 \ 3 | --use_env main.py \ 4 | --pruning_type structure_new \ 5 | --model deit_tiny_patch16_224 \ 6 | --atten_head 3 \ 7 | --atten_density 0.7 \ 8 | --other_density 0.7 \ 9 | --update_frequency 15000 \ 10 | --epochs 600 \ 11 | --batch-size 64 \ 12 | --t_end 0.8 \ 13 | --data-path /datadrive_a/TLC/imagenet \ 14 | --output_dir ./tiny_dst_structure_0428 \ 15 | --dist_url tcp://127.0.0.1:23305 \ 16 | --density 0.7 \ 17 | --sparse_init fixed_ERK \ 18 | --growth gradient \ 19 | --death magnitude \ 20 | --redistribution none -------------------------------------------------------------------------------- /SViTE/cmd/vm/0430/deit_base_structure.sh: -------------------------------------------------------------------------------- 1 | python -m torch.distributed.launch \ 2 | --nproc_per_node=8 \ 3 | --use_env main.py \ 4 | --pruning_type structure_new \ 5 | --model deit_base_patch16_224 \ 6 | --atten_head 12 \ 7 | --atten_density 0.7 \ 8 | --other_density 0.6 \ 9 | --update_frequency 7000 \ 10 | --epochs 600 \ 11 | --batch-size 128 \ 12 | --t_end 0.8 \ 13 | --data-path $1 \ 14 | --output_dir ./base_dst_structure_0430 \ 15 | --dist_url tcp://127.0.0.1:23305 \ 16 | --density 0.6 \ 17 | --sparse_init fixed_ERK \ 18 | --growth gradient \ 19 | --death magnitude \ 20 | --redistribution none -------------------------------------------------------------------------------- /SViTE/cmd/vm/0502/vm14.sh: -------------------------------------------------------------------------------- 1 | python -m torch.distributed.launch \ 2 | --nproc_per_node=8 \ 3 | --use_env main.py \ 4 | --pruning_type structure_new \ 5 | --model deit_tiny_patch16_224 \ 6 | --atten_head 3 \ 7 | --atten_density 0.7 \ 8 | --other_density 0.7 \ 9 | --update_frequency 20000 \ 10 | --epochs 600 \ 11 | --batch-size 64 \ 12 | --t_end 0.8 \ 13 | --data-path $1 \ 14 | --output_dir ./tiny_dst_structure_0502 \ 15 | --dist_url tcp://127.0.0.1:23305 \ 16 | --density 0.7 \ 17 | --sparse_init fixed_ERK \ 18 | --growth gradient \ 19 | --death magnitude \ 20 | --redistribution none \ 21 | --num_workers $2 -------------------------------------------------------------------------------- /SViTE/cmd/vm/0502/vm2.sh: -------------------------------------------------------------------------------- 1 | python -m torch.distributed.launch \ 2 | --nproc_per_node=8 \ 3 | --use_env main.py \ 4 | --pruning_type structure_new \ 5 | --model deit_small_patch16_224 \ 6 | --atten_head 6 \ 7 | --atten_density 0.7 \ 8 | --other_density 0.6 \ 9 | --update_frequency 15000 \ 10 | --epochs 600 \ 11 | --batch-size 64 \ 12 | --t_end 0.8 \ 13 | --data-path $1 \ 14 | --output_dir ./small_dst_structure_0502 \ 15 | --dist_url tcp://127.0.0.1:23305 \ 16 | --density 0.6 \ 17 | --sparse_init fixed_ERK \ 18 | --growth gradient \ 19 | --death magnitude \ 20 | --redistribution none -------------------------------------------------------------------------------- /SViTE/cmd/vm/0502/vm3.sh: -------------------------------------------------------------------------------- 1 | python -m torch.distributed.launch \ 2 | --nproc_per_node=8 \ 3 | --use_env main.py \ 4 | --model deit_base_patch16_224 \ 5 | --epochs 600 \ 6 | --batch-size 128 \ 7 | --data-path $1 \ 8 | --output_dir ./base_dst_uns_0502 \ 9 | --dist_url tcp://127.0.0.1:23305 \ 10 | --sparse_init fixed_ERK \ 11 | --density 0.3 \ 12 | --update_frequency 7000 \ 13 | --growth gradient \ 14 | --death magnitude \ 15 | --redistribution none -------------------------------------------------------------------------------- /SViTE/cmd/vm/0506/structure_small.sh: -------------------------------------------------------------------------------- 1 | python -m torch.distributed.launch \ 2 | --nproc_per_node=8 \ 3 | --use_env main.py \ 4 | --pruning_type structure_new \ 5 | --model deit_small_patch16_224 \ 6 | --atten_head 6 \ 7 | --atten_density 0.5 \ 8 | --other_density 0.5 \ 9 | --update_frequency 15000 \ 10 | --epochs 600 \ 11 | --batch-size 64 \ 12 | --t_end 0.8 \ 13 | --data-path $1 \ 14 | --output_dir ./small_dst_structure_0502 \ 15 | --dist_url tcp://127.0.0.1:23305 \ 16 | --density 0.5 \ 17 | --sparse_init fixed_ERK \ 18 | --growth gradient \ 19 | --death magnitude \ 20 | --redistribution none \ 21 | --num_workers $2 -------------------------------------------------------------------------------- /SViTE/cmd/vm/0506/vm3.sh: -------------------------------------------------------------------------------- 1 | python -m torch.distributed.launch \ 2 | --nproc_per_node=8 \ 3 | --use_env main.py \ 4 | --model deit_base_patch16_224 \ 5 | --epochs 600 \ 6 | --batch-size 128 \ 7 | --data-path $1 \ 8 | --output_dir ./base_dst_uns_0502 \ 9 | --dist_url tcp://127.0.0.1:23305 \ 10 | --sparse_init fixed_ERK \ 11 | --density 0.3 \ 12 | --update_frequency 7000 \ 13 | --growth gradient \ 14 | --death magnitude \ 15 | --redistribution none \ 16 | --resume ./base_dst_uns_0502/checkpoint.pth \ 17 | --num_workers $2 -------------------------------------------------------------------------------- /SViTE/cmd/vm/0507/line67.sh: -------------------------------------------------------------------------------- 1 | python -m torch.distributed.launch \ 2 | --nproc_per_node=8 \ 3 | --use_env main.py \ 4 | --model deit_small_patch16_224 \ 5 | --epochs 600 \ 6 | --batch-size 64 \ 7 | --data-path $1 \ 8 | --output_dir ./small_dst_uns_line67 \ 9 | --dist_url tcp://127.0.0.1:23305 \ 10 | --sparse_init fixed_ERK \ 11 | --density 0.6 \ 12 | --update_frequency 15000 \ 13 | --growth gradient \ 14 | --death magnitude \ 15 | --redistribution none \ 16 | --num_workers $2 -------------------------------------------------------------------------------- /SViTE/cmd/vm/0507/line69.sh: -------------------------------------------------------------------------------- 1 | python -m torch.distributed.launch \ 2 | --nproc_per_node=8 \ 3 | --use_env main.py \ 4 | --pruning_type structure_new \ 5 | --model deit_base_patch16_224 \ 6 | --atten_head 12 \ 7 | --atten_density 0.5 \ 8 | --other_density 0.5 \ 9 | --update_frequency 7000 \ 10 | --epochs 600 \ 11 | --batch-size 128 \ 12 | --t_end 0.8 \ 13 | --data-path $1 \ 14 | --output_dir ./base_dst_structure_line69 \ 15 | --dist_url tcp://127.0.0.1:23305 \ 16 | --density 0.5 \ 17 | --sparse_init fixed_ERK \ 18 | --growth gradient \ 19 | --death magnitude \ 20 | --redistribution none \ 21 | --num_workers $2 -------------------------------------------------------------------------------- /SViTE/cmd/vm/0507/vm1.sh: -------------------------------------------------------------------------------- 1 | python -m torch.distributed.launch \ 2 | --nproc_per_node=8 \ 3 | --use_env main.py \ 4 | --model deit_base_patch16_224 \ 5 | --epochs 600 \ 6 | --batch-size 128 \ 7 | --data-path $1 \ 8 | --output_dir ./base_dst_uns_0507 \ 9 | --dist_url tcp://127.0.0.1:23305 \ 10 | --sparse_init fixed_ERK \ 11 | --density 0.6 \ 12 | --update_frequency 7000 \ 13 | --growth gradient \ 14 | --death magnitude \ 15 | --redistribution none \ 16 | --num_workers $2 -------------------------------------------------------------------------------- /SViTE/cmd/vm/0508/line71.sh: -------------------------------------------------------------------------------- 1 | python -m torch.distributed.launch \ 2 | --nproc_per_node=8 \ 3 | --use_env main.py \ 4 | --model deit_small_patch16_224 \ 5 | --epochs 600 \ 6 | --batch-size 64 \ 7 | --data-path $1 \ 8 | --output_dir ./small_dst_uns_line71 \ 9 | --dist_url tcp://127.0.0.1:23305 \ 10 | --sparse_init fixed_ERK \ 11 | --density 0.7 \ 12 | --update_frequency 15000 \ 13 | --growth gradient \ 14 | --death magnitude \ 15 | --redistribution none \ 16 | --num_workers $2 -------------------------------------------------------------------------------- /SViTE/cmd/vm/0508/line72.sh: -------------------------------------------------------------------------------- 1 | python -m torch.distributed.launch \ 2 | --nproc_per_node=8 \ 3 | --use_env main.py \ 4 | --pruning_type structure_new \ 5 | --model deit_small_patch16_224 \ 6 | --atten_head 6 \ 7 | --atten_density 0.7 \ 8 | --other_density 0.7 \ 9 | --update_frequency 15000 \ 10 | --epochs 600 \ 11 | --batch-size 64 \ 12 | --t_end 0.8 \ 13 | --data-path $1 \ 14 | --output_dir ./small_dst_structure_line72 \ 15 | --dist_url tcp://127.0.0.1:23305 \ 16 | --density 0.7 \ 17 | --sparse_init fixed_ERK \ 18 | --growth gradient \ 19 | --death magnitude \ 20 | --redistribution none \ 21 | --num_workers $2 -------------------------------------------------------------------------------- /SViTE/cmd/vm/0508/line73.sh: -------------------------------------------------------------------------------- 1 | python -m torch.distributed.launch \ 2 | --nproc_per_node=8 \ 3 | --use_env main.py \ 4 | --pruning_type structure_new \ 5 | --model deit_small_patch16_224 \ 6 | --atten_head 6 \ 7 | --atten_density 0.4 \ 8 | --other_density 0.4 \ 9 | --update_frequency 15000 \ 10 | --epochs 600 \ 11 | --batch-size 64 \ 12 | --t_end 0.8 \ 13 | --data-path $1 \ 14 | --output_dir ./small_dst_structure_line73 \ 15 | --dist_url tcp://127.0.0.1:23305 \ 16 | --density 0.4 \ 17 | --sparse_init fixed_ERK \ 18 | --growth gradient \ 19 | --death magnitude \ 20 | --redistribution none \ 21 | --num_workers $2 -------------------------------------------------------------------------------- /SViTE/cmd/vm/0508/line74.sh: -------------------------------------------------------------------------------- 1 | python -m torch.distributed.launch \ 2 | --nproc_per_node=8 \ 3 | --use_env main.py \ 4 | --pruning_type structure_new \ 5 | --model deit_small_patch16_224 \ 6 | --atten_head 6 \ 7 | --atten_density 0.4 \ 8 | --other_density 0.3 \ 9 | --update_frequency 15000 \ 10 | --epochs 600 \ 11 | --batch-size 64 \ 12 | --t_end 0.8 \ 13 | --data-path $1 \ 14 | --output_dir ./small_dst_structure_line74 \ 15 | --dist_url tcp://127.0.0.1:23305 \ 16 | --density 0.3 \ 17 | --sparse_init fixed_ERK \ 18 | --growth gradient \ 19 | --death magnitude \ 20 | --redistribution none \ 21 | --num_workers $2 -------------------------------------------------------------------------------- /SViTE/cmd/vm/0508/line75.sh: -------------------------------------------------------------------------------- 1 | python -m torch.distributed.launch \ 2 | --nproc_per_node=8 \ 3 | --use_env main.py \ 4 | --model deit_base_patch16_224 \ 5 | --epochs 600 \ 6 | --batch-size 128 \ 7 | --data-path $1 \ 8 | --output_dir ./base_dst_uns_line75 \ 9 | --dist_url tcp://127.0.0.1:23305 \ 10 | --sparse_init fixed_ERK \ 11 | --density 0.7 \ 12 | --update_frequency 7000 \ 13 | --growth gradient \ 14 | --death magnitude \ 15 | --redistribution none \ 16 | --num_workers $2 -------------------------------------------------------------------------------- /SViTE/cmd/vm/0508/line77.sh: -------------------------------------------------------------------------------- 1 | python -m torch.distributed.launch \ 2 | --nproc_per_node=8 \ 3 | --use_env main.py \ 4 | --pruning_type structure_new \ 5 | --model deit_base_patch16_224 \ 6 | --atten_head 12 \ 7 | --atten_density 0.7 \ 8 | --other_density 0.7 \ 9 | --update_frequency 7000 \ 10 | --epochs 600 \ 11 | --batch-size 128 \ 12 | --t_end 0.8 \ 13 | --data-path $1 \ 14 | --output_dir ./base_dst_structure_line77 \ 15 | --dist_url tcp://127.0.0.1:23305 \ 16 | --density 0.7 \ 17 | --sparse_init fixed_ERK \ 18 | --growth gradient \ 19 | --death magnitude \ 20 | --redistribution none \ 21 | --num_workers $2 -------------------------------------------------------------------------------- /SViTE/cmd/vm/0508/line78.sh: -------------------------------------------------------------------------------- 1 | python -m torch.distributed.launch \ 2 | --nproc_per_node=8 \ 3 | --use_env main.py \ 4 | --pruning_type structure_new \ 5 | --model deit_base_patch16_224 \ 6 | --atten_head 12 \ 7 | --atten_density 0.4 \ 8 | --other_density 0.4 \ 9 | --update_frequency 7000 \ 10 | --epochs 600 \ 11 | --batch-size 128 \ 12 | --t_end 0.8 \ 13 | --data-path $1 \ 14 | --output_dir ./base_dst_structure_line78 \ 15 | --dist_url tcp://127.0.0.1:23305 \ 16 | --density 0.4 \ 17 | --sparse_init fixed_ERK \ 18 | --growth gradient \ 19 | --death magnitude \ 20 | --redistribution none \ 21 | --num_workers $2 -------------------------------------------------------------------------------- /SViTE/cmd/vm/0508/line79.sh: -------------------------------------------------------------------------------- 1 | python -m torch.distributed.launch \ 2 | --nproc_per_node=8 \ 3 | --use_env main.py \ 4 | --pruning_type structure_new \ 5 | --model deit_base_patch16_224 \ 6 | --atten_head 12 \ 7 | --atten_density 0.4 \ 8 | --other_density 0.3 \ 9 | --update_frequency 7000 \ 10 | --epochs 600 \ 11 | --batch-size 128 \ 12 | --t_end 0.8 \ 13 | --data-path $1 \ 14 | --output_dir ./base_dst_structure_line79 \ 15 | --dist_url tcp://127.0.0.1:23305 \ 16 | --density 0.3 \ 17 | --sparse_init fixed_ERK \ 18 | --growth gradient \ 19 | --death magnitude \ 20 | --redistribution none \ 21 | --num_workers $2 -------------------------------------------------------------------------------- /SViTE/cmd/vm/0509/gumbel.sh: -------------------------------------------------------------------------------- 1 | python -m torch.distributed.launch \ 2 | --nproc_per_node=8 \ 3 | --use_env main.py \ 4 | --model deit_small_patch16_224 \ 5 | --epochs 600 \ 6 | --batch-size 64 \ 7 | --data-path $1 \ 8 | --output_dir ./small_dst_uns_gumbel_0509 \ 9 | --dist_url tcp://127.0.0.1:25503 \ 10 | --sparse_init fixed_ERK \ 11 | --density 0.5 \ 12 | --update_frequency 15000 \ 13 | --growth gradient \ 14 | --death magnitude \ 15 | --redistribution none \ 16 | --num_workers $2 \ 17 | --token_selection -------------------------------------------------------------------------------- /SViTE/cmd/vm/0516/line70.sh: -------------------------------------------------------------------------------- 1 | python -m torch.distributed.launch \ 2 | --nproc_per_node=8 \ 3 | --use_env main.py \ 4 | --pruning_type structure_new \ 5 | --model deit_small_patch16_224 \ 6 | --atten_head 6 \ 7 | --atten_density 0.7 \ 8 | --other_density 0.6 \ 9 | --update_frequency 15000 \ 10 | --epochs 600 \ 11 | --batch-size 64 \ 12 | --t_end 0.8 \ 13 | --data-path $1 \ 14 | --output_dir ./small_dst_structure_token1 \ 15 | --dist_url tcp://127.0.0.1:23305 \ 16 | --density 0.9 \ 17 | --sparse_init fixed_ERK \ 18 | --growth gradient \ 19 | --death magnitude \ 20 | --redistribution none \ 21 | --num_workers $2 \ 22 | --token_selection \ 23 | --token_number 197 -------------------------------------------------------------------------------- /SViTE/cmd/vm/0517/line67.sh: -------------------------------------------------------------------------------- 1 | python -m torch.distributed.launch \ 2 | --nproc_per_node=8 \ 3 | --use_env main.py \ 4 | --model deit_small_patch16_224 \ 5 | --epochs 600 \ 6 | --batch-size 64 \ 7 | --data-path $1 \ 8 | --output_dir ./small_dst_uns_gumbel_line67 \ 9 | --dist_url tcp://127.0.0.1:25503 \ 10 | --sparse_init fixed_ERK \ 11 | --density 0.5 \ 12 | --update_frequency 15000 \ 13 | --growth gradient \ 14 | --death magnitude \ 15 | --redistribution none \ 16 | --num_workers $2 \ 17 | --token_selection \ 18 | --token_number 177 -------------------------------------------------------------------------------- /SViTE/cmd/vm/0517/line67_new_95.sh: -------------------------------------------------------------------------------- 1 | python -m torch.distributed.launch \ 2 | --nproc_per_node=8 \ 3 | --use_env main.py \ 4 | --model deit_small_patch16_224 \ 5 | --epochs 600 \ 6 | --batch-size 64 \ 7 | --data-path $1 \ 8 | --output_dir ./small_dst_uns_gumbel_187 \ 9 | --dist_url tcp://127.0.0.1:25503 \ 10 | --sparse_init fixed_ERK \ 11 | --density 0.5 \ 12 | --update_frequency 15000 \ 13 | --growth gradient \ 14 | --death magnitude \ 15 | --redistribution none \ 16 | --num_workers $2 \ 17 | --token_selection \ 18 | --token_number 187 -------------------------------------------------------------------------------- /SViTE/cmd/vm/0517/line68.sh: -------------------------------------------------------------------------------- 1 | python -m torch.distributed.launch \ 2 | --nproc_per_node=8 \ 3 | --use_env main.py \ 4 | --model deit_small_patch16_224 \ 5 | --epochs 600 \ 6 | --batch-size 64 \ 7 | --data-path $1 \ 8 | --output_dir ./small_dst_uns_gumbel_line68 \ 9 | --dist_url tcp://127.0.0.1:25503 \ 10 | --sparse_init fixed_ERK \ 11 | --density 0.5 \ 12 | --update_frequency 15000 \ 13 | --growth gradient \ 14 | --death magnitude \ 15 | --redistribution none \ 16 | --num_workers $2 \ 17 | --token_selection \ 18 | --token_number 138 -------------------------------------------------------------------------------- /SViTE/cmd/vm/0517/line69.sh: -------------------------------------------------------------------------------- 1 | python -m torch.distributed.launch \ 2 | --nproc_per_node=8 \ 3 | --use_env main.py \ 4 | --model deit_small_patch16_224 \ 5 | --epochs 600 \ 6 | --batch-size 64 \ 7 | --data-path $1 \ 8 | --output_dir ./small_dst_uns_gumbel_line69 \ 9 | --dist_url tcp://127.0.0.1:25503 \ 10 | --sparse_init fixed_ERK \ 11 | --density 0.5 \ 12 | --update_frequency 15000 \ 13 | --growth gradient \ 14 | --death magnitude \ 15 | --redistribution none \ 16 | --num_workers $2 \ 17 | --token_selection \ 18 | --token_number 99 -------------------------------------------------------------------------------- /SViTE/cmd/vm/0517/line71.sh: -------------------------------------------------------------------------------- 1 | python -m torch.distributed.launch \ 2 | --nproc_per_node=8 \ 3 | --use_env main.py \ 4 | --pruning_type structure_new \ 5 | --model deit_small_patch16_224 \ 6 | --atten_head 6 \ 7 | --atten_density 0.7 \ 8 | --other_density 0.6 \ 9 | --update_frequency 15000 \ 10 | --epochs 600 \ 11 | --batch-size 64 \ 12 | --t_end 0.8 \ 13 | --data-path $1 \ 14 | --output_dir ./small_dst_structure_token177 \ 15 | --dist_url tcp://127.0.0.1:23305 \ 16 | --density 0.9 \ 17 | --sparse_init fixed_ERK \ 18 | --growth gradient \ 19 | --death magnitude \ 20 | --redistribution none \ 21 | --num_workers $2 \ 22 | --token_selection \ 23 | --token_number 177 -------------------------------------------------------------------------------- /SViTE/cmd/vm/0517/line72.sh: -------------------------------------------------------------------------------- 1 | python -m torch.distributed.launch \ 2 | --nproc_per_node=8 \ 3 | --use_env main.py \ 4 | --pruning_type structure_new \ 5 | --model deit_small_patch16_224 \ 6 | --atten_head 6 \ 7 | --atten_density 0.7 \ 8 | --other_density 0.6 \ 9 | --update_frequency 15000 \ 10 | --epochs 600 \ 11 | --batch-size 64 \ 12 | --t_end 0.8 \ 13 | --data-path $1 \ 14 | --output_dir ./small_dst_structure_token138 \ 15 | --dist_url tcp://127.0.0.1:23305 \ 16 | --density 0.9 \ 17 | --sparse_init fixed_ERK \ 18 | --growth gradient \ 19 | --death magnitude \ 20 | --redistribution none \ 21 | --num_workers $2 \ 22 | --token_selection \ 23 | --token_number 138 -------------------------------------------------------------------------------- /SViTE/cmd/vm/0517/line73.sh: -------------------------------------------------------------------------------- 1 | python -m torch.distributed.launch \ 2 | --nproc_per_node=8 \ 3 | --use_env main.py \ 4 | --pruning_type structure_new \ 5 | --model deit_small_patch16_224 \ 6 | --atten_head 6 \ 7 | --atten_density 0.7 \ 8 | --other_density 0.6 \ 9 | --update_frequency 15000 \ 10 | --epochs 600 \ 11 | --batch-size 64 \ 12 | --t_end 0.8 \ 13 | --data-path $1 \ 14 | --output_dir ./small_dst_structure_token99 \ 15 | --dist_url tcp://127.0.0.1:23305 \ 16 | --density 0.9 \ 17 | --sparse_init fixed_ERK \ 18 | --growth gradient \ 19 | --death magnitude \ 20 | --redistribution none \ 21 | --num_workers $2 \ 22 | --token_selection \ 23 | --token_number 99 -------------------------------------------------------------------------------- /SViTE/cmd/vm/0517/resume.sh: -------------------------------------------------------------------------------- 1 | python -m torch.distributed.launch \ 2 | --nproc_per_node=8 \ 3 | --use_env main.py \ 4 | --model deit_base_patch16_224 \ 5 | --epochs 600 \ 6 | --batch-size 128 \ 7 | --data-path $1 \ 8 | --output_dir ./base_dst_uns_0502 \ 9 | --dist_url tcp://127.0.0.1:23305 \ 10 | --sparse_init fixed_ERK \ 11 | --density 0.3 \ 12 | --update_frequency 7000 \ 13 | --growth gradient \ 14 | --death magnitude \ 15 | --redistribution none \ 16 | --resume ./base_dst_uns_0502/checkpoint.pth \ 17 | --num_workers $2 -------------------------------------------------------------------------------- /SViTE/cmd/vm/0519/layer_wise_token_uns_small.sh: -------------------------------------------------------------------------------- 1 | python -m torch.distributed.launch \ 2 | --nproc_per_node=8 \ 3 | --use_env main.py \ 4 | --model deit_small_patch16_224_data \ 5 | --epochs 600 \ 6 | --batch-size 64 \ 7 | --data-path $1 \ 8 | --output_dir ./small_dst_uns_layerwise_gumbel_density50_token177 \ 9 | --dist_url tcp://127.0.0.1:25503 \ 10 | --sparse_init fixed_ERK \ 11 | --density 0.5 \ 12 | --update_frequency 15000 \ 13 | --growth gradient \ 14 | --death magnitude \ 15 | --redistribution none \ 16 | --num_workers $2 \ 17 | --token_selection \ 18 | --token_number 177 -------------------------------------------------------------------------------- /SViTE/cmd/vm/0519/line72_95_s.sh: -------------------------------------------------------------------------------- 1 | python -m torch.distributed.launch \ 2 | --nproc_per_node=8 \ 3 | --use_env main.py \ 4 | --pruning_type structure_new \ 5 | --model deit_small_patch16_224 \ 6 | --atten_head 6 \ 7 | --atten_density 0.7 \ 8 | --other_density 0.6 \ 9 | --update_frequency 15000 \ 10 | --epochs 600 \ 11 | --batch-size 64 \ 12 | --t_end 0.8 \ 13 | --data-path $1 \ 14 | --output_dir ./small_dst_structure_token187 \ 15 | --dist_url tcp://127.0.0.1:23305 \ 16 | --density 0.9 \ 17 | --sparse_init fixed_ERK \ 18 | --growth gradient \ 19 | --death magnitude \ 20 | --redistribution none \ 21 | --num_workers $2 \ 22 | --token_selection \ 23 | --token_number 187 -------------------------------------------------------------------------------- /SViTE/cmd/vm/0519/line78.sh: -------------------------------------------------------------------------------- 1 | python -m torch.distributed.launch \ 2 | --nproc_per_node=8 \ 3 | --use_env main.py \ 4 | --model deit_small_patch16_224_data \ 5 | --epochs 600 \ 6 | --batch-size 64 \ 7 | --data-path $1 \ 8 | --output_dir ./small_dst_uns_layerwise_gumbel_density50_token187 \ 9 | --dist_url tcp://127.0.0.1:25503 \ 10 | --sparse_init fixed_ERK \ 11 | --density 0.5 \ 12 | --update_frequency 15000 \ 13 | --growth gradient \ 14 | --death magnitude \ 15 | --redistribution none \ 16 | --num_workers $2 \ 17 | --token_selection \ 18 | --token_number 187 -------------------------------------------------------------------------------- /SViTE/cmd/vm/0519/line79.sh: -------------------------------------------------------------------------------- 1 | python -m torch.distributed.launch \ 2 | --nproc_per_node=8 \ 3 | --use_env main.py \ 4 | --model deit_small_patch16_224_data \ 5 | --epochs 600 \ 6 | --batch-size 64 \ 7 | --data-path $1 \ 8 | --output_dir ./small_dst_uns_layerwise_gumbel_density50_token138 \ 9 | --dist_url tcp://127.0.0.1:25503 \ 10 | --sparse_init fixed_ERK \ 11 | --density 0.5 \ 12 | --update_frequency 15000 \ 13 | --growth gradient \ 14 | --death magnitude \ 15 | --redistribution none \ 16 | --num_workers $2 \ 17 | --token_selection \ 18 | --token_number 138 -------------------------------------------------------------------------------- /SViTE/cmd/vm/inference.sh: -------------------------------------------------------------------------------- 1 | # export NCCL_P2P_DISABLE=1 2 | python -m torch.distributed.launch \ 3 | --nproc_per_node=4 \ 4 | --use_env inference.py \ 5 | --model deit_small_patch16_224 \ 6 | --epochs 600 \ 7 | --batch-size 64 \ 8 | --data-path $1 \ 9 | --output_dir ./inference \ 10 | --dist_url tcp://127.0.0.1:25503 \ 11 | --sparse_init fixed_ERK \ 12 | --density 0.5 \ 13 | --update_frequency 15000 \ 14 | --growth gradient \ 15 | --death magnitude \ 16 | --redistribution none \ 17 | --num_workers 16 \ 18 | --token_selection \ 19 | --token_number $2 \ 20 | --resume $3 \ 21 | --train_eval -------------------------------------------------------------------------------- /SViTE/cmd/vm/inference_data.sh: -------------------------------------------------------------------------------- 1 | # export NCCL_P2P_DISABLE=1 2 | python -m torch.distributed.launch \ 3 | --nproc_per_node=4 \ 4 | --use_env inference.py \ 5 | --model deit_small_patch16_224_data \ 6 | --epochs 600 \ 7 | --batch-size 64 \ 8 | --data-path $1 \ 9 | --output_dir ./inference \ 10 | --dist_url tcp://127.0.0.1:25503 \ 11 | --sparse_init fixed_ERK \ 12 | --density 0.5 \ 13 | --update_frequency 15000 \ 14 | --growth gradient \ 15 | --death magnitude \ 16 | --redistribution none \ 17 | --num_workers 16 \ 18 | --token_selection \ 19 | --token_number $2 \ 20 | --resume $3 \ 21 | --train_eval -------------------------------------------------------------------------------- /SViTE/datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | import os 4 | import json 5 | 6 | from torchvision import datasets, transforms 7 | from torchvision.datasets.folder import ImageFolder, default_loader 8 | 9 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 10 | from timm.data import create_transform 11 | 12 | 13 | class INatDataset(ImageFolder): 14 | def __init__(self, root, train=True, year=2018, transform=None, target_transform=None, 15 | category='name', loader=default_loader): 16 | self.transform = transform 17 | self.loader = loader 18 | self.target_transform = target_transform 19 | self.year = year 20 | # assert category in ['kingdom','phylum','class','order','supercategory','family','genus','name'] 21 | path_json = os.path.join(root, f'{"train" if train else "val"}{year}.json') 22 | with open(path_json) as json_file: 23 | data = json.load(json_file) 24 | 25 | with open(os.path.join(root, 'categories.json')) as json_file: 26 | data_catg = json.load(json_file) 27 | 28 | path_json_for_targeter = os.path.join(root, f"train{year}.json") 29 | 30 | with open(path_json_for_targeter) as json_file: 31 | data_for_targeter = json.load(json_file) 32 | 33 | targeter = {} 34 | indexer = 0 35 | for elem in data_for_targeter['annotations']: 36 | king = [] 37 | king.append(data_catg[int(elem['category_id'])][category]) 38 | if king[0] not in targeter.keys(): 39 | targeter[king[0]] = indexer 40 | indexer += 1 41 | self.nb_classes = len(targeter) 42 | 43 | self.samples = [] 44 | for elem in data['images']: 45 | cut = elem['file_name'].split('/') 46 | target_current = int(cut[2]) 47 | path_current = os.path.join(root, cut[0], cut[2], cut[3]) 48 | 49 | categors = data_catg[target_current] 50 | target_current_true = targeter[categors[category]] 51 | self.samples.append((path_current, target_current_true)) 52 | 53 | # __getitem__ and __len__ inherited from ImageFolder 54 | 55 | 56 | def build_dataset(is_train, args, no_augmentation=False): 57 | transform = build_transform(is_train, args) 58 | if no_augmentation: 59 | print('training set without augmentation') 60 | transform = build_transform(False, args) 61 | 62 | 63 | if args.data_set == 'CIFAR': 64 | dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform) 65 | nb_classes = 100 66 | elif args.data_set == 'IMNET': 67 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 68 | dataset = datasets.ImageFolder(root, transform=transform) 69 | nb_classes = 1000 70 | elif args.data_set == 'INAT': 71 | dataset = INatDataset(args.data_path, train=is_train, year=2018, 72 | category=args.inat_category, transform=transform) 73 | nb_classes = dataset.nb_classes 74 | elif args.data_set == 'INAT19': 75 | dataset = INatDataset(args.data_path, train=is_train, year=2019, 76 | category=args.inat_category, transform=transform) 77 | nb_classes = dataset.nb_classes 78 | 79 | return dataset, nb_classes 80 | 81 | 82 | def build_transform(is_train, args): 83 | resize_im = args.input_size > 32 84 | if is_train: 85 | # this should always dispatch to transforms_imagenet_train 86 | transform = create_transform( 87 | input_size=args.input_size, 88 | is_training=True, 89 | color_jitter=args.color_jitter, 90 | auto_augment=args.aa, 91 | interpolation=args.train_interpolation, 92 | re_prob=args.reprob, 93 | re_mode=args.remode, 94 | re_count=args.recount, 95 | ) 96 | if not resize_im: 97 | # replace RandomResizedCropAndInterpolation with 98 | # RandomCrop 99 | transform.transforms[0] = transforms.RandomCrop( 100 | args.input_size, padding=4) 101 | return transform 102 | 103 | t = [] 104 | if resize_im: 105 | size = int((256 / 224) * args.input_size) 106 | t.append( 107 | transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images 108 | ) 109 | t.append(transforms.CenterCrop(args.input_size)) 110 | 111 | t.append(transforms.ToTensor()) 112 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) 113 | return transforms.Compose(t) 114 | -------------------------------------------------------------------------------- /SViTE/engine.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | Train and eval functions used in main.py 5 | """ 6 | import math 7 | import sys 8 | from typing import Iterable, Optional 9 | import time 10 | import torch 11 | 12 | from timm.data import Mixup 13 | from timm.utils import accuracy, ModelEma 14 | 15 | from losses import DistillationLoss 16 | import utils 17 | import pdb 18 | import warnings 19 | warnings.filterwarnings('ignore') 20 | 21 | def get_tau(start_tau, end_tau, ite, total): 22 | tau = start_tau + (end_tau - start_tau) * ite / total 23 | return tau 24 | 25 | ite_step = 0 26 | def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss, 27 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 28 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 29 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, 30 | set_training_mode=True, mask=None, args=None): 31 | model.train(set_training_mode) 32 | metric_logger = utils.MetricLogger(delimiter=" ") 33 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 34 | header = 'Epoch: [{}]'.format(epoch) 35 | print_freq = 10 36 | # pdb.set_trace() 37 | total_iteration = len(data_loader) * (args.epochs) 38 | for samples, targets in metric_logger.log_every(data_loader, print_freq, header): 39 | samples = samples.to(device, non_blocking=True) 40 | targets = targets.to(device, non_blocking=True) 41 | 42 | global ite_step 43 | optimizer.zero_grad() 44 | if mixup_fn is not None: 45 | samples, targets = mixup_fn(samples, targets) 46 | 47 | if args.token_selection: 48 | tau = get_tau(10, 0.1, ite_step, total_iteration) 49 | else: 50 | tau = -1 51 | 52 | with torch.cuda.amp.autocast(): 53 | if args.pruning_type == 'structure': 54 | outputs, atten_pruning_indicator = model(samples, tau=tau, number=args.token_number) 55 | else: 56 | outputs = model(samples, tau=tau, number=args.token_number) 57 | atten_pruning_indicator = None 58 | 59 | loss = criterion(samples, outputs, targets) 60 | 61 | loss_value = loss.item() 62 | if not math.isfinite(loss_value): 63 | print("Loss is {}, stopping training".format(loss_value)) 64 | sys.exit(1) 65 | 66 | # this attribute is added by timm on one optimizer (adahessian) 67 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 68 | loss_scaler(loss, optimizer, clip_grad=max_norm, 69 | parameters=model.parameters(), 70 | create_graph=is_second_order) 71 | if mask is not None: 72 | mask.step(pruning_type=args.pruning_type) 73 | 74 | torch.cuda.synchronize() 75 | if model_ema is not None: 76 | model_ema.update(model) 77 | metric_logger.update(loss=loss_value) 78 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 79 | # update sparse topology 80 | 81 | ite_step = mask.steps 82 | if ite_step % args.update_frequency == 0 and ite_step < args.t_end * total_iteration: 83 | mask.at_end_of_epoch(pruning_type=args.pruning_type, 84 | indicator_list=atten_pruning_indicator) 85 | # gather the stats from all processes 86 | metric_logger.synchronize_between_processes() 87 | print("Averaged stats:", metric_logger) 88 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 89 | 90 | 91 | def train_one_epoch_training_time(model: torch.nn.Module, criterion: DistillationLoss, 92 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 93 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 94 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, 95 | set_training_mode=True, mask=None, args=None): 96 | model.train(set_training_mode) 97 | metric_logger = utils.MetricLogger(delimiter=" ") 98 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 99 | header = 'Epoch: [{}]'.format(epoch) 100 | print_freq = 10 101 | # pdb.set_trace() 102 | total_time = 0 103 | total_iteration = len(data_loader) * (args.epochs) 104 | for samples, targets in metric_logger.log_every(data_loader, print_freq, header): 105 | samples = samples.to(device, non_blocking=True) 106 | targets = targets.to(device, non_blocking=True) 107 | 108 | optimizer.zero_grad() 109 | if mixup_fn is not None: 110 | samples, targets = mixup_fn(samples, targets) 111 | 112 | with torch.cuda.amp.autocast(): 113 | 114 | start = time.time() 115 | if args.pruning_type == 'structure': 116 | outputs, atten_pruning_indicator = model(samples) 117 | 118 | elif args.token_selection: 119 | outputs = model(samples, tau=10, number=args.token_number) 120 | atten_pruning_indicator = None 121 | 122 | else: 123 | outputs = model(samples) 124 | atten_pruning_indicator = None 125 | 126 | loss = criterion(samples, outputs, targets) 127 | 128 | loss_value = loss.item() 129 | if not math.isfinite(loss_value): 130 | print("Loss is {}, stopping training".format(loss_value)) 131 | sys.exit(1) 132 | 133 | # this attribute is added by timm on one optimizer (adahessian) 134 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 135 | loss_scaler(loss, optimizer, clip_grad=max_norm, 136 | parameters=model.parameters(), 137 | create_graph=is_second_order) 138 | 139 | end = time.time() 140 | total_time += end-start 141 | global ite_step 142 | ite_step += 1 143 | if ite_step % 100 == 0: 144 | print(total_time) 145 | total_time = 0 146 | 147 | 148 | # if mask is not None: 149 | # mask.step(pruning_type=args.pruning_type) 150 | 151 | torch.cuda.synchronize() 152 | if model_ema is not None: 153 | model_ema.update(model) 154 | metric_logger.update(loss=loss_value) 155 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 156 | # update sparse topology 157 | 158 | 159 | # if ite_step % args.update_frequency == 0 and ite_step < args.t_end * total_iteration: 160 | # mask.at_end_of_epoch(pruning_type=args.pruning_type, 161 | # indicator_list=atten_pruning_indicator) 162 | # gather the stats from all processes 163 | metric_logger.synchronize_between_processes() 164 | print("Averaged stats:", metric_logger) 165 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 166 | 167 | 168 | @torch.no_grad() 169 | def evaluate(data_loader, model, device, args=None): 170 | criterion = torch.nn.CrossEntropyLoss() 171 | 172 | metric_logger = utils.MetricLogger(delimiter=" ") 173 | header = 'Test:' 174 | 175 | # switch to evaluation mode 176 | model.eval() 177 | 178 | if args.token_selection: 179 | tau = 1 180 | else: 181 | tau = -1 182 | 183 | for images, target in metric_logger.log_every(data_loader, 10, header): 184 | images = images.to(device, non_blocking=True) 185 | target = target.to(device, non_blocking=True) 186 | 187 | # compute output 188 | with torch.cuda.amp.autocast(): 189 | if args.pruning_type == 'structure': 190 | output, atten_pruning_indicator = model(images, tau=tau, number=args.token_number) 191 | else: 192 | output = model(images, tau=tau, number=args.token_number) 193 | atten_pruning_indicator = None 194 | 195 | # output = model(images) 196 | loss = criterion(output, target) 197 | 198 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 199 | 200 | batch_size = images.shape[0] 201 | metric_logger.update(loss=loss.item()) 202 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 203 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 204 | # gather the stats from all processes 205 | metric_logger.synchronize_between_processes() 206 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 207 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 208 | 209 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 210 | -------------------------------------------------------------------------------- /SViTE/hubconf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | from models import * 4 | 5 | dependencies = ["torch", "torchvision", "timm"] 6 | -------------------------------------------------------------------------------- /SViTE/losses.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | Implements the knowledge distillation loss 5 | """ 6 | import torch 7 | from torch.nn import functional as F 8 | 9 | 10 | class DistillationLoss(torch.nn.Module): 11 | """ 12 | This module wraps a standard criterion and adds an extra knowledge distillation loss by 13 | taking a teacher model prediction and using it as additional supervision. 14 | """ 15 | def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module, 16 | distillation_type: str, alpha: float, tau: float): 17 | super().__init__() 18 | self.base_criterion = base_criterion 19 | self.teacher_model = teacher_model 20 | assert distillation_type in ['none', 'soft', 'hard'] 21 | self.distillation_type = distillation_type 22 | self.alpha = alpha 23 | self.tau = tau 24 | 25 | def forward(self, inputs, outputs, labels): 26 | """ 27 | Args: 28 | inputs: The original inputs that are feed to the teacher model 29 | outputs: the outputs of the model to be trained. It is expected to be 30 | either a Tensor, or a Tuple[Tensor, Tensor], with the original output 31 | in the first position and the distillation predictions as the second output 32 | labels: the labels for the base criterion 33 | """ 34 | outputs_kd = None 35 | if not isinstance(outputs, torch.Tensor): 36 | # assume that the model outputs a tuple of [outputs, outputs_kd] 37 | outputs, outputs_kd = outputs 38 | base_loss = self.base_criterion(outputs, labels) 39 | if self.distillation_type == 'none': 40 | return base_loss 41 | 42 | if outputs_kd is None: 43 | raise ValueError("When knowledge distillation is enabled, the model is " 44 | "expected to return a Tuple[Tensor, Tensor] with the output of the " 45 | "class_token and the dist_token") 46 | # don't backprop throught the teacher 47 | with torch.no_grad(): 48 | teacher_outputs = self.teacher_model(inputs) 49 | 50 | if self.distillation_type == 'soft': 51 | T = self.tau 52 | # taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100 53 | # with slight modifications 54 | distillation_loss = F.kl_div( 55 | F.log_softmax(outputs_kd / T, dim=1), 56 | F.log_softmax(teacher_outputs / T, dim=1), 57 | reduction='sum', 58 | log_target=True 59 | ) * (T * T) / outputs_kd.numel() 60 | elif self.distillation_type == 'hard': 61 | distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim=1)) 62 | 63 | loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha 64 | return loss 65 | -------------------------------------------------------------------------------- /SViTE/requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.7.0 2 | torchvision==0.8.1 3 | timm==0.3.2 4 | -------------------------------------------------------------------------------- /SViTE/run_with_submitit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | A script to run multinode training with submitit. 5 | """ 6 | import argparse 7 | import os 8 | import uuid 9 | from pathlib import Path 10 | 11 | import main as classification 12 | import submitit 13 | 14 | 15 | def parse_args(): 16 | classification_parser = classification.get_args_parser() 17 | parser = argparse.ArgumentParser("Submitit for DeiT", parents=[classification_parser]) 18 | parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node") 19 | parser.add_argument("--nodes", default=2, type=int, help="Number of nodes to request") 20 | parser.add_argument("--timeout", default=2800, type=int, help="Duration of the job") 21 | parser.add_argument("--job_dir", default="", type=str, help="Job dir. Leave empty for automatic.") 22 | 23 | parser.add_argument("--partition", default="learnfair", type=str, help="Partition where to submit") 24 | parser.add_argument("--use_volta32", action='store_true', help="Big models? Use this") 25 | parser.add_argument('--comment', default="", type=str, 26 | help='Comment to pass to scheduler, e.g. priority message') 27 | return parser.parse_args() 28 | 29 | 30 | def get_shared_folder() -> Path: 31 | user = os.getenv("USER") 32 | if Path("/checkpoint/").is_dir(): 33 | p = Path(f"/checkpoint/{user}/experiments") 34 | p.mkdir(exist_ok=True) 35 | return p 36 | raise RuntimeError("No shared folder available") 37 | 38 | 39 | def get_init_file(): 40 | # Init file must not exist, but it's parent dir must exist. 41 | os.makedirs(str(get_shared_folder()), exist_ok=True) 42 | init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init" 43 | if init_file.exists(): 44 | os.remove(str(init_file)) 45 | return init_file 46 | 47 | 48 | class Trainer(object): 49 | def __init__(self, args): 50 | self.args = args 51 | 52 | def __call__(self): 53 | import main as classification 54 | 55 | self._setup_gpu_args() 56 | classification.main(self.args) 57 | 58 | def checkpoint(self): 59 | import os 60 | import submitit 61 | 62 | self.args.dist_url = get_init_file().as_uri() 63 | checkpoint_file = os.path.join(self.args.output_dir, "checkpoint.pth") 64 | if os.path.exists(checkpoint_file): 65 | self.args.resume = checkpoint_file 66 | print("Requeuing ", self.args) 67 | empty_trainer = type(self)(self.args) 68 | return submitit.helpers.DelayedSubmission(empty_trainer) 69 | 70 | def _setup_gpu_args(self): 71 | import submitit 72 | from pathlib import Path 73 | 74 | job_env = submitit.JobEnvironment() 75 | self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id))) 76 | self.args.gpu = job_env.local_rank 77 | self.args.rank = job_env.global_rank 78 | self.args.world_size = job_env.num_tasks 79 | print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") 80 | 81 | 82 | def main(): 83 | args = parse_args() 84 | if args.job_dir == "": 85 | args.job_dir = get_shared_folder() / "%j" 86 | 87 | # Note that the folder will depend on the job_id, to easily track experiments 88 | executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30) 89 | 90 | num_gpus_per_node = args.ngpus 91 | nodes = args.nodes 92 | timeout_min = args.timeout 93 | 94 | partition = args.partition 95 | kwargs = {} 96 | if args.use_volta32: 97 | kwargs['slurm_constraint'] = 'volta32gb' 98 | if args.comment: 99 | kwargs['slurm_comment'] = args.comment 100 | 101 | executor.update_parameters( 102 | mem_gb=40 * num_gpus_per_node, 103 | gpus_per_node=num_gpus_per_node, 104 | tasks_per_node=num_gpus_per_node, # one task per GPU 105 | cpus_per_task=10, 106 | nodes=nodes, 107 | timeout_min=timeout_min, # max is 60 * 72 108 | # Below are cluster dependent parameters 109 | slurm_partition=partition, 110 | slurm_signal_delay_s=120, 111 | **kwargs 112 | ) 113 | 114 | executor.update_parameters(name="deit") 115 | 116 | args.dist_url = get_init_file().as_uri() 117 | args.output_dir = args.job_dir 118 | 119 | trainer = Trainer(args) 120 | job = executor.submit(trainer) 121 | 122 | print("Submitted job_id:", job.job_id) 123 | 124 | 125 | if __name__ == "__main__": 126 | main() 127 | -------------------------------------------------------------------------------- /SViTE/samplers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | import torch 4 | import torch.distributed as dist 5 | import math 6 | 7 | 8 | class RASampler(torch.utils.data.Sampler): 9 | """Sampler that restricts data loading to a subset of the dataset for distributed, 10 | with repeated augmentation. 11 | It ensures that different each augmented version of a sample will be visible to a 12 | different process (GPU) 13 | Heavily based on torch.utils.data.DistributedSampler 14 | """ 15 | 16 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): 17 | if num_replicas is None: 18 | if not dist.is_available(): 19 | raise RuntimeError("Requires distributed package to be available") 20 | num_replicas = dist.get_world_size() 21 | if rank is None: 22 | if not dist.is_available(): 23 | raise RuntimeError("Requires distributed package to be available") 24 | rank = dist.get_rank() 25 | self.dataset = dataset 26 | self.num_replicas = num_replicas 27 | self.rank = rank 28 | self.epoch = 0 29 | self.num_samples = int(math.ceil(len(self.dataset) * 3.0 / self.num_replicas)) 30 | self.total_size = self.num_samples * self.num_replicas 31 | # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas)) 32 | self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas)) 33 | self.shuffle = shuffle 34 | 35 | def __iter__(self): 36 | # deterministically shuffle based on epoch 37 | g = torch.Generator() 38 | g.manual_seed(self.epoch) 39 | if self.shuffle: 40 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 41 | else: 42 | indices = list(range(len(self.dataset))) 43 | 44 | # add extra samples to make it evenly divisible 45 | indices = [ele for ele in indices for i in range(3)] 46 | indices += indices[:(self.total_size - len(indices))] 47 | assert len(indices) == self.total_size 48 | 49 | # subsample 50 | indices = indices[self.rank:self.total_size:self.num_replicas] 51 | assert len(indices) == self.num_samples 52 | 53 | return iter(indices[:self.num_selected_samples]) 54 | 55 | def __len__(self): 56 | return self.num_selected_samples 57 | 58 | def set_epoch(self, epoch): 59 | self.epoch = epoch 60 | -------------------------------------------------------------------------------- /SViTE/tox.ini: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 120 3 | ignore = F401,E402,F403,W503,W504 4 | -------------------------------------------------------------------------------- /SViTE/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | Misc functions, including distributed helpers. 5 | 6 | Mostly copy-paste from torchvision references. 7 | """ 8 | import io 9 | import os 10 | import time 11 | from collections import defaultdict, deque 12 | import datetime 13 | 14 | import torch 15 | import torch.distributed as dist 16 | 17 | 18 | class SmoothedValue(object): 19 | """Track a series of values and provide access to smoothed values over a 20 | window or the global series average. 21 | """ 22 | 23 | def __init__(self, window_size=20, fmt=None): 24 | if fmt is None: 25 | fmt = "{median:.4f} ({global_avg:.4f})" 26 | self.deque = deque(maxlen=window_size) 27 | self.total = 0.0 28 | self.count = 0 29 | self.fmt = fmt 30 | 31 | def update(self, value, n=1): 32 | self.deque.append(value) 33 | self.count += n 34 | self.total += value * n 35 | 36 | def synchronize_between_processes(self): 37 | """ 38 | Warning: does not synchronize the deque! 39 | """ 40 | if not is_dist_avail_and_initialized(): 41 | return 42 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 43 | dist.barrier() 44 | dist.all_reduce(t) 45 | t = t.tolist() 46 | self.count = int(t[0]) 47 | self.total = t[1] 48 | 49 | @property 50 | def median(self): 51 | d = torch.tensor(list(self.deque)) 52 | return d.median().item() 53 | 54 | @property 55 | def avg(self): 56 | d = torch.tensor(list(self.deque), dtype=torch.float32) 57 | return d.mean().item() 58 | 59 | @property 60 | def global_avg(self): 61 | return self.total / self.count 62 | 63 | @property 64 | def max(self): 65 | return max(self.deque) 66 | 67 | @property 68 | def value(self): 69 | return self.deque[-1] 70 | 71 | def __str__(self): 72 | return self.fmt.format( 73 | median=self.median, 74 | avg=self.avg, 75 | global_avg=self.global_avg, 76 | max=self.max, 77 | value=self.value) 78 | 79 | 80 | class MetricLogger(object): 81 | def __init__(self, delimiter="\t"): 82 | self.meters = defaultdict(SmoothedValue) 83 | self.delimiter = delimiter 84 | 85 | def update(self, **kwargs): 86 | for k, v in kwargs.items(): 87 | if isinstance(v, torch.Tensor): 88 | v = v.item() 89 | assert isinstance(v, (float, int)) 90 | self.meters[k].update(v) 91 | 92 | def __getattr__(self, attr): 93 | if attr in self.meters: 94 | return self.meters[attr] 95 | if attr in self.__dict__: 96 | return self.__dict__[attr] 97 | raise AttributeError("'{}' object has no attribute '{}'".format( 98 | type(self).__name__, attr)) 99 | 100 | def __str__(self): 101 | loss_str = [] 102 | for name, meter in self.meters.items(): 103 | loss_str.append( 104 | "{}: {}".format(name, str(meter)) 105 | ) 106 | return self.delimiter.join(loss_str) 107 | 108 | def synchronize_between_processes(self): 109 | for meter in self.meters.values(): 110 | meter.synchronize_between_processes() 111 | 112 | def add_meter(self, name, meter): 113 | self.meters[name] = meter 114 | 115 | def log_every(self, iterable, print_freq, header=None): 116 | i = 0 117 | if not header: 118 | header = '' 119 | start_time = time.time() 120 | end = time.time() 121 | iter_time = SmoothedValue(fmt='{avg:.4f}') 122 | data_time = SmoothedValue(fmt='{avg:.4f}') 123 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 124 | log_msg = [ 125 | header, 126 | '[{0' + space_fmt + '}/{1}]', 127 | 'eta: {eta}', 128 | '{meters}', 129 | 'time: {time}', 130 | 'data: {data}' 131 | ] 132 | if torch.cuda.is_available(): 133 | log_msg.append('max mem: {memory:.0f}') 134 | log_msg = self.delimiter.join(log_msg) 135 | MB = 1024.0 * 1024.0 136 | for obj in iterable: 137 | data_time.update(time.time() - end) 138 | yield obj 139 | iter_time.update(time.time() - end) 140 | if i % print_freq == 0 or i == len(iterable) - 1: 141 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 142 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 143 | if torch.cuda.is_available(): 144 | print(log_msg.format( 145 | i, len(iterable), eta=eta_string, 146 | meters=str(self), 147 | time=str(iter_time), data=str(data_time), 148 | memory=torch.cuda.max_memory_allocated() / MB)) 149 | else: 150 | print(log_msg.format( 151 | i, len(iterable), eta=eta_string, 152 | meters=str(self), 153 | time=str(iter_time), data=str(data_time))) 154 | i += 1 155 | end = time.time() 156 | total_time = time.time() - start_time 157 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 158 | print('{} Total time: {} ({:.4f} s / it)'.format( 159 | header, total_time_str, total_time / len(iterable))) 160 | 161 | 162 | def _load_checkpoint_for_ema(model_ema, checkpoint): 163 | """ 164 | Workaround for ModelEma._load_checkpoint to accept an already-loaded object 165 | """ 166 | mem_file = io.BytesIO() 167 | torch.save(checkpoint, mem_file) 168 | mem_file.seek(0) 169 | model_ema._load_checkpoint(mem_file) 170 | 171 | 172 | def setup_for_distributed(is_master): 173 | """ 174 | This function disables printing when not in master process 175 | """ 176 | import builtins as __builtin__ 177 | builtin_print = __builtin__.print 178 | 179 | def print(*args, **kwargs): 180 | force = kwargs.pop('force', False) 181 | if is_master or force: 182 | builtin_print(*args, **kwargs) 183 | 184 | __builtin__.print = print 185 | 186 | 187 | def is_dist_avail_and_initialized(): 188 | if not dist.is_available(): 189 | return False 190 | if not dist.is_initialized(): 191 | return False 192 | return True 193 | 194 | 195 | def get_world_size(): 196 | if not is_dist_avail_and_initialized(): 197 | return 1 198 | return dist.get_world_size() 199 | 200 | 201 | def get_rank(): 202 | if not is_dist_avail_and_initialized(): 203 | return 0 204 | return dist.get_rank() 205 | 206 | 207 | def is_main_process(): 208 | return get_rank() == 0 209 | 210 | 211 | def save_on_master(*args, **kwargs): 212 | if is_main_process(): 213 | torch.save(*args, **kwargs) 214 | 215 | 216 | def init_distributed_mode(args): 217 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 218 | args.rank = int(os.environ["RANK"]) 219 | args.world_size = int(os.environ['WORLD_SIZE']) 220 | args.gpu = int(os.environ['LOCAL_RANK']) 221 | elif 'SLURM_PROCID' in os.environ: 222 | args.rank = int(os.environ['SLURM_PROCID']) 223 | args.gpu = args.rank % torch.cuda.device_count() 224 | else: 225 | print('Not using distributed mode') 226 | args.distributed = False 227 | return 228 | 229 | args.distributed = True 230 | 231 | torch.cuda.set_device(args.gpu) 232 | args.dist_backend = 'nccl' 233 | print('| distributed init (rank {}): {}'.format( 234 | args.rank, args.dist_url), flush=True) 235 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 236 | world_size=args.world_size, rank=args.rank) 237 | torch.distributed.barrier() 238 | setup_for_distributed(args.rank == 0) 239 | -------------------------------------------------------------------------------- /SViTE/vision2.py: -------------------------------------------------------------------------------- 1 | """ Vision Transformer (ViT) in PyTorch 2 | """ 3 | import torch 4 | import torch.nn as nn 5 | from functools import partial 6 | 7 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 8 | # from .helpers import load_pretrained 9 | # from .layers import DropPath, to_2tuple, trunc_normal_ 10 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 11 | # from .resnet import resnet26d, resnet50d 12 | # from .registry import register_model 13 | import pdb 14 | import numpy as np 15 | 16 | def _cfg(url='', **kwargs): 17 | return { 18 | 'url': url, 19 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 20 | 'crop_pct': .9, 'interpolation': 'bicubic', 21 | 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 22 | 'first_conv': 'patch_embed.proj', 'classifier': 'head', 23 | **kwargs 24 | } 25 | 26 | 27 | def get_sparsity(module): 28 | 29 | total = module.weight.numel() 30 | zeros = (module.weight==0).sum() 31 | return zeros / total 32 | 33 | 34 | class Mlp(nn.Module): 35 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 36 | super().__init__() 37 | out_features = out_features or in_features 38 | hidden_features = hidden_features or in_features 39 | self.fc1 = nn.Linear(in_features, hidden_features) 40 | self.act = act_layer() 41 | self.fc2 = nn.Linear(hidden_features, out_features) 42 | self.drop = nn.Dropout(drop) 43 | 44 | def forward(self, x): 45 | 46 | x = self.fc1(x) 47 | x = self.act(x) 48 | x = self.drop(x) 49 | 50 | x = self.fc2(x) 51 | x = self.drop(x) 52 | 53 | return x 54 | 55 | 56 | class Attention(nn.Module): 57 | def __init__(self, dim, sparse_dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 58 | super().__init__() 59 | self.num_heads = num_heads 60 | head_dim = sparse_dim // num_heads 61 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 62 | self.scale = qk_scale or head_dim ** -0.5 63 | self.qkv = nn.Linear(dim, sparse_dim * 3, bias=qkv_bias) 64 | self.attn_drop = nn.Dropout(attn_drop) 65 | self.proj = nn.Linear(dim, dim) 66 | self.proj_drop = nn.Dropout(proj_drop) 67 | self.zero_dim = dim - sparse_dim 68 | self.sparse_dim = sparse_dim 69 | 70 | def forward(self, x): 71 | 72 | zeros = torch.zeros_like(x)[:,:,:self.zero_dim] 73 | 74 | B, N, C = x.shape 75 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.sparse_dim // self.num_heads).permute(2, 0, 3, 1, 4) 76 | 77 | q, k, v = qkv[0], qkv[1], qkv[2] # torch.Size([256, 3, 197, 64]) 78 | attn = (q @ k.transpose(-2, -1)) * self.scale # torch.Size([256, 3, 64, 197]) 79 | attn = attn.softmax(dim=-1) # torch.Size([256, 3, 197, 197]) 80 | attn = self.attn_drop(attn) 81 | 82 | x = (attn @ v).transpose(1, 2) 83 | 84 | x = x.reshape(B, N, self.sparse_dim) 85 | x = torch.cat([x, zeros], dim=2) 86 | x = self.proj(x) 87 | 88 | x = self.proj_drop(x) 89 | return x 90 | 91 | 92 | class Block(nn.Module): 93 | 94 | def __init__(self, dim, num_heads, sparse_dim, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 95 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 96 | super().__init__() 97 | self.norm1 = norm_layer(dim) 98 | self.attn = Attention( 99 | dim, sparse_dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 100 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 101 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 102 | self.norm2 = norm_layer(dim) 103 | mlp_hidden_dim = int(dim * mlp_ratio) 104 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 105 | 106 | def forward(self, x): 107 | x = x + self.drop_path(self.attn(self.norm1(x))) 108 | x = x + self.drop_path(self.mlp(self.norm2(x))) 109 | return x 110 | 111 | 112 | class PatchEmbed(nn.Module): 113 | """ Image to Patch Embedding 114 | """ 115 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 116 | super().__init__() 117 | img_size = to_2tuple(img_size) 118 | patch_size = to_2tuple(patch_size) 119 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) 120 | self.img_size = img_size 121 | self.patch_size = patch_size 122 | self.num_patches = num_patches 123 | 124 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 125 | 126 | def forward(self, x): 127 | B, C, H, W = x.shape 128 | # FIXME look at relaxing size constraints 129 | assert H == self.img_size[0] and W == self.img_size[1], \ 130 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 131 | 132 | x = self.proj(x).flatten(2).transpose(1, 2) 133 | return x 134 | 135 | 136 | class HybridEmbed(nn.Module): 137 | """ CNN Feature Map Embedding 138 | Extract feature map from CNN, flatten, project to embedding dim. 139 | """ 140 | def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768): 141 | super().__init__() 142 | assert isinstance(backbone, nn.Module) 143 | img_size = to_2tuple(img_size) 144 | self.img_size = img_size 145 | self.backbone = backbone 146 | if feature_size is None: 147 | with torch.no_grad(): 148 | # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature 149 | # map for all networks, the feature metadata has reliable channel and stride info, but using 150 | # stride to calc feature dim requires info about padding of each stage that isn't captured. 151 | training = backbone.training 152 | if training: 153 | backbone.eval() 154 | o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1] 155 | feature_size = o.shape[-2:] 156 | feature_dim = o.shape[1] 157 | backbone.train(training) 158 | else: 159 | feature_size = to_2tuple(feature_size) 160 | feature_dim = self.backbone.feature_info.channels()[-1] 161 | self.num_patches = feature_size[0] * feature_size[1] 162 | self.proj = nn.Linear(feature_dim, embed_dim) 163 | 164 | def forward(self, x): 165 | x = self.backbone(x)[-1] 166 | x = x.flatten(2).transpose(1, 2) 167 | x = self.proj(x) 168 | return x 169 | 170 | 171 | class VisionTransformer(nn.Module): 172 | """ Vision Transformer with support for patch or hybrid CNN input stage 173 | """ 174 | def __init__(self, sparse_dim, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 175 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 176 | drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm, pruning_type='unstructure', sparse=1.0): 177 | super().__init__() 178 | self.pruning_type = pruning_type 179 | self.num_classes = num_classes 180 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 181 | 182 | if hybrid_backbone is not None: 183 | self.patch_embed = HybridEmbed( 184 | hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim) 185 | else: 186 | self.patch_embed = PatchEmbed( 187 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 188 | num_patches = self.patch_embed.num_patches 189 | 190 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 191 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 192 | self.pos_drop = nn.Dropout(p=drop_rate) 193 | 194 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 195 | self.blocks = nn.ModuleList([ 196 | Block( 197 | dim=embed_dim, num_heads=num_heads, sparse_dim=sparse_dim, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 198 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) 199 | for i in range(depth)]) 200 | self.norm = norm_layer(embed_dim) 201 | 202 | # NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here 203 | #self.repr = nn.Linear(embed_dim, representation_size) 204 | #self.repr_act = nn.Tanh() 205 | 206 | # Classifier head 207 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 208 | 209 | trunc_normal_(self.pos_embed, std=.02) 210 | trunc_normal_(self.cls_token, std=.02) 211 | self.apply(self._init_weights) 212 | 213 | def _init_weights(self, m): 214 | if isinstance(m, nn.Linear): 215 | trunc_normal_(m.weight, std=.02) 216 | if isinstance(m, nn.Linear) and m.bias is not None: 217 | nn.init.constant_(m.bias, 0) 218 | elif isinstance(m, nn.LayerNorm): 219 | nn.init.constant_(m.bias, 0) 220 | nn.init.constant_(m.weight, 1.0) 221 | 222 | @torch.jit.ignore 223 | def no_weight_decay(self): 224 | return {'pos_embed', 'cls_token'} 225 | 226 | def get_classifier(self): 227 | return self.head 228 | 229 | def reset_classifier(self, num_classes, global_pool=''): 230 | self.num_classes = num_classes 231 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 232 | 233 | def forward_features(self, x): 234 | 235 | B = x.shape[0] 236 | x = self.patch_embed(x) 237 | 238 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 239 | x = torch.cat((cls_tokens, x), dim=1) 240 | x = x + self.pos_embed 241 | x = self.pos_drop(x) 242 | 243 | for blk in self.blocks: 244 | x = blk(x) 245 | 246 | x = self.norm(x) 247 | return x[:, 0] 248 | 249 | def forward(self, x): 250 | x = self.forward_features(x) 251 | x = self.head(x) 252 | 253 | return x 254 | -------------------------------------------------------------------------------- /SViTE/vision3.py: -------------------------------------------------------------------------------- 1 | """ Vision Transformer (ViT) in PyTorch 2 | """ 3 | import torch 4 | import torch.nn as nn 5 | from functools import partial 6 | 7 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 8 | # from .helpers import load_pretrained 9 | # from .layers import DropPath, to_2tuple, trunc_normal_ 10 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 11 | # from .resnet import resnet26d, resnet50d 12 | # from .registry import register_model 13 | import pdb 14 | import numpy as np 15 | 16 | def _cfg(url='', **kwargs): 17 | return { 18 | 'url': url, 19 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 20 | 'crop_pct': .9, 'interpolation': 'bicubic', 21 | 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 22 | 'first_conv': 'patch_embed.proj', 'classifier': 'head', 23 | **kwargs 24 | } 25 | 26 | 27 | def get_sparsity(module): 28 | 29 | total = module.weight.numel() 30 | zeros = (module.weight==0).sum() 31 | return zeros / total 32 | 33 | 34 | class Mlp(nn.Module): 35 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 36 | super().__init__() 37 | out_features = out_features or in_features 38 | hidden_features = hidden_features or in_features 39 | self.fc1 = nn.Linear(in_features, hidden_features) 40 | self.act = act_layer() 41 | self.fc2 = nn.Linear(hidden_features, out_features) 42 | self.drop = nn.Dropout(drop) 43 | 44 | def forward(self, x): 45 | 46 | x = self.fc1(x) 47 | x = self.act(x) 48 | x = self.drop(x) 49 | 50 | x = self.fc2(x) 51 | x = self.drop(x) 52 | 53 | return x 54 | 55 | 56 | class Attention(nn.Module): 57 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 58 | super().__init__() 59 | self.num_heads = num_heads 60 | head_dim = dim // num_heads 61 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 62 | self.scale = qk_scale or head_dim ** -0.5 63 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 64 | self.attn_drop = nn.Dropout(attn_drop) 65 | self.proj = nn.Linear(dim, dim) 66 | self.proj_drop = nn.Dropout(proj_drop) 67 | 68 | def forward(self, x): 69 | 70 | 71 | B, N, C = x.shape 72 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 73 | 74 | q, k, v = qkv[0], qkv[1], qkv[2] # torch.Size([256, 3, 197, 64]) 75 | attn = (q @ k.transpose(-2, -1)) * self.scale # torch.Size([256, 3, 64, 197]) 76 | 77 | attn = attn.softmax(dim=-1) # torch.Size([256, 3, 197, 197]) 78 | attn = self.attn_drop(attn) 79 | 80 | x = (attn @ v).transpose(1, 2) 81 | 82 | x = x.reshape(B, N, C) 83 | x = self.proj(x) 84 | x = self.proj_drop(x) 85 | 86 | return x 87 | 88 | 89 | class Block(nn.Module): 90 | 91 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 92 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 93 | super().__init__() 94 | self.norm1 = norm_layer(dim) 95 | self.attn = Attention( 96 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 97 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 98 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 99 | self.norm2 = norm_layer(dim) 100 | mlp_hidden_dim = int(dim * mlp_ratio) 101 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 102 | 103 | 104 | def forward(self, x): 105 | x = x + self.drop_path(self.attn(self.norm1(x))) 106 | x = x + self.drop_path(self.mlp(self.norm2(x))) 107 | return x 108 | 109 | 110 | class PatchEmbed(nn.Module): 111 | """ Image to Patch Embedding 112 | """ 113 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 114 | super().__init__() 115 | img_size = to_2tuple(img_size) 116 | patch_size = to_2tuple(patch_size) 117 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) 118 | self.img_size = img_size 119 | self.patch_size = patch_size 120 | self.num_patches = num_patches 121 | 122 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 123 | 124 | def forward(self, x): 125 | B, C, H, W = x.shape 126 | # FIXME look at relaxing size constraints 127 | assert H == self.img_size[0] and W == self.img_size[1], \ 128 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 129 | 130 | x = self.proj(x).flatten(2).transpose(1, 2) 131 | return x 132 | 133 | 134 | class HybridEmbed(nn.Module): 135 | """ CNN Feature Map Embedding 136 | Extract feature map from CNN, flatten, project to embedding dim. 137 | """ 138 | def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768): 139 | super().__init__() 140 | assert isinstance(backbone, nn.Module) 141 | img_size = to_2tuple(img_size) 142 | self.img_size = img_size 143 | self.backbone = backbone 144 | if feature_size is None: 145 | with torch.no_grad(): 146 | # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature 147 | # map for all networks, the feature metadata has reliable channel and stride info, but using 148 | # stride to calc feature dim requires info about padding of each stage that isn't captured. 149 | training = backbone.training 150 | if training: 151 | backbone.eval() 152 | o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1] 153 | feature_size = o.shape[-2:] 154 | feature_dim = o.shape[1] 155 | backbone.train(training) 156 | else: 157 | feature_size = to_2tuple(feature_size) 158 | feature_dim = self.backbone.feature_info.channels()[-1] 159 | self.num_patches = feature_size[0] * feature_size[1] 160 | self.proj = nn.Linear(feature_dim, embed_dim) 161 | 162 | def forward(self, x): 163 | x = self.backbone(x)[-1] 164 | x = x.flatten(2).transpose(1, 2) 165 | x = self.proj(x) 166 | return x 167 | 168 | 169 | class VisionTransformer(nn.Module): 170 | """ Vision Transformer with support for patch or hybrid CNN input stage 171 | """ 172 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 173 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 174 | drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm, pruning_type='unstructure'): 175 | super().__init__() 176 | self.pruning_type = pruning_type 177 | self.num_classes = num_classes 178 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 179 | 180 | if hybrid_backbone is not None: 181 | self.patch_embed = HybridEmbed( 182 | hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim) 183 | else: 184 | self.patch_embed = PatchEmbed( 185 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 186 | num_patches = self.patch_embed.num_patches 187 | 188 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 189 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 190 | self.pos_drop = nn.Dropout(p=drop_rate) 191 | 192 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 193 | self.blocks = nn.ModuleList([ 194 | Block( 195 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 196 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) 197 | for i in range(depth)]) 198 | self.norm = norm_layer(embed_dim) 199 | 200 | # NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here 201 | #self.repr = nn.Linear(embed_dim, representation_size) 202 | #self.repr_act = nn.Tanh() 203 | 204 | # Classifier head 205 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 206 | 207 | trunc_normal_(self.pos_embed, std=.02) 208 | trunc_normal_(self.cls_token, std=.02) 209 | self.apply(self._init_weights) 210 | 211 | def _init_weights(self, m): 212 | if isinstance(m, nn.Linear): 213 | trunc_normal_(m.weight, std=.02) 214 | if isinstance(m, nn.Linear) and m.bias is not None: 215 | nn.init.constant_(m.bias, 0) 216 | elif isinstance(m, nn.LayerNorm): 217 | nn.init.constant_(m.bias, 0) 218 | nn.init.constant_(m.weight, 1.0) 219 | 220 | @torch.jit.ignore 221 | def no_weight_decay(self): 222 | return {'pos_embed', 'cls_token'} 223 | 224 | def get_classifier(self): 225 | return self.head 226 | 227 | def reset_classifier(self, num_classes, global_pool=''): 228 | self.num_classes = num_classes 229 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 230 | 231 | def forward_features(self, x): 232 | 233 | B = x.shape[0] 234 | x = self.patch_embed(x) 235 | 236 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 237 | x = torch.cat((cls_tokens, x), dim=1) 238 | x = x + self.pos_embed 239 | x = self.pos_drop(x) 240 | 241 | for blk in self.blocks: 242 | x = blk(x) 243 | 244 | x = self.norm(x) 245 | return x[:, 0] 246 | 247 | def forward(self, x): 248 | 249 | x = self.forward_features(x) 250 | x = self.head(x) 251 | 252 | return x 253 | -------------------------------------------------------------------------------- /SViTE/vision_gumbel.py: -------------------------------------------------------------------------------- 1 | """ Vision Transformer (ViT) in PyTorch 2 | """ 3 | import torch 4 | import torch.nn as nn 5 | from functools import partial 6 | import torch.nn.functional as F 7 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 8 | # from .helpers import load_pretrained 9 | # from .layers import DropPath, to_2tuple, trunc_normal_ 10 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 11 | # from .resnet import resnet26d, resnet50d 12 | # from .registry import register_model 13 | import pdb 14 | import numpy as np 15 | 16 | def _cfg(url='', **kwargs): 17 | return { 18 | 'url': url, 19 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 20 | 'crop_pct': .9, 'interpolation': 'bicubic', 21 | 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 22 | 'first_conv': 'patch_embed.proj', 'classifier': 'head', 23 | **kwargs 24 | } 25 | 26 | 27 | def get_sparsity(module): 28 | 29 | total = module.weight.numel() 30 | zeros = (module.weight==0).sum() 31 | return zeros / total 32 | 33 | 34 | def scatter(logits, index, k): 35 | bs = logits.shape[0] 36 | #print('bs = {}'.format(bs)) 37 | 38 | x_index = torch.arange(bs).reshape(-1, 1).expand(bs,k) 39 | x_index = x_index.reshape(-1).tolist() 40 | y_index = index.reshape(-1).tolist() 41 | 42 | output = torch.zeros_like(logits).cuda() 43 | output[x_index, y_index] = 1.0 44 | #print(output.sum(dim=1)) 45 | 46 | return output 47 | 48 | def gumbel_softmax(logits, k, tau=1, hard=False, eps=1e-10, dim=-1): 49 | # type: (torch.Tensor, float, bool, float, int) -> torch.Tensor 50 | 51 | def _gen_gumbels(): 52 | gumbels = -torch.empty_like(logits).cuda().exponential_().log() 53 | if torch.isnan(gumbels).sum() or torch.isinf(gumbels).sum(): 54 | # to avoid zero in exp output 55 | gumbels = _gen_gumbels() 56 | return gumbels 57 | 58 | gumbels = _gen_gumbels() # ~Gumbel(0,1) 59 | gumbels = (logits + gumbels) / tau # ~Gumbel(logits,tau) 60 | y_soft = gumbels.softmax(dim) 61 | 62 | if hard: 63 | # Straight through. 64 | index = y_soft.topk(k, dim=dim)[1] 65 | y_hard = scatter(logits, index, k) 66 | ret = y_hard - y_soft.detach() + y_soft 67 | else: 68 | # Reparametrization trick. 69 | ret = y_soft 70 | 71 | if torch.isnan(ret).sum(): 72 | import ipdb 73 | ipdb.set_trace() 74 | raise OverflowError(f'gumbel softmax output: {ret}') 75 | return ret 76 | 77 | class Mlp(nn.Module): 78 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 79 | super().__init__() 80 | out_features = out_features or in_features 81 | hidden_features = hidden_features or in_features 82 | self.fc1 = nn.Linear(in_features, hidden_features) 83 | self.act = act_layer() 84 | self.fc2 = nn.Linear(hidden_features, out_features) 85 | self.drop = nn.Dropout(drop) 86 | 87 | def forward(self, x): 88 | 89 | x = self.fc1(x) 90 | x = self.act(x) 91 | x = self.drop(x) 92 | 93 | x = self.fc2(x) 94 | x = self.drop(x) 95 | 96 | return x 97 | 98 | 99 | class Attention(nn.Module): 100 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 101 | super().__init__() 102 | self.num_heads = num_heads 103 | head_dim = dim // num_heads 104 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 105 | self.scale = qk_scale or head_dim ** -0.5 106 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 107 | self.attn_drop = nn.Dropout(attn_drop) 108 | self.proj = nn.Linear(dim, dim) 109 | self.proj_drop = nn.Dropout(proj_drop) 110 | 111 | def forward(self, x): 112 | 113 | 114 | B, N, C = x.shape 115 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 116 | 117 | q, k, v = qkv[0], qkv[1], qkv[2] # torch.Size([256, 3, 197, 64]) 118 | attn = (q @ k.transpose(-2, -1)) * self.scale # torch.Size([256, 3, 64, 197]) 119 | 120 | attn = attn.softmax(dim=-1) # torch.Size([256, 3, 197, 197]) 121 | attn = self.attn_drop(attn) 122 | 123 | x = (attn @ v).transpose(1, 2) 124 | 125 | x = x.reshape(B, N, C) 126 | x = self.proj(x) 127 | x = self.proj_drop(x) 128 | 129 | return x 130 | 131 | 132 | class Block(nn.Module): 133 | 134 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 135 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 136 | super().__init__() 137 | self.norm1 = norm_layer(dim) 138 | self.attn = Attention( 139 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 140 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 141 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 142 | self.norm2 = norm_layer(dim) 143 | mlp_hidden_dim = int(dim * mlp_ratio) 144 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 145 | 146 | 147 | def forward(self, x): 148 | x = x + self.drop_path(self.attn(self.norm1(x))) 149 | x = x + self.drop_path(self.mlp(self.norm2(x))) 150 | return x 151 | 152 | 153 | class PatchEmbed(nn.Module): 154 | """ Image to Patch Embedding 155 | """ 156 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 157 | super().__init__() 158 | img_size = to_2tuple(img_size) 159 | patch_size = to_2tuple(patch_size) 160 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) 161 | self.img_size = img_size 162 | self.patch_size = patch_size 163 | self.num_patches = num_patches 164 | 165 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 166 | 167 | def forward(self, x): 168 | B, C, H, W = x.shape 169 | # FIXME look at relaxing size constraints 170 | assert H == self.img_size[0] and W == self.img_size[1], \ 171 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 172 | 173 | x = self.proj(x).flatten(2).transpose(1, 2) 174 | return x 175 | 176 | 177 | class HybridEmbed(nn.Module): 178 | """ CNN Feature Map Embedding 179 | Extract feature map from CNN, flatten, project to embedding dim. 180 | """ 181 | def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768): 182 | super().__init__() 183 | assert isinstance(backbone, nn.Module) 184 | img_size = to_2tuple(img_size) 185 | self.img_size = img_size 186 | self.backbone = backbone 187 | if feature_size is None: 188 | with torch.no_grad(): 189 | # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature 190 | # map for all networks, the feature metadata has reliable channel and stride info, but using 191 | # stride to calc feature dim requires info about padding of each stage that isn't captured. 192 | training = backbone.training 193 | if training: 194 | backbone.eval() 195 | o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1] 196 | feature_size = o.shape[-2:] 197 | feature_dim = o.shape[1] 198 | backbone.train(training) 199 | else: 200 | feature_size = to_2tuple(feature_size) 201 | feature_dim = self.backbone.feature_info.channels()[-1] 202 | self.num_patches = feature_size[0] * feature_size[1] 203 | self.proj = nn.Linear(feature_dim, embed_dim) 204 | 205 | def forward(self, x): 206 | x = self.backbone(x)[-1] 207 | x = x.flatten(2).transpose(1, 2) 208 | x = self.proj(x) 209 | return x 210 | 211 | 212 | class VisionTransformer(nn.Module): 213 | """ Vision Transformer with support for patch or hybrid CNN input stage 214 | """ 215 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 216 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 217 | drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm, pruning_type='unstructure'): 218 | super().__init__() 219 | self.pruning_type = pruning_type 220 | self.num_classes = num_classes 221 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 222 | 223 | if hybrid_backbone is not None: 224 | self.patch_embed = HybridEmbed( 225 | hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim) 226 | else: 227 | self.patch_embed = PatchEmbed( 228 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 229 | num_patches = self.patch_embed.num_patches 230 | 231 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 232 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 233 | self.pos_drop = nn.Dropout(p=drop_rate) 234 | 235 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 236 | self.blocks = nn.ModuleList([ 237 | Block( 238 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 239 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) 240 | for i in range(depth)]) 241 | self.norm = norm_layer(embed_dim) 242 | 243 | # NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here 244 | #self.repr = nn.Linear(embed_dim, representation_size) 245 | #self.repr_act = nn.Tanh() 246 | 247 | # Classifier head 248 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 249 | 250 | trunc_normal_(self.pos_embed, std=.02) 251 | trunc_normal_(self.cls_token, std=.02) 252 | self.apply(self._init_weights) 253 | 254 | def _init_weights(self, m): 255 | if isinstance(m, nn.Linear): 256 | trunc_normal_(m.weight, std=.02) 257 | if isinstance(m, nn.Linear) and m.bias is not None: 258 | nn.init.constant_(m.bias, 0) 259 | elif isinstance(m, nn.LayerNorm): 260 | nn.init.constant_(m.bias, 0) 261 | nn.init.constant_(m.weight, 1.0) 262 | 263 | @torch.jit.ignore 264 | def no_weight_decay(self): 265 | return {'pos_embed', 'cls_token'} 266 | 267 | def get_classifier(self): 268 | return self.head 269 | 270 | def reset_classifier(self, num_classes, global_pool=''): 271 | self.num_classes = num_classes 272 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 273 | 274 | def forward_features(self, x, tau=-1, number=197): 275 | 276 | B = x.shape[0] 277 | x = self.patch_embed(x) 278 | 279 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 280 | x = torch.cat((cls_tokens, x), dim=1) 281 | x = x + self.pos_embed 282 | x = self.pos_drop(x) 283 | 284 | x = x[:,:number,:] 285 | 286 | for blk in self.blocks: 287 | x = blk(x) 288 | 289 | x = self.norm(x) 290 | return x[:, 0] 291 | 292 | def forward(self, x, tau=-1, number=197): 293 | 294 | x = self.forward_features(x, tau, number) 295 | x = self.head(x) 296 | 297 | return x 298 | -------------------------------------------------------------------------------- /SViTE/vision_gumbel_structure.py: -------------------------------------------------------------------------------- 1 | """ Vision Transformer (ViT) in PyTorch 2 | """ 3 | import torch 4 | import torch.nn as nn 5 | from functools import partial 6 | import torch.nn.functional as F 7 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 8 | # from .helpers import load_pretrained 9 | # from .layers import DropPath, to_2tuple, trunc_normal_ 10 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 11 | # from .resnet import resnet26d, resnet50d 12 | # from .registry import register_model 13 | import pdb 14 | import numpy as np 15 | 16 | def _cfg(url='', **kwargs): 17 | return { 18 | 'url': url, 19 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 20 | 'crop_pct': .9, 'interpolation': 'bicubic', 21 | 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 22 | 'first_conv': 'patch_embed.proj', 'classifier': 'head', 23 | **kwargs 24 | } 25 | 26 | 27 | def get_sparsity(module): 28 | 29 | total = module.weight.numel() 30 | zeros = (module.weight==0).sum() 31 | return zeros / total 32 | 33 | def scatter(logits, index, k): 34 | bs = logits.shape[0] 35 | #print('bs = {}'.format(bs)) 36 | 37 | x_index = torch.arange(bs).reshape(-1, 1).expand(bs,k) 38 | x_index = x_index.reshape(-1).tolist() 39 | y_index = index.reshape(-1).tolist() 40 | 41 | output = torch.zeros_like(logits).cuda() 42 | output[x_index, y_index] = 1.0 43 | #print(output.sum(dim=1)) 44 | 45 | return output 46 | 47 | def gumbel_softmax(logits, k, tau=1, hard=False, eps=1e-10, dim=-1): 48 | # type: (torch.Tensor, float, bool, float, int) -> torch.Tensor 49 | 50 | def _gen_gumbels(): 51 | gumbels = -torch.empty_like(logits).cuda().exponential_().log() 52 | if torch.isnan(gumbels).sum() or torch.isinf(gumbels).sum(): 53 | # to avoid zero in exp output 54 | gumbels = _gen_gumbels() 55 | return gumbels 56 | 57 | gumbels = _gen_gumbels() # ~Gumbel(0,1) 58 | gumbels = (logits + gumbels) / tau # ~Gumbel(logits,tau) 59 | y_soft = gumbels.softmax(dim) 60 | 61 | if hard: 62 | # Straight through. 63 | index = y_soft.topk(k, dim=dim)[1] 64 | y_hard = scatter(logits, index, k) 65 | ret = y_hard - y_soft.detach() + y_soft 66 | else: 67 | # Reparametrization trick. 68 | ret = y_soft 69 | 70 | if torch.isnan(ret).sum(): 71 | import ipdb 72 | ipdb.set_trace() 73 | raise OverflowError(f'gumbel softmax output: {ret}') 74 | return ret 75 | 76 | 77 | class Mlp(nn.Module): 78 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 79 | super().__init__() 80 | out_features = out_features or in_features 81 | hidden_features = hidden_features or in_features 82 | self.fc1 = nn.Linear(in_features, hidden_features) 83 | self.act = act_layer() 84 | self.fc2 = nn.Linear(hidden_features, out_features) 85 | self.drop = nn.Dropout(drop) 86 | 87 | def forward(self, x): 88 | 89 | x = self.fc1(x) 90 | x = self.act(x) 91 | x = self.drop(x) 92 | 93 | x = self.fc2(x) 94 | x = self.drop(x) 95 | 96 | return x 97 | 98 | 99 | class Attention(nn.Module): 100 | def __init__(self, dim, sparse_dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 101 | super().__init__() 102 | self.num_heads = num_heads 103 | head_dim = sparse_dim // num_heads 104 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 105 | self.scale = qk_scale or head_dim ** -0.5 106 | self.qkv = nn.Linear(dim, sparse_dim * 3, bias=qkv_bias) 107 | self.attn_drop = nn.Dropout(attn_drop) 108 | self.proj = nn.Linear(dim, dim) 109 | self.proj_drop = nn.Dropout(proj_drop) 110 | self.zero_dim = dim - sparse_dim 111 | self.sparse_dim = sparse_dim 112 | 113 | def forward(self, x): 114 | 115 | zeros = torch.zeros_like(x)[:,:,:self.zero_dim] 116 | 117 | B, N, C = x.shape 118 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.sparse_dim // self.num_heads).permute(2, 0, 3, 1, 4) 119 | 120 | q, k, v = qkv[0], qkv[1], qkv[2] # torch.Size([256, 3, 197, 64]) 121 | attn = (q @ k.transpose(-2, -1)) * self.scale # torch.Size([256, 3, 64, 197]) 122 | attn = attn.softmax(dim=-1) # torch.Size([256, 3, 197, 197]) 123 | attn = self.attn_drop(attn) 124 | 125 | x = (attn @ v).transpose(1, 2) 126 | 127 | x = x.reshape(B, N, self.sparse_dim) 128 | x = torch.cat([x, zeros], dim=2) 129 | x = self.proj(x) 130 | 131 | x = self.proj_drop(x) 132 | return x 133 | 134 | 135 | class Block(nn.Module): 136 | 137 | def __init__(self, dim, num_heads, sparse_dim, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 138 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 139 | super().__init__() 140 | self.norm1 = norm_layer(dim) 141 | self.attn = Attention( 142 | dim, sparse_dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 143 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 144 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 145 | self.norm2 = norm_layer(dim) 146 | mlp_hidden_dim = int(dim * mlp_ratio) 147 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 148 | 149 | def forward(self, x): 150 | x = x + self.drop_path(self.attn(self.norm1(x))) 151 | x = x + self.drop_path(self.mlp(self.norm2(x))) 152 | return x 153 | 154 | 155 | class PatchEmbed(nn.Module): 156 | """ Image to Patch Embedding 157 | """ 158 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 159 | super().__init__() 160 | img_size = to_2tuple(img_size) 161 | patch_size = to_2tuple(patch_size) 162 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) 163 | self.img_size = img_size 164 | self.patch_size = patch_size 165 | self.num_patches = num_patches 166 | 167 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 168 | 169 | def forward(self, x): 170 | B, C, H, W = x.shape 171 | # FIXME look at relaxing size constraints 172 | assert H == self.img_size[0] and W == self.img_size[1], \ 173 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 174 | 175 | x = self.proj(x).flatten(2).transpose(1, 2) 176 | return x 177 | 178 | 179 | class HybridEmbed(nn.Module): 180 | """ CNN Feature Map Embedding 181 | Extract feature map from CNN, flatten, project to embedding dim. 182 | """ 183 | def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768): 184 | super().__init__() 185 | assert isinstance(backbone, nn.Module) 186 | img_size = to_2tuple(img_size) 187 | self.img_size = img_size 188 | self.backbone = backbone 189 | if feature_size is None: 190 | with torch.no_grad(): 191 | # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature 192 | # map for all networks, the feature metadata has reliable channel and stride info, but using 193 | # stride to calc feature dim requires info about padding of each stage that isn't captured. 194 | training = backbone.training 195 | if training: 196 | backbone.eval() 197 | o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1] 198 | feature_size = o.shape[-2:] 199 | feature_dim = o.shape[1] 200 | backbone.train(training) 201 | else: 202 | feature_size = to_2tuple(feature_size) 203 | feature_dim = self.backbone.feature_info.channels()[-1] 204 | self.num_patches = feature_size[0] * feature_size[1] 205 | self.proj = nn.Linear(feature_dim, embed_dim) 206 | 207 | def forward(self, x): 208 | x = self.backbone(x)[-1] 209 | x = x.flatten(2).transpose(1, 2) 210 | x = self.proj(x) 211 | return x 212 | 213 | 214 | class VisionTransformer(nn.Module): 215 | """ Vision Transformer with support for patch or hybrid CNN input stage 216 | """ 217 | def __init__(self, sparse_dim, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 218 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 219 | drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm, pruning_type='unstructure', sparse=1.0): 220 | super().__init__() 221 | self.pruning_type = pruning_type 222 | self.num_classes = num_classes 223 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 224 | 225 | if hybrid_backbone is not None: 226 | self.patch_embed = HybridEmbed( 227 | hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim) 228 | else: 229 | self.patch_embed = PatchEmbed( 230 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 231 | num_patches = self.patch_embed.num_patches 232 | 233 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 234 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 235 | self.pos_drop = nn.Dropout(p=drop_rate) 236 | 237 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 238 | self.blocks = nn.ModuleList([ 239 | Block( 240 | dim=embed_dim, num_heads=num_heads, sparse_dim=sparse_dim, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 241 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) 242 | for i in range(depth)]) 243 | self.norm = norm_layer(embed_dim) 244 | 245 | # NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here 246 | #self.repr = nn.Linear(embed_dim, representation_size) 247 | #self.repr_act = nn.Tanh() 248 | 249 | # Classifier head 250 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 251 | 252 | trunc_normal_(self.pos_embed, std=.02) 253 | trunc_normal_(self.cls_token, std=.02) 254 | self.apply(self._init_weights) 255 | 256 | def _init_weights(self, m): 257 | if isinstance(m, nn.Linear): 258 | trunc_normal_(m.weight, std=.02) 259 | if isinstance(m, nn.Linear) and m.bias is not None: 260 | nn.init.constant_(m.bias, 0) 261 | elif isinstance(m, nn.LayerNorm): 262 | nn.init.constant_(m.bias, 0) 263 | nn.init.constant_(m.weight, 1.0) 264 | 265 | @torch.jit.ignore 266 | def no_weight_decay(self): 267 | return {'pos_embed', 'cls_token'} 268 | 269 | def get_classifier(self): 270 | return self.head 271 | 272 | def reset_classifier(self, num_classes, global_pool=''): 273 | self.num_classes = num_classes 274 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 275 | 276 | def forward_features(self, x, tau=-1, number=197): 277 | 278 | B = x.shape[0] 279 | x = self.patch_embed(x) 280 | 281 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 282 | x = torch.cat((cls_tokens, x), dim=1) 283 | x = x + self.pos_embed 284 | x = self.pos_drop(x) 285 | 286 | x = x[:,:number,:] 287 | 288 | for blk in self.blocks: 289 | x = blk(x) 290 | 291 | x = self.norm(x) 292 | return x[:, 0] 293 | 294 | def forward(self, x, tau=-1, number=197): 295 | 296 | x = self.forward_features(x, tau, number) 297 | x = self.head(x) 298 | 299 | return x 300 | -------------------------------------------------------------------------------- /SViTE/vision_transformer.py: -------------------------------------------------------------------------------- 1 | """ Vision Transformer (ViT) in PyTorch 2 | """ 3 | import torch 4 | import torch.nn as nn 5 | from functools import partial 6 | import torch.nn.functional as F 7 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 8 | # from .helpers import load_pretrained 9 | # from .layers import DropPath, to_2tuple, trunc_normal_ 10 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 11 | # from .resnet import resnet26d, resnet50d 12 | # from .registry import register_model 13 | import pdb 14 | import numpy as np 15 | 16 | def _cfg(url='', **kwargs): 17 | return { 18 | 'url': url, 19 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 20 | 'crop_pct': .9, 'interpolation': 'bicubic', 21 | 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 22 | 'first_conv': 'patch_embed.proj', 'classifier': 'head', 23 | **kwargs 24 | } 25 | 26 | def scatter(logits, index, k): 27 | bs = logits.shape[0] 28 | #print('bs = {}'.format(bs)) 29 | 30 | x_index = torch.arange(bs).reshape(-1, 1).expand(bs,k) 31 | x_index = x_index.reshape(-1).tolist() 32 | y_index = index.reshape(-1).tolist() 33 | 34 | output = torch.zeros_like(logits).cuda() 35 | output[x_index, y_index] = 1.0 36 | #print(output.sum(dim=1)) 37 | 38 | return output 39 | 40 | def gumbel_softmax(logits, k, tau=1, hard=False, eps=1e-10, dim=-1): 41 | # type: (torch.Tensor, float, bool, float, int) -> torch.Tensor 42 | 43 | def _gen_gumbels(): 44 | gumbels = -torch.empty_like(logits).cuda().exponential_().log() 45 | if torch.isnan(gumbels).sum() or torch.isinf(gumbels).sum(): 46 | # to avoid zero in exp output 47 | gumbels = _gen_gumbels() 48 | return gumbels 49 | 50 | gumbels = _gen_gumbels() # ~Gumbel(0,1) 51 | gumbels = (logits + gumbels) / tau # ~Gumbel(logits,tau) 52 | y_soft = gumbels.softmax(dim) 53 | 54 | if hard: 55 | # Straight through. 56 | index = y_soft.topk(k, dim=dim)[1] 57 | y_hard = scatter(logits, index, k) 58 | ret = y_hard - y_soft.detach() + y_soft 59 | else: 60 | # Reparametrization trick. 61 | ret = y_soft 62 | 63 | if torch.isnan(ret).sum(): 64 | import ipdb 65 | ipdb.set_trace() 66 | raise OverflowError(f'gumbel softmax output: {ret}') 67 | return ret 68 | 69 | class Mlp(nn.Module): 70 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 71 | super().__init__() 72 | out_features = out_features or in_features 73 | hidden_features = hidden_features or in_features 74 | self.fc1 = nn.Linear(in_features, hidden_features) 75 | self.act = act_layer() 76 | self.fc2 = nn.Linear(hidden_features, out_features) 77 | self.drop = nn.Dropout(drop) 78 | 79 | def forward(self, x): 80 | x = self.fc1(x) 81 | x = self.act(x) 82 | x = self.drop(x) 83 | x = self.fc2(x) 84 | x = self.drop(x) 85 | return x 86 | 87 | class Attention(nn.Module): 88 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 89 | super().__init__() 90 | self.num_heads = num_heads 91 | head_dim = dim // num_heads 92 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 93 | self.scale = qk_scale or head_dim ** -0.5 94 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 95 | self.attn_drop = nn.Dropout(attn_drop) 96 | self.proj = nn.Linear(dim, dim) 97 | self.proj_drop = nn.Dropout(proj_drop) 98 | 99 | self.atten_values = None 100 | self.grad_scores = None 101 | self.grad_norm = None 102 | 103 | def _score(self, grad): # grad (B, N, num_heads, length) 104 | self.grad_scores = torch.einsum('bnhl,bnhl->bh', grad, self.atten_values).abs().mean(dim=0) 105 | self.grad_norm = grad.norm(dim=(1,3), p=1).mean(dim=0) 106 | 107 | def forward(self, x): 108 | B, N, C = x.shape 109 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 110 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 111 | attn = (q @ k.transpose(-2, -1)) * self.scale 112 | attn = attn.softmax(dim=-1) 113 | attn = self.attn_drop(attn) 114 | 115 | x = (attn @ v).transpose(1, 2) 116 | if self.training: 117 | self.atten_values = x 118 | x.register_hook(self._score) 119 | 120 | indicator_list = compute_indicator(x) 121 | x = x.reshape(B, N, C) 122 | x = self.proj(x) 123 | x = self.proj_drop(x) 124 | return x, indicator_list 125 | 126 | def compute_indicator(input_tensor, how='l1'): 127 | 128 | input_tensor = input_tensor.transpose(0, 2) 129 | head_num = input_tensor.shape[0] 130 | indicator_list = [] 131 | for i in range(head_num): 132 | norm = torch.norm(input_tensor[i], p=1).detach().cpu().item() 133 | indicator_list.append(norm) 134 | return indicator_list 135 | 136 | class Block(nn.Module): 137 | 138 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 139 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 140 | super().__init__() 141 | self.norm1 = norm_layer(dim) 142 | self.attn = Attention( 143 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 144 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 145 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 146 | self.norm2 = norm_layer(dim) 147 | mlp_hidden_dim = int(dim * mlp_ratio) 148 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 149 | 150 | def forward(self, x): 151 | 152 | attention, l1 = self.attn(self.norm1(x)) 153 | x = x + self.drop_path(attention) 154 | x = x + self.drop_path(self.mlp(self.norm2(x))) 155 | return x, l1 156 | # def forward(self, x): 157 | # x = x + self.drop_path(self.attn(self.norm1(x))) 158 | # x = x + self.drop_path(self.mlp(self.norm2(x))) 159 | # return x 160 | 161 | class PatchEmbed(nn.Module): 162 | """ Image to Patch Embedding 163 | """ 164 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 165 | super().__init__() 166 | img_size = to_2tuple(img_size) 167 | patch_size = to_2tuple(patch_size) 168 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) 169 | self.img_size = img_size 170 | self.patch_size = patch_size 171 | self.num_patches = num_patches 172 | 173 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 174 | 175 | def forward(self, x): 176 | B, C, H, W = x.shape 177 | # FIXME look at relaxing size constraints 178 | assert H == self.img_size[0] and W == self.img_size[1], \ 179 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 180 | x = self.proj(x).flatten(2).transpose(1, 2) 181 | return x 182 | 183 | class HybridEmbed(nn.Module): 184 | """ CNN Feature Map Embedding 185 | Extract feature map from CNN, flatten, project to embedding dim. 186 | """ 187 | def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768): 188 | super().__init__() 189 | assert isinstance(backbone, nn.Module) 190 | img_size = to_2tuple(img_size) 191 | self.img_size = img_size 192 | self.backbone = backbone 193 | if feature_size is None: 194 | with torch.no_grad(): 195 | # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature 196 | # map for all networks, the feature metadata has reliable channel and stride info, but using 197 | # stride to calc feature dim requires info about padding of each stage that isn't captured. 198 | training = backbone.training 199 | if training: 200 | backbone.eval() 201 | o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1] 202 | feature_size = o.shape[-2:] 203 | feature_dim = o.shape[1] 204 | backbone.train(training) 205 | else: 206 | feature_size = to_2tuple(feature_size) 207 | feature_dim = self.backbone.feature_info.channels()[-1] 208 | self.num_patches = feature_size[0] * feature_size[1] 209 | self.proj = nn.Linear(feature_dim, embed_dim) 210 | 211 | def forward(self, x): 212 | x = self.backbone(x)[-1] 213 | x = x.flatten(2).transpose(1, 2) 214 | x = self.proj(x) 215 | return x 216 | 217 | class VisionTransformer(nn.Module): 218 | """ Vision Transformer with support for patch or hybrid CNN input stage 219 | """ 220 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 221 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 222 | drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm, pruning_type='unstructure'): 223 | super().__init__() 224 | self.pruning_type = pruning_type 225 | self.num_classes = num_classes 226 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 227 | 228 | if hybrid_backbone is not None: 229 | self.patch_embed = HybridEmbed( 230 | hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim) 231 | else: 232 | self.patch_embed = PatchEmbed( 233 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 234 | num_patches = self.patch_embed.num_patches 235 | 236 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 237 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 238 | self.pos_drop = nn.Dropout(p=drop_rate) 239 | self.gumbel = nn.Linear(embed_dim, 1) 240 | 241 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 242 | self.blocks = nn.ModuleList([ 243 | Block( 244 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 245 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) 246 | for i in range(depth)]) 247 | self.norm = norm_layer(embed_dim) 248 | 249 | # NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here 250 | #self.repr = nn.Linear(embed_dim, representation_size) 251 | #self.repr_act = nn.Tanh() 252 | 253 | # Classifier head 254 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 255 | 256 | trunc_normal_(self.pos_embed, std=.02) 257 | trunc_normal_(self.cls_token, std=.02) 258 | self.apply(self._init_weights) 259 | 260 | def _init_weights(self, m): 261 | if isinstance(m, nn.Linear): 262 | trunc_normal_(m.weight, std=.02) 263 | if isinstance(m, nn.Linear) and m.bias is not None: 264 | nn.init.constant_(m.bias, 0) 265 | elif isinstance(m, nn.LayerNorm): 266 | nn.init.constant_(m.bias, 0) 267 | nn.init.constant_(m.weight, 1.0) 268 | 269 | @torch.jit.ignore 270 | def no_weight_decay(self): 271 | return {'pos_embed', 'cls_token'} 272 | 273 | def get_classifier(self): 274 | return self.head 275 | 276 | def reset_classifier(self, num_classes, global_pool=''): 277 | self.num_classes = num_classes 278 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 279 | 280 | def forward_features(self, x, tau=-1, number=197): 281 | 282 | l1_list = [] 283 | B = x.shape[0] 284 | x = self.patch_embed(x) 285 | 286 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 287 | x = torch.cat((cls_tokens, x), dim=1) 288 | x = x + self.pos_embed 289 | x = self.pos_drop(x) # [Batch, token, dim] 290 | 291 | if tau > 0: 292 | emb_dim = x.shape[2] 293 | token_number = x.shape[1] 294 | token_scores = self.gumbel(x) 295 | token_scores = token_scores.reshape(B, -1) 296 | token_mask = gumbel_softmax(F.log_softmax(token_scores, dim=-1), k=number, tau=tau, hard=True) 297 | token_mask[:,0] = 1. 298 | token_mask = token_mask.expand(emb_dim,-1,-1).permute(1,2,0) 299 | 300 | x = x * token_mask 301 | 302 | for blk in self.blocks: 303 | x, l1 = blk(x) 304 | l1_list.append(l1) 305 | 306 | x = self.norm(x) 307 | return x[:, 0], l1_list 308 | 309 | def forward(self, x, tau=-1, number=197): 310 | 311 | x, l1_indicator = self.forward_features(x, tau, number) 312 | x = self.head(x) 313 | if self.pruning_type == 'structure': 314 | return x, l1_indicator 315 | else: 316 | return x 317 | -------------------------------------------------------------------------------- /SViTE/vision_transformer_data.py: -------------------------------------------------------------------------------- 1 | """ Vision Transformer (ViT) in PyTorch 2 | """ 3 | import torch 4 | import torch.nn as nn 5 | from functools import partial 6 | import torch.nn.functional as F 7 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 8 | # from .helpers import load_pretrained 9 | # from .layers import DropPath, to_2tuple, trunc_normal_ 10 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 11 | # from .resnet import resnet26d, resnet50d 12 | # from .registry import register_model 13 | import pdb 14 | import numpy as np 15 | 16 | def _cfg(url='', **kwargs): 17 | return { 18 | 'url': url, 19 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 20 | 'crop_pct': .9, 'interpolation': 'bicubic', 21 | 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 22 | 'first_conv': 'patch_embed.proj', 'classifier': 'head', 23 | **kwargs 24 | } 25 | 26 | def scatter(logits, index, k): 27 | bs = logits.shape[0] 28 | #print('bs = {}'.format(bs)) 29 | 30 | x_index = torch.arange(bs).reshape(-1, 1).expand(bs,k) 31 | x_index = x_index.reshape(-1).tolist() 32 | y_index = index.reshape(-1).tolist() 33 | 34 | output = torch.zeros_like(logits).cuda() 35 | output[x_index, y_index] = 1.0 36 | #print(output.sum(dim=1)) 37 | 38 | return output 39 | 40 | def gumbel_softmax(logits, k, tau=1, hard=False, eps=1e-10, dim=-1): 41 | # type: (torch.Tensor, float, bool, float, int) -> torch.Tensor 42 | 43 | def _gen_gumbels(): 44 | gumbels = -torch.empty_like(logits).cuda().exponential_().log() 45 | if torch.isnan(gumbels).sum() or torch.isinf(gumbels).sum(): 46 | # to avoid zero in exp output 47 | gumbels = _gen_gumbels() 48 | return gumbels 49 | 50 | gumbels = _gen_gumbels() # ~Gumbel(0,1) 51 | gumbels = (logits + gumbels) / tau # ~Gumbel(logits,tau) 52 | y_soft = gumbels.softmax(dim) 53 | 54 | if hard: 55 | # Straight through. 56 | index = y_soft.topk(k, dim=dim)[1] 57 | y_hard = scatter(logits, index, k) 58 | ret = y_hard - y_soft.detach() + y_soft 59 | else: 60 | # Reparametrization trick. 61 | ret = y_soft 62 | 63 | if torch.isnan(ret).sum(): 64 | import ipdb 65 | ipdb.set_trace() 66 | raise OverflowError(f'gumbel softmax output: {ret}') 67 | return ret 68 | 69 | class Mlp(nn.Module): 70 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 71 | super().__init__() 72 | out_features = out_features or in_features 73 | hidden_features = hidden_features or in_features 74 | self.fc1 = nn.Linear(in_features, hidden_features) 75 | self.act = act_layer() 76 | self.fc2 = nn.Linear(hidden_features, out_features) 77 | self.drop = nn.Dropout(drop) 78 | 79 | def forward(self, x): 80 | x = self.fc1(x) 81 | x = self.act(x) 82 | x = self.drop(x) 83 | x = self.fc2(x) 84 | x = self.drop(x) 85 | return x 86 | 87 | class Attention(nn.Module): 88 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 89 | super().__init__() 90 | self.num_heads = num_heads 91 | head_dim = dim // num_heads 92 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 93 | self.scale = qk_scale or head_dim ** -0.5 94 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 95 | self.gumbel = nn.Linear(dim,1) 96 | self.attn_drop = nn.Dropout(attn_drop) 97 | self.proj = nn.Linear(dim, dim) 98 | self.proj_drop = nn.Dropout(proj_drop) 99 | self.score = False 100 | 101 | self.atten_values = None 102 | self.grad_scores = None 103 | self.grad_norm = None 104 | 105 | def _score(self, grad): # grad (B, N, num_heads, length) 106 | self.grad_scores = torch.einsum('bnhl,bnhl->bh', grad, self.atten_values).abs().mean(dim=0) 107 | self.grad_norm = grad.norm(dim=(1,3), p=1).mean(dim=0) 108 | 109 | def forward(self, x, tau=-1, number=197): 110 | B, N, C = x.shape 111 | 112 | if tau > 0: 113 | token_scores = self.gumbel(x) 114 | token_scores = token_scores.reshape(B, -1) 115 | token_mask = gumbel_softmax(F.log_softmax(token_scores, dim=-1), k=number, tau=tau, hard=True) 116 | token_mask[:,0] = 1. 117 | token_mask = token_mask.expand(C,-1,-1).permute(1,2,0) 118 | x = x * token_mask 119 | 120 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 121 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 122 | attn = (q @ k.transpose(-2, -1)) * self.scale 123 | attn = attn.softmax(dim=-1) 124 | attn = self.attn_drop(attn) 125 | 126 | x = (attn @ v).transpose(1, 2) 127 | if self.training and self.score: 128 | self.atten_values = x 129 | x.register_hook(self._score) 130 | 131 | x = x.reshape(B, N, C) 132 | x = self.proj(x) 133 | x = self.proj_drop(x) 134 | return x 135 | 136 | class Block(nn.Module): 137 | 138 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 139 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 140 | super().__init__() 141 | self.norm1 = norm_layer(dim) 142 | self.attn = Attention( 143 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 144 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 145 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 146 | self.norm2 = norm_layer(dim) 147 | mlp_hidden_dim = int(dim * mlp_ratio) 148 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 149 | 150 | def forward(self, x, tau, number): 151 | x = x + self.drop_path(self.attn(self.norm1(x), tau=tau, number=number)) 152 | x = x + self.drop_path(self.mlp(self.norm2(x))) 153 | return x 154 | 155 | class PatchEmbed(nn.Module): 156 | """ Image to Patch Embedding 157 | """ 158 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 159 | super().__init__() 160 | img_size = to_2tuple(img_size) 161 | patch_size = to_2tuple(patch_size) 162 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) 163 | self.img_size = img_size 164 | self.patch_size = patch_size 165 | self.num_patches = num_patches 166 | 167 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 168 | 169 | def forward(self, x): 170 | B, C, H, W = x.shape 171 | # FIXME look at relaxing size constraints 172 | assert H == self.img_size[0] and W == self.img_size[1], \ 173 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 174 | x = self.proj(x).flatten(2).transpose(1, 2) 175 | return x 176 | 177 | class HybridEmbed(nn.Module): 178 | """ CNN Feature Map Embedding 179 | Extract feature map from CNN, flatten, project to embedding dim. 180 | """ 181 | def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768): 182 | super().__init__() 183 | assert isinstance(backbone, nn.Module) 184 | img_size = to_2tuple(img_size) 185 | self.img_size = img_size 186 | self.backbone = backbone 187 | if feature_size is None: 188 | with torch.no_grad(): 189 | # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature 190 | # map for all networks, the feature metadata has reliable channel and stride info, but using 191 | # stride to calc feature dim requires info about padding of each stage that isn't captured. 192 | training = backbone.training 193 | if training: 194 | backbone.eval() 195 | o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1] 196 | feature_size = o.shape[-2:] 197 | feature_dim = o.shape[1] 198 | backbone.train(training) 199 | else: 200 | feature_size = to_2tuple(feature_size) 201 | feature_dim = self.backbone.feature_info.channels()[-1] 202 | self.num_patches = feature_size[0] * feature_size[1] 203 | self.proj = nn.Linear(feature_dim, embed_dim) 204 | 205 | def forward(self, x): 206 | x = self.backbone(x)[-1] 207 | x = x.flatten(2).transpose(1, 2) 208 | x = self.proj(x) 209 | return x 210 | 211 | class VisionTransformer(nn.Module): 212 | """ Vision Transformer with support for patch or hybrid CNN input stage 213 | """ 214 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 215 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 216 | drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm, pruning_type='unstructure'): 217 | super().__init__() 218 | self.pruning_type = pruning_type 219 | self.num_classes = num_classes 220 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 221 | 222 | if hybrid_backbone is not None: 223 | self.patch_embed = HybridEmbed( 224 | hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim) 225 | else: 226 | self.patch_embed = PatchEmbed( 227 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 228 | num_patches = self.patch_embed.num_patches 229 | 230 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 231 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 232 | self.pos_drop = nn.Dropout(p=drop_rate) 233 | self.gumbel = nn.Linear(embed_dim, 1) 234 | 235 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 236 | self.blocks = nn.ModuleList([ 237 | Block( 238 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 239 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) 240 | for i in range(depth)]) 241 | self.norm = norm_layer(embed_dim) 242 | 243 | # NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here 244 | #self.repr = nn.Linear(embed_dim, representation_size) 245 | #self.repr_act = nn.Tanh() 246 | 247 | # Classifier head 248 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 249 | 250 | trunc_normal_(self.pos_embed, std=.02) 251 | trunc_normal_(self.cls_token, std=.02) 252 | self.apply(self._init_weights) 253 | 254 | def _init_weights(self, m): 255 | if isinstance(m, nn.Linear): 256 | trunc_normal_(m.weight, std=.02) 257 | if isinstance(m, nn.Linear) and m.bias is not None: 258 | nn.init.constant_(m.bias, 0) 259 | elif isinstance(m, nn.LayerNorm): 260 | nn.init.constant_(m.bias, 0) 261 | nn.init.constant_(m.weight, 1.0) 262 | 263 | @torch.jit.ignore 264 | def no_weight_decay(self): 265 | return {'pos_embed', 'cls_token'} 266 | 267 | def get_classifier(self): 268 | return self.head 269 | 270 | def reset_classifier(self, num_classes, global_pool=''): 271 | self.num_classes = num_classes 272 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 273 | 274 | def forward_features(self, x, tau, number): 275 | 276 | B = x.shape[0] 277 | x = self.patch_embed(x) 278 | 279 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 280 | x = torch.cat((cls_tokens, x), dim=1) 281 | x = x + self.pos_embed 282 | x = self.pos_drop(x) # [Batch, token, dim] 283 | 284 | if tau > 0: 285 | emb_dim = x.shape[2] 286 | token_scores = self.gumbel(x) 287 | token_scores = token_scores.reshape(B, -1) 288 | token_mask = gumbel_softmax(F.log_softmax(token_scores, dim=-1), k=number, tau=tau, hard=True) 289 | token_mask[:,0] = 1. 290 | token_mask = token_mask.expand(emb_dim,-1,-1).permute(1,2,0) 291 | 292 | x = x * token_mask 293 | 294 | for blk in self.blocks: 295 | x = blk(x, tau, number) 296 | 297 | x = self.norm(x) 298 | 299 | return x[:, 0] 300 | 301 | def forward(self, x, tau=-1, number=197): 302 | 303 | x = self.forward_features(x, tau, number) 304 | x = self.head(x) 305 | 306 | return x 307 | --------------------------------------------------------------------------------