├── .gitignore ├── LICENSE ├── README.md ├── SPViT_DeiT ├── .gitignore ├── README.md ├── config │ ├── spvit_deit_bs_l006_t100_ft.json │ ├── spvit_deit_bs_l006_t100_search.json │ ├── spvit_deit_bs_l008_t60_ft.json │ ├── spvit_deit_bs_l008_t60_ft_dist.json │ ├── spvit_deit_bs_l008_t60_ft_param_opt.json │ ├── spvit_deit_bs_l008_t60_search.json │ ├── spvit_deit_sm_l30_t32_ft.json │ ├── spvit_deit_sm_l30_t32_ft_dist.json │ ├── spvit_deit_sm_l30_t32_search.json │ ├── spvit_deit_ti_l200_t10_ft.json │ ├── spvit_deit_ti_l200_t10_ft_dist.json │ └── spvit_deit_ti_l200_t10_search.json ├── datasets.py ├── engine.py ├── ffn_indicators │ ├── .DS_Store │ ├── spvit_deit_bs_l006_t100_search_15epoch.pth │ ├── spvit_deit_bs_l008_t60_search_10epoch.pth │ ├── spvit_deit_sm_l30_t32_search_10epoch.pth │ └── spvit_deit_ti_l200_t10_search_10epoch.pth ├── hubconf.py ├── logger.py ├── losses.py ├── main.py ├── main_pruning.py ├── models.py ├── models_pruning.py ├── params.py ├── post_training_optimize_checkpoint.py ├── requirements.txt ├── samplers.py ├── tox.ini └── utils.py └── SPViT_Swin ├── .gitignore ├── README.md ├── config.py ├── configs ├── spvit_swin_bs_l01_t100_ft.yaml ├── spvit_swin_bs_l01_t100_search.yaml ├── spvit_swin_sm_l04_t55_ft.yaml ├── spvit_swin_sm_l04_t55_ft_dist.yaml ├── spvit_swin_sm_l04_t55_search.yaml ├── spvit_swin_tn_l28_t32_ft.yaml ├── spvit_swin_tn_l28_t32_ft_dist.yaml └── spvit_swin_tn_l28_t32_search.yaml ├── data ├── __init__.py ├── build.py ├── cached_image_folder.py ├── samplers.py └── zipreader.py ├── dev ├── README.md ├── linter.sh ├── packaging │ ├── README.md │ ├── build_all_wheels.sh │ ├── build_wheel.sh │ ├── gen_install_table.py │ ├── gen_wheel_index.sh │ └── pkg_helpers.bash ├── parse_results.sh ├── run_inference_tests.sh └── run_instant_tests.sh ├── ffn_indicators ├── .DS_Store ├── spvit_swin_bs_l01_t100_search_20epoch.pth ├── spvit_swin_sm_l04_t55_search_14epoch.pth └── spvit_swin_t_l28_t32_search_12epoch.pth ├── logger.py ├── lr_scheduler.py ├── main.py ├── main_pruning.py ├── models ├── __init__.py ├── build.py ├── spvit_swin.py └── utils.py ├── optimizer.py ├── post_training_optimize_checkpoint.py ├── requirements.txt ├── setup.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

[TPAMI 2024] Pruning Self-attentions into Convolutional Layers in Single Path

2 | 3 | **This is the official repository for our paper:** [Pruning Self-attentions into Convolutional Layers in Single Path](https://arxiv.org/abs/2111.11802) by [Haoyu He](https://charles-haoyuhe.github.io/), [Jianfei Cai](https://jianfei-cai.github.io/), [Jing liu](https://sites.google.com/view/jing-liu/%E9%A6%96%E9%A1%B5), [Zizheng Pan](https://zizhengpan.github.io/), [Jing Zhang](https://scholar.google.com/citations?user=9jH5v74AAAAJ&hl=en), [Dacheng Tao](https://www.sydney.edu.au/engineering/about/our-people/academic-staff/dacheng-tao.html) and [Bohan Zhuang](https://bohanzhuang.github.io/). 4 | 5 | *** 6 | 7 | >

🚀 News

8 | > 9 | >[2023-12-29]: Accepted by TPAMI! 10 | > 11 | >[2023-06-09]: Update distillation configurations and pre-trained checkpoints. 12 | > 13 | >[2021-12-04]: Release pre-trained models. 14 | > 15 | >[2021-11-25]: Release code. 16 | 17 | *** 18 | 19 | ### Introduction: 20 | 21 | To reduce the massive computational resource consumption for ViTs and add convolutional inductive bias, **our SPViT prunes pre-trained ViT models into accurate and compact hybrid models by pruning self-attentions into convolutional layers**. Thanks to the proposed weight-sharing scheme between self-attention and convolutional layers that cast the search problem as finding which subset of parameters to use, our **SPViT has significantly reduced search cost**. 22 | 23 | *** 24 | 25 | ### Experimental results: 26 | 27 | We provide experimental results and pre-trained models for SPViT: 28 | 29 | | Name | Acc@1 | Acc@5 | # parameters | FLOPs | Model | 30 | | :------------ | :---: | :---: | ------------ | ----- | ------------------------------------------------------------ | 31 | | SPViT-DeiT-Ti | 70.7 | 90.3 | 4.9M | 1.0G | [Model](https://github.com/ziplab/SPViT/releases/download/1.0/spvit_deit_ti_l200_t10.pth) | 32 | | SPViT-DeiT-Ti* | 73.2 | 91.4 | 4.9M | 1.0G | [Model](https://github.com/ziplab/SPViT/releases/download/1.0/spvit_deit_ti_l200_t10_dist.pth) | 33 | | SPViT-DeiT-S | 78.3 | 94.3 | 16.4M | 3.3G | [Model](https://github.com/ziplab/SPViT/releases/download/1.0/spvit_deit_sm_l30_t32.pth) | 34 | | SPViT-DeiT-S* | 80.3 | 95.1 | 16.4M | 3.3G | [Model](https://github.com/ziplab/SPViT/releases/download/1.0/spvit_deit_sm_l30_t32_dist.pth) | 35 | | SPViT-DeiT-B | 81.5 | 95.7 | 46.2M | 8.3G | [Model](https://github.com/ziplab/SPViT/releases/download/1.0/spvit_deit_bs_l008_t60.pth) | 36 | | SPViT-DeiT-B* | 82.4 | 96.1 | 46.2M | 8.3G | [Model](https://github.com/ziplab/SPViT/releases/download/1.0/spvit_deit_bs_l008_t60_dist.pth) | 37 | 38 | | Name | Acc@1 | Acc@5 | # parameters | FLOPs | Model | 39 | | :------------ | :---: | :---: | ------------ | ----- | ------------------------------------------------------------ | 40 | | SPViT-Swin-Ti | 80.1 | 94.9 | 26.3M | 3.3G | [Model](https://github.com/ziplab/SPViT/releases/download/1.0/spvit_swin_t_l28_t32.pth) | 41 | | SPViT-Swin-Ti* | 81.0 | 95.3 | 26.3M | 3.3G | [Model](https://github.com/ziplab/SPViT/releases/download/1.0/spvit_swin_t_l28_t32_dist.pth) | 42 | | SPViT-Swin-S | 82.4 | 96.0 | 39.2M | 6.1G | [Model](https://github.com/ziplab/SPViT/releases/download/1.0/spvit_swin_sm_l04_t55.pth) | 43 | | SPViT-Swin-S* | 83.0 | 96.4 | 39.2M | 6.1G | [Model](https://github.com/ziplab/SPViT/releases/download/1.0/spvit_swin_sm_l04_t55_dist.pth) | 44 | 45 | * indicates knowledge distillation. 46 | ### Getting started: 47 | 48 | In this repository, we provide code for pruning two representative ViT models. 49 | 50 | - SPViT-DeiT that prunes [DeiT](https://github.com/facebookresearch/deit). Please see [SPViT_DeiT/README.md](SPViT_DeiT/README.md ) for details. 51 | - SPViT-Swin that prunes [Swin](https://github.com/microsoft/Swin-Transformer). Please see [SPViT_Swin/README.md](SPViT_Swin/README.md) for details. 52 | 53 | *** 54 | 55 | If you find our paper useful, please consider cite: 56 | 57 | ``` 58 | @article{he2024Pruning, 59 | title={Pruning Self-attentions into Convolutional Layers in Single Path}, 60 | author={He, Haoyu and Liu, Jing and Pan, Zizheng and Cai, Jianfei and Zhang, Jing and Tao, Dacheng and Zhuang, Bohan}, 61 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence}, 62 | year={2024}, 63 | publisher={IEEE} 64 | } 65 | 66 | ``` 67 | 68 | -------------------------------------------------------------------------------- /SPViT_DeiT/.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | **/__pycache__/** 3 | imnet_resnet50_scratch/timm_temp/ 4 | .dumbo.json 5 | checkpoints/ 6 | -------------------------------------------------------------------------------- /SPViT_DeiT/README.md: -------------------------------------------------------------------------------- 1 | ### Getting started on SPViT-DeiT: 2 | 3 | #### Installation and data preparation 4 | 5 | - First, you can install the required environments as illustrated in the [DeiT](https://github.com/facebookresearch/deit) repository or follow the instructions below: 6 | 7 | ```bash 8 | # Create virtual env 9 | conda create -n spvit-deit python=3.7 -y 10 | conda activate spvit-deit 11 | 12 | # Install PyTorch 1.7.0+ and torchvision 0.8.1+ and pytorch-image-models 0.3.2: 13 | conda install -c pytorch pytorch torchvision 14 | pip install timm==0.3.2 15 | ``` 16 | 17 | - Next, install some other dependencies that are required by SPViT: 18 | 19 | ```bash 20 | pip install tensorboardX tensorboard 21 | ``` 22 | 23 | - Please refer to the [DeiT](https://github.com/facebookresearch/deit) repository to prepare the standard ImageNet dataset, then link the ImageNet dataset under the `data`folder: 24 | 25 | ```bash 26 | $ tree data 27 | imagenet 28 | ├── train 29 | │ ├── class1 30 | │ │ ├── img1.jpeg 31 | │ │ ├── img2.jpeg 32 | │ │ └── ... 33 | │ ├── class2 34 | │ │ ├── img3.jpeg 35 | │ │ └── ... 36 | │ └── ... 37 | └── val 38 | ├── class1 39 | │ ├── img4.jpeg 40 | │ ├── img5.jpeg 41 | │ └── ... 42 | ├── class2 43 | │ ├── img6.jpeg 44 | │ └── ... 45 | └── ... 46 | ``` 47 | 48 | #### Download pretrained models 49 | 50 | - We start searching and fine-tuneing both from the pre-trained models. 51 | 52 | - Since we provide training scripts for three DeiT models: DeiT-Ti, DeiT-S and DeiT-B, please download the corresponding three pre-trained models from the [DeiT](https://github.com/facebookresearch/deit) repository as well. 53 | 54 | - Next, move the downloaded pre-trained models into the following file structure: 55 | 56 | ```bash 57 | $ tree model 58 | ├── deit_base_patch16_224-b5f2ef4d.pth 59 | ├── deit_small_patch16_224-cd65a155.pth 60 | ├── deit_tiny_patch16_224-a1311bcf.pth 61 | ``` 62 | 63 | - Note that do not change the filenames for the pre-trained models as we hard-coded these filenames when tailoring and loading the pre-trained models. Feel free to modify the hard-coded parts when pruning from other pre-trained models. 64 | 65 | #### Searching 66 | 67 | To search architectures with SPViT-DeiT-Ti, run: 68 | 69 | ```bash 70 | python -m torch.distributed.launch --nproc_per_node=4 --master_port=3146 --use_env main_pruning.py --config config/spvit_deit_ti_l200_t10_search.json 71 | ``` 72 | 73 | To search architectures with SPViT-DeiT-S, run: 74 | 75 | ```bash 76 | python -m torch.distributed.launch --nproc_per_node=8 --master_port=3146 --use_env main_pruning.py --config config/spvit_deit_sm_l30_t32_search.json 77 | ``` 78 | 79 | To search architectures with SPViT-DeiT-B, run: 80 | 81 | ```bash 82 | python -m torch.distributed.launch --nproc_per_node=8 --master_port=3146 --use_env main_pruning.py --config config/spvit_deit_bs_l006_t100_search.json 83 | ``` 84 | 85 | #### Fine-tuning 86 | 87 | You can start fine-tuning from either your own searched architectures or from our provided architectures by modifying and assigning the MSA indicators in `assigned_indicators` and the FFN indicators in `searching_model`. 88 | 89 | To fine-tune the architectures searched by SPViT-DeiT-Ti, run: 90 | 91 | ```bash 92 | python -m torch.distributed.launch --nproc_per_node=4 --master_port=3146 --use_env main_pruning.py --config config/spvit_deit_ti_l200_t10_ft.json 93 | ``` 94 | 95 | To fine-tune the architectures with SPViT-DeiT-S, run: 96 | 97 | ```bash 98 | python -m torch.distributed.launch --nproc_per_node=8 --master_port=3146 --use_env main_pruning.py --config config/spvit_deit_sm_l30_t32_ft.json 99 | ``` 100 | 101 | To fine-tune the architectures with SPViT-DeiT-B, run: 102 | 103 | ```bash 104 | python -m torch.distributed.launch --nproc_per_node=8 --master_port=3146 --use_env main_pruning.py --config config/spvit_deit_bs_l006_t100_ft.json 105 | ``` 106 | 107 | #### Evaluation 108 | 109 | We provide several examples for evaluating pre-trained SPViT models. 110 | 111 | To evaluate SPViT-DeiT-Ti pre-trained models, run: 112 | 113 | ```bash 114 | python -m torch.distributed.launch --nproc_per_node=1 --master_port=3146 --use_env main_pruning.py --config config/spvit_deit_ti_l200_t10_ft.json --resume [PRE-TRAINED MODEL PATH] --eval 115 | ``` 116 | 117 | To evaluate SPViT-DeiT-S pre-trained models, run: 118 | 119 | ```bash 120 | python -m torch.distributed.launch --nproc_per_node=1 --master_port=3146 --use_env main_pruning.py --config config/spvit_deit_sm_l30_t32_ft.json --resume [PRE-TRAINED MODEL PATH] --eval 121 | ``` 122 | 123 | To evaluate SPViT-DeiT-B pre-trained models, run: 124 | 125 | ```bash 126 | python -m torch.distributed.launch --nproc_per_node=1 --master_port=3146 --use_env main_pruning.py --config config/spvit_deit_bs_l006_t100_ft.json --resume [PRE-TRAINED MODEL PATH] --eval 127 | ``` 128 | 129 | After fine-tuning, you can optimize your checkpoint to a smaller size with the following code: 130 | ```bash 131 | python post_training_optimize_checkpoint.py YOUR_CHECKPOINT_PATH 132 | ``` 133 | The optimized checkpoint can be evaluated by replacing `UnifiedAttention` with `UnifiedAttentionParamOpt` and we have provided an example in `SPViT_DeiT/config/spvit_deit_bs_l008_t60_ft_param_opt.json`. 134 | 135 | #### TODO: 136 | 137 | ``` 138 | - [x] Release code. 139 | - [x] Release pre-trained models. 140 | ``` 141 | 142 | -------------------------------------------------------------------------------- /SPViT_DeiT/config/spvit_deit_bs_l006_t100_ft.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": "spvit_deit_base_patch16_224", 3 | "batch_size": 128, 4 | "data_path": "data/imagenet", 5 | "data_set": "IMNET", 6 | "exp_name": "spvit_deit_bs_l006_t100_ft", 7 | "input_size": 224, 8 | "patch_size": 16, 9 | "num_workers": 10, 10 | "att_layer": "UnifiedAttention", 11 | "ffn_layer": "UnifiedMlp", 12 | "loss_lambda": 0.06, 13 | "theta": 1.5, 14 | "target_flops": 10.0, 15 | "resume": "model/deit_base_patch16_224-b5f2ef4d.pth", 16 | "searching_model": "ffn_indicators/spvit_deit_bs_l006_t100_search_15epoch.pth", 17 | "assigned_indicators": [[1.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 1.0, 0.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], 18 | "arc_lr": 1e-3, 19 | "lr": 5e-5, 20 | "epochs": 130, 21 | "warmup_epochs": 0 22 | } -------------------------------------------------------------------------------- /SPViT_DeiT/config/spvit_deit_bs_l006_t100_search.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": "spvit_deit_base_patch16_224", 3 | "batch_size": 128, 4 | "data_path": "data/imagenet", 5 | "data_set": "IMNET", 6 | "exp_name": "spvit_deit_bs_l006_t100_search", 7 | "input_size": 224, 8 | "patch_size": 16, 9 | "num_workers": 10, 10 | "att_layer": "UnifiedAttention", 11 | "ffn_layer": "UnifiedMlp", 12 | "loss_lambda": 0.06, 13 | "theta": 1.5, 14 | "target_flops": 10.0, 15 | "resume": "model/deit_base_patch16_224-b5f2ef4d.pth", 16 | "searching_model": "", 17 | "assigned_indicators": [], 18 | "arc_lr": 1e-3, 19 | "lr": 5e-5, 20 | "min_lr": 1e-4, 21 | "warmup_epochs": 0 22 | } -------------------------------------------------------------------------------- /SPViT_DeiT/config/spvit_deit_bs_l008_t60_ft.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": "spvit_deit_base_patch16_224", 3 | "batch_size": 128, 4 | "data_path": "data/imagenet", 5 | "data_set": "IMNET", 6 | "exp_name": "spvit_deit_bs_l008_t60_ft", 7 | "input_size": 224, 8 | "patch_size": 16, 9 | "num_workers": 10, 10 | "att_layer": "UnifiedAttention", 11 | "ffn_layer": "UnifiedMlp", 12 | "loss_lambda": 0.08, 13 | "theta": 1.5, 14 | "target_flops": 6.0, 15 | "resume": "model/deit_base_patch16_224-b5f2ef4d.pth", 16 | "searching_model": "ffn_indicators/spvit_deit_bs_l008_t60_search_10epoch.pth", 17 | "assigned_indicators": [[1.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 1.0, 0.0], [1.0, 1.0, 0.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], 18 | "arc_lr": 1e-3, 19 | "lr": 5e-5, 20 | "epochs": 130, 21 | "warmup_epochs": 0 22 | } -------------------------------------------------------------------------------- /SPViT_DeiT/config/spvit_deit_bs_l008_t60_ft_dist.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": "spvit_deit_base_patch16_224", 3 | "batch_size": 128, 4 | "data_path": "data/imagenet", 5 | "data_set": "IMNET", 6 | "exp_name": "spvit_deit_bs_l008_t60_ft_dist", 7 | "input_size": 224, 8 | "patch_size": 16, 9 | "num_workers": 10, 10 | "att_layer": "UnifiedAttention", 11 | "ffn_layer": "UnifiedMlp", 12 | "loss_lambda": 0.08, 13 | "theta": 1.5, 14 | "target_flops": 6.0, 15 | "resume": "model/deit_base_patch16_224-b5f2ef4d.pth", 16 | "searching_model": "ffn_indicators/spvit_deit_bs_l008_t60_search_10epoch.pth", 17 | "assigned_indicators": [[1.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 1.0, 0.0], [1.0, 1.0, 0.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], 18 | "arc_lr": 1e-3, 19 | "lr": 5e-5, 20 | "epochs": 200, 21 | "warmup_epochs": 0, 22 | "teacher_model": "regnety_160", 23 | "teacher_path": "https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth", 24 | "distillation_type": "hard" 25 | } -------------------------------------------------------------------------------- /SPViT_DeiT/config/spvit_deit_bs_l008_t60_ft_param_opt.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": "spvit_deit_base_patch16_224", 3 | "batch_size": 128, 4 | "data_path": "data/imagenet", 5 | "data_set": "IMNET", 6 | "exp_name": "spvit_deit_bs_l008_t60_ft", 7 | "input_size": 224, 8 | "patch_size": 16, 9 | "num_workers": 10, 10 | "att_layer": "UnifiedAttentionParamOpt", 11 | "ffn_layer": "UnifiedMlp", 12 | "loss_lambda": 0.08, 13 | "theta": 1.5, 14 | "target_flops": 6.0, 15 | "resume": "model/spvit_deit_bs_l008_t60_dist_optimized.pth", 16 | "searching_model": "ffn_indicators/spvit_deit_bs_l008_t60_search_10epoch.pth", 17 | "assigned_indicators": [[1.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 1.0, 0.0], [1.0, 1.0, 0.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], 18 | "arc_lr": 1e-3, 19 | "lr": 5e-5, 20 | "epochs": 130, 21 | "warmup_epochs": 0 22 | } -------------------------------------------------------------------------------- /SPViT_DeiT/config/spvit_deit_bs_l008_t60_search.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": "spvit_deit_base_patch16_224", 3 | "batch_size": 128, 4 | "data_path": "data/imagenet", 5 | "data_set": "IMNET", 6 | "exp_name": "spvit_deit_bs_l006_t100_search", 7 | "input_size": 224, 8 | "patch_size": 16, 9 | "num_workers": 10, 10 | "att_layer": "UnifiedAttention", 11 | "ffn_layer": "UnifiedMlp", 12 | "loss_lambda": 0.08, 13 | "theta": 1.5, 14 | "target_flops": 6.0, 15 | "resume": "model/deit_base_patch16_224-b5f2ef4d.pth", 16 | "searching_model": "", 17 | "assigned_indicators": [], 18 | "arc_lr": 1e-3, 19 | "lr": 5e-5, 20 | "min_lr": 1e-4, 21 | "warmup_epochs": 0 22 | } -------------------------------------------------------------------------------- /SPViT_DeiT/config/spvit_deit_sm_l30_t32_ft.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": "spvit_deit_small_patch16_224", 3 | "batch_size": 128, 4 | "data_path": "data/imagenet", 5 | "data_set": "IMNET", 6 | "exp_name": "spvit_deit_sm_l30_t32_ft", 7 | "input_size": 224, 8 | "patch_size": 16, 9 | "num_workers": 10, 10 | "att_layer": "UnifiedAttention", 11 | "ffn_layer": "UnifiedMlp", 12 | "loss_lambda": 3.0, 13 | "theta": 1.5, 14 | "target_flops": 3.2, 15 | "resume": "model/deit_small_patch16_224-cd65a155.pth", 16 | "searching_model": "ffn_indicators/spvit_deit_sm_l30_t32_search_10epoch.pth", 17 | "assigned_indicators": [[1.0, 0.0, 0.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], 18 | "arc_lr": 5e-4, 19 | "lr": 5e-5, 20 | "epochs": 130, 21 | "warmup_epochs": 0 22 | } -------------------------------------------------------------------------------- /SPViT_DeiT/config/spvit_deit_sm_l30_t32_ft_dist.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": "spvit_deit_small_patch16_224", 3 | "batch_size": 128, 4 | "data_path": "data/imagenet", 5 | "data_set": "IMNET", 6 | "exp_name": "spvit_deit_sm_l30_t32_ft_dist", 7 | "input_size": 224, 8 | "patch_size": 16, 9 | "num_workers": 10, 10 | "att_layer": "UnifiedAttention", 11 | "ffn_layer": "UnifiedMlp", 12 | "loss_lambda": 3.0, 13 | "theta": 1.5, 14 | "target_flops": 3.2, 15 | "resume": "model/deit_small_patch16_224-cd65a155.pth", 16 | "searching_model": "ffn_indicators/spvit_deit_sm_l30_t32_search_10epoch.pth", 17 | "assigned_indicators": [[1.0, 0.0, 0.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], 18 | "arc_lr": 5e-4, 19 | "lr": 5e-5, 20 | "epochs": 200, 21 | "warmup_epochs": 0, 22 | "teacher_model": "regnety_160", 23 | "teacher_path": "https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth", 24 | "distillation_type": "hard" 25 | } -------------------------------------------------------------------------------- /SPViT_DeiT/config/spvit_deit_sm_l30_t32_search.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": "spvit_deit_small_patch16_224", 3 | "batch_size": 128, 4 | "data_path": "data/imagenet", 5 | "data_set": "IMNET", 6 | "exp_name": "spvit_deit_sm_l30_t32_search", 7 | "input_size": 224, 8 | "patch_size": 16, 9 | "num_workers": 10, 10 | "att_layer": "UnifiedAttention", 11 | "ffn_layer": "UnifiedMlp", 12 | "loss_lambda": 3.0, 13 | "theta": 1.5, 14 | "target_flops": 3.2, 15 | "resume": "model/deit_small_patch16_224-cd65a155.pth", 16 | "searching_model": "", 17 | "assigned_indicators": [], 18 | "arc_lr": 5e-4, 19 | "lr": 5e-5, 20 | "min_lr": 1e-4, 21 | "warmup_epochs": 0 22 | } -------------------------------------------------------------------------------- /SPViT_DeiT/config/spvit_deit_ti_l200_t10_ft.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": "spvit_deit_tiny_patch16_224", 3 | "batch_size": 256, 4 | "data_path": "data/imagenet", 5 | "data_set": "IMNET", 6 | "exp_name": "spvit_deit_ti_l200_t10_ft", 7 | "input_size": 224, 8 | "patch_size": 16, 9 | "num_workers": 10, 10 | "att_layer": "UnifiedAttention", 11 | "ffn_layer": "UnifiedMlp", 12 | "loss_lambda": 20.0, 13 | "theta": 1.5, 14 | "target_flops": 1.0, 15 | "resume": "model/deit_tiny_patch16_224-a1311bcf.pth", 16 | "searching_model": "ffn_indicators/spvit_deit_ti_l200_t10_search_10epoch.pth", 17 | "assigned_indicators": [[1.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], 18 | "arc_lr": 1e-3, 19 | "lr": 5e-5, 20 | "epochs": 130, 21 | "warmup_epochs": 0 22 | } -------------------------------------------------------------------------------- /SPViT_DeiT/config/spvit_deit_ti_l200_t10_ft_dist.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": "spvit_deit_tiny_patch16_224", 3 | "batch_size": 256, 4 | "data_path": "data/imagenet", 5 | "data_set": "IMNET", 6 | "exp_name": "spvit_deit_ti_l200_t10_ft_dist", 7 | "input_size": 224, 8 | "patch_size": 16, 9 | "num_workers": 10, 10 | "att_layer": "UnifiedAttention", 11 | "ffn_layer": "UnifiedMlp", 12 | "loss_lambda": 20.0, 13 | "theta": 1.5, 14 | "target_flops": 1.0, 15 | "resume": "model/deit_tiny_patch16_224-a1311bcf.pth", 16 | "searching_model": "ffn_indicators/spvit_deit_ti_l200_t10_search_10epoch.pth", 17 | "assigned_indicators": [[1.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], 18 | "arc_lr": 1e-3, 19 | "lr": 5e-5, 20 | "epochs": 200, 21 | "warmup_epochs": 0, 22 | "teacher_model": "regnety_160", 23 | "teacher_path": "https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth", 24 | "distillation_type": "hard" 25 | } -------------------------------------------------------------------------------- /SPViT_DeiT/config/spvit_deit_ti_l200_t10_search.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": "spvit_deit_tiny_patch16_224", 3 | "batch_size": 256, 4 | "data_path": "data/imagenet", 5 | "data_set": "IMNET", 6 | "exp_name": "spvit_deit_ti_l200_t10_search", 7 | "input_size": 224, 8 | "patch_size": 16, 9 | "num_workers": 10, 10 | "att_layer": "UnifiedAttention", 11 | "ffn_layer": "UnifiedMlp", 12 | "loss_lambda": 20.0, 13 | "theta": 1.5, 14 | "target_flops": 1.0, 15 | "resume": "model/deit_tiny_patch16_224-a1311bcf.pth", 16 | "searching_model": "", 17 | "assigned_indicators": [], 18 | "arc_lr": 1e-3, 19 | "lr": 5e-5, 20 | "min_lr": 1e-4, 21 | "warmup_epochs": 0 22 | } -------------------------------------------------------------------------------- /SPViT_DeiT/datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | # Modifications copyright (c) 2021 Zhuang AI Group, Haoyu He 4 | 5 | import os 6 | import json 7 | 8 | from torchvision import datasets, transforms 9 | from torchvision.datasets.folder import ImageFolder, default_loader 10 | 11 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 12 | from timm.data import create_transform 13 | 14 | 15 | class INatDataset(ImageFolder): 16 | def __init__(self, root, train=True, year=2018, transform=None, target_transform=None, 17 | category='name', loader=default_loader): 18 | self.transform = transform 19 | self.loader = loader 20 | self.target_transform = target_transform 21 | self.year = year 22 | # assert category in ['kingdom','phylum','class','order','supercategory','family','genus','name'] 23 | path_json = os.path.join(root, f'{"train" if train else "val"}{year}.json') 24 | with open(path_json) as json_file: 25 | data = json.load(json_file) 26 | 27 | with open(os.path.join(root, 'categories.json')) as json_file: 28 | data_catg = json.load(json_file) 29 | 30 | path_json_for_targeter = os.path.join(root, f"train{year}.json") 31 | 32 | with open(path_json_for_targeter) as json_file: 33 | data_for_targeter = json.load(json_file) 34 | 35 | targeter = {} 36 | indexer = 0 37 | for elem in data_for_targeter['annotations']: 38 | king = [] 39 | king.append(data_catg[int(elem['category_id'])][category]) 40 | if king[0] not in targeter.keys(): 41 | targeter[king[0]] = indexer 42 | indexer += 1 43 | self.nb_classes = len(targeter) 44 | 45 | self.samples = [] 46 | for elem in data['images']: 47 | cut = elem['file_name'].split('/') 48 | target_current = int(cut[2]) 49 | path_current = os.path.join(root, cut[0], cut[2], cut[3]) 50 | 51 | categors = data_catg[target_current] 52 | target_current_true = targeter[categors[category]] 53 | self.samples.append((path_current, target_current_true)) 54 | 55 | # __getitem__ and __len__ inherited from ImageFolder 56 | 57 | 58 | def build_dataset(is_train, args): 59 | transform = build_transform(is_train, args) 60 | 61 | if args.data_set == 'CIFAR': 62 | dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform) 63 | nb_classes = 100 64 | elif args.data_set == 'IMNET': 65 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 66 | dataset = datasets.ImageFolder(root, transform=transform) 67 | nb_classes = 1000 68 | elif args.data_set == 'INAT': 69 | dataset = INatDataset(args.data_path, train=is_train, year=2018, 70 | category=args.inat_category, transform=transform) 71 | nb_classes = dataset.nb_classes 72 | elif args.data_set == 'INAT19': 73 | dataset = INatDataset(args.data_path, train=is_train, year=2019, 74 | category=args.inat_category, transform=transform) 75 | nb_classes = dataset.nb_classes 76 | elif args.data_set == 'IMNET100': 77 | root = os.path.join(args.data_path, 'train100' if is_train else 'val100') 78 | dataset = datasets.ImageFolder(root, transform=transform) 79 | nb_classes = 100 80 | 81 | return dataset, nb_classes 82 | 83 | 84 | def build_transform(is_train, args): 85 | resize_im = args.input_size > 32 86 | if is_train: 87 | # this should always dispatch to transforms_imagenet_train 88 | transform = create_transform( 89 | input_size=args.input_size, 90 | is_training=True, 91 | color_jitter=args.color_jitter, 92 | auto_augment=args.aa, 93 | interpolation=args.train_interpolation, 94 | re_prob=args.reprob, 95 | re_mode=args.remode, 96 | re_count=args.recount, 97 | ) 98 | if not resize_im: 99 | # replace RandomResizedCropAndInterpolation with 100 | # RandomCrop 101 | transform.transforms[0] = transforms.RandomCrop( 102 | args.input_size, padding=4) 103 | return transform 104 | 105 | t = [] 106 | if resize_im: 107 | size = int((256 / 224) * args.input_size) 108 | t.append( 109 | transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images 110 | ) 111 | t.append(transforms.CenterCrop(args.input_size)) 112 | 113 | t.append(transforms.ToTensor()) 114 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) 115 | return transforms.Compose(t) 116 | -------------------------------------------------------------------------------- /SPViT_DeiT/engine.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | # Modifications copyright (c) 2021 Zhuang AI Group, Haoyu He 4 | 5 | """ 6 | Train and eval functions used in main.py 7 | """ 8 | import math 9 | import sys 10 | from typing import Iterable, Optional 11 | 12 | import torch 13 | 14 | from timm.data import Mixup 15 | from timm.utils import accuracy, ModelEma 16 | 17 | from losses import DistillationLoss 18 | import utils 19 | 20 | 21 | def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss, 22 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 23 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 24 | mixup_fn: Optional[Mixup] = None, 25 | set_training_mode=True): 26 | model.train(set_training_mode) 27 | metric_logger = utils.MetricLogger(delimiter=" ") 28 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 29 | header = 'Epoch: [{}]'.format(epoch) 30 | print_freq = 10 31 | 32 | for samples, targets in metric_logger.log_every(data_loader, print_freq, header): 33 | samples = samples.to(device, non_blocking=True) 34 | targets = targets.to(device, non_blocking=True) 35 | 36 | if mixup_fn is not None: 37 | samples, targets = mixup_fn(samples, targets) 38 | 39 | with torch.cuda.amp.autocast(): 40 | 41 | outputs = model(samples) 42 | loss = criterion(samples, outputs, targets) 43 | 44 | loss_value = loss.item() 45 | 46 | if not math.isfinite(loss_value): 47 | print("Loss is {}, stopping training".format(loss_value)) 48 | sys.exit(1) 49 | 50 | optimizer.zero_grad() 51 | 52 | # this attribute is added by timm on one optimizer (adahessian) 53 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 54 | 55 | # loss.backward() 56 | # for name, param in model.named_parameters(): 57 | # if 'thresholds' in name: 58 | # print(name, param.grad) 59 | 60 | loss_scaler(loss, optimizer, clip_grad=max_norm, 61 | parameters=model.parameters(), create_graph=is_second_order) 62 | 63 | torch.cuda.synchronize() 64 | 65 | metric_logger.update(loss=loss_value) 66 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 67 | 68 | # gather the stats from all processes 69 | metric_logger.synchronize_between_processes() 70 | print("Averaged stats:", metric_logger) 71 | 72 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 73 | 74 | 75 | def train_one_epoch_pruning(model: torch.nn.Module, criterion: DistillationLoss, 76 | data_loader: Iterable, optimizer1: torch.optim.Optimizer, optimizer2: torch.optim.Optimizer, 77 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 78 | mixup_fn: Optional[Mixup] = None, 79 | set_training_mode=True, logger=None): 80 | model.train(set_training_mode) 81 | metric_logger = utils.MetricLogger(delimiter=" ") 82 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 83 | header = 'Epoch: [{}]'.format(epoch) 84 | print_freq = 10 85 | 86 | for samples, targets in metric_logger.log_every(data_loader, print_freq, header): 87 | samples = samples.to(device, non_blocking=True) 88 | targets = targets.to(device, non_blocking=True) 89 | 90 | if mixup_fn is not None: 91 | samples, targets = mixup_fn(samples, targets) 92 | 93 | with torch.cuda.amp.autocast(): 94 | 95 | outputs, msa_indicators_list, msa_thresholds_list, ffn_indicators_list = model(samples) 96 | loss_cls = criterion(samples, outputs, targets) 97 | 98 | if not model.module.assigned_indicators: 99 | loss_bop = model.module.calculate_bops_loss() 100 | loss = loss_cls + loss_bop 101 | else: 102 | loss_bop = torch.zeros(1).to(loss_cls.device) 103 | loss = loss_cls 104 | 105 | loss_value = loss.item() 106 | 107 | if not math.isfinite(loss_value): 108 | print("Loss is {}, stopping training".format(loss_value)) 109 | sys.exit(1) 110 | 111 | optimizer1.zero_grad() 112 | optimizer2.zero_grad() 113 | 114 | # this attribute is added by timm on one optimizer (adahessian) 115 | is_second_order1 = hasattr(optimizer1, 'is_second_order') and optimizer1.is_second_order 116 | 117 | if not model.module.assigned_indicators: 118 | loss_scaler(loss, optimizer1, optimizer2, clip_grad=max_norm, create_graph=is_second_order1, model=model) 119 | else: 120 | 121 | # Not using architecture optimizer during fine-tuning 122 | loss_scaler(loss, optimizer1, None, clip_grad=max_norm, create_graph=is_second_order1, model=model) 123 | 124 | torch.cuda.synchronize() 125 | 126 | metric_logger.update(loss=loss_value) 127 | metric_logger.update(loss_cls=loss_cls.item()) 128 | metric_logger.update(loss_bop=loss_bop.item()) 129 | metric_logger.update(lr=optimizer1.param_groups[0]["lr"]) 130 | 131 | str_msa_thresholds = '' 132 | if not model.module.assigned_indicators and utils.get_rank() == 0: 133 | str_msa_thresholds = str( 134 | [["{:.3f}".format(i.item()) for i in blocks] for blocks in msa_thresholds_list]) 135 | 136 | logger.info(str_msa_thresholds) 137 | 138 | str_ffn_indicators = str( 139 | [i.item() for i in ffn_indicators_list]) 140 | 141 | logger.info(str_ffn_indicators) 142 | 143 | # break 144 | 145 | # gather the stats from all processes 146 | metric_logger.synchronize_between_processes() 147 | print("Averaged stats:", metric_logger) 148 | 149 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}, str_msa_thresholds 150 | 151 | 152 | @torch.no_grad() 153 | def evaluate(data_loader, model, device): 154 | criterion = torch.nn.CrossEntropyLoss() 155 | 156 | metric_logger = utils.MetricLogger(delimiter=" ") 157 | header = 'Test:' 158 | 159 | # switch to evaluation mode 160 | model.eval() 161 | 162 | for images, target in metric_logger.log_every(data_loader, 10, header): 163 | images = images.to(device, non_blocking=True) 164 | target = target.to(device, non_blocking=True) 165 | 166 | # compute output 167 | with torch.cuda.amp.autocast(): 168 | output = model(images) 169 | loss = criterion(output, target) 170 | 171 | # metric_logger.log_indicator(indicators) 172 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 173 | 174 | batch_size = images.shape[0] 175 | metric_logger.update(loss=loss.item()) 176 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 177 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 178 | # gather the stats from all processes 179 | metric_logger.synchronize_between_processes() 180 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 181 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 182 | 183 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 184 | 185 | 186 | @torch.no_grad() 187 | def evaluate_pruning(data_loader, model, device): 188 | criterion = torch.nn.CrossEntropyLoss() 189 | 190 | metric_logger = utils.MetricLogger(delimiter=" ") 191 | header = 'Test:' 192 | 193 | # switch to evaluation mode 194 | model.eval() 195 | 196 | for images, target in metric_logger.log_every(data_loader, 10, header): 197 | images = images.to(device, non_blocking=True) 198 | target = target.to(device, non_blocking=True) 199 | 200 | # compute output 201 | with torch.cuda.amp.autocast(): 202 | output, msa_indicators_list, msa_thresholds_list, ffn_indicators_list = model(images) 203 | loss = criterion(output, target) 204 | 205 | # metric_logger.log_indicator(indicators) 206 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 207 | 208 | batch_size = images.shape[0] 209 | metric_logger.update(loss=loss.item()) 210 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 211 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 212 | 213 | # break 214 | 215 | # gather the stats from all processes 216 | metric_logger.synchronize_between_processes() 217 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 218 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 219 | 220 | str_msa_indicators = '' 221 | str_msa_thresholds = '' 222 | str_ffn_indicators = '' 223 | str_flops = '' 224 | msa_thresholds = [] 225 | 226 | # If searching, print some stuff 227 | if not model.module.assigned_indicators and utils.get_rank() == 0: 228 | str_msa_indicators = str( 229 | [[i.item() for i in blocks] for blocks in msa_indicators_list]) 230 | 231 | print('str_msa_indicators: ', str_msa_indicators) 232 | 233 | str_msa_thresholds = str( 234 | [["{:.3f}".format(i.item()) for i in blocks] for blocks in msa_thresholds_list]) 235 | 236 | print('str_msa_thresholds: ', str_msa_thresholds) 237 | 238 | str_ffn_indicators = str( 239 | [i.item() for i in ffn_indicators_list]) 240 | 241 | print('str_ffn_indicators: ', str_ffn_indicators) 242 | 243 | str_flops = str("{:.3f}".format(model.module.flops()[0].item() / 1e9)) 244 | print('flops: ', str_flops) 245 | 246 | msa_thresholds = [[i.item() for i in blocks] for blocks in msa_thresholds_list] 247 | 248 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}, str_msa_indicators, str_msa_thresholds,\ 249 | str_ffn_indicators, str_flops, msa_thresholds 250 | -------------------------------------------------------------------------------- /SPViT_DeiT/ffn_indicators/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziplab/SPViT/ae19b0dafae3baf15c8eb4e817f4156386936866/SPViT_DeiT/ffn_indicators/.DS_Store -------------------------------------------------------------------------------- /SPViT_DeiT/ffn_indicators/spvit_deit_bs_l006_t100_search_15epoch.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziplab/SPViT/ae19b0dafae3baf15c8eb4e817f4156386936866/SPViT_DeiT/ffn_indicators/spvit_deit_bs_l006_t100_search_15epoch.pth -------------------------------------------------------------------------------- /SPViT_DeiT/ffn_indicators/spvit_deit_bs_l008_t60_search_10epoch.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziplab/SPViT/ae19b0dafae3baf15c8eb4e817f4156386936866/SPViT_DeiT/ffn_indicators/spvit_deit_bs_l008_t60_search_10epoch.pth -------------------------------------------------------------------------------- /SPViT_DeiT/ffn_indicators/spvit_deit_sm_l30_t32_search_10epoch.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziplab/SPViT/ae19b0dafae3baf15c8eb4e817f4156386936866/SPViT_DeiT/ffn_indicators/spvit_deit_sm_l30_t32_search_10epoch.pth -------------------------------------------------------------------------------- /SPViT_DeiT/ffn_indicators/spvit_deit_ti_l200_t10_search_10epoch.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziplab/SPViT/ae19b0dafae3baf15c8eb4e817f4156386936866/SPViT_DeiT/ffn_indicators/spvit_deit_ti_l200_t10_search_10epoch.pth -------------------------------------------------------------------------------- /SPViT_DeiT/hubconf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | # Modifications copyright (c) 2021 Zhuang AI Group, Haoyu He 4 | 5 | from models import * 6 | 7 | dependencies = ["torch", "torchvision", "timm"] 8 | -------------------------------------------------------------------------------- /SPViT_DeiT/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from pathlib import Path 4 | import json 5 | from datetime import datetime 6 | from params import args 7 | 8 | dt = datetime.now() 9 | dt.replace(tzinfo=datetime.now().astimezone().tzinfo) 10 | _LOG_FMT = '%(asctime)s - %(levelname)s - %(name)s - %(message)s' 11 | _DATE_FMT = '%m/%d/%Y %H:%M:%S' 12 | logging.basicConfig(format=_LOG_FMT, datefmt=_DATE_FMT, level=logging.INFO) 13 | logger = logging.getLogger('__main__') # this is the global logger 14 | current_time = datetime.now().strftime('%b%d_%H-%M-%S') 15 | output_dir = os.path.join('outputs', args.exp_name) 16 | Path(output_dir).mkdir(parents=True, exist_ok=True) 17 | checkpoint_path = os.path.join(output_dir, 'last_checkpoint.pth') 18 | 19 | 20 | # Here we auto load checkpoint even if there is a resume file 21 | setattr(args, 'auto_resume', False) 22 | if os.path.exists(checkpoint_path): 23 | setattr(args, 'resume', checkpoint_path) 24 | setattr(args, 'auto_resume', True) 25 | 26 | setattr(args, 'output_dir', output_dir) 27 | 28 | if not args.eval: 29 | log_path = os.path.join(output_dir, 'all_logs.txt') 30 | with open(os.path.join(args.output_dir, 'args.json'), 'w+') as f: 31 | json.dump(vars(args), f, indent=4) 32 | else: 33 | log_path = os.path.join(output_dir, 'eval_logs.txt') 34 | 35 | fh = logging.FileHandler(log_path, 'a+') 36 | formatter = logging.Formatter(_LOG_FMT, datefmt=_DATE_FMT) 37 | fh.setFormatter(formatter) 38 | logger.addHandler(fh) -------------------------------------------------------------------------------- /SPViT_DeiT/losses.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | # Modifications copyright (c) 2021 Zhuang AI Group, Haoyu He 4 | 5 | """ 6 | Implements the knowledge distillation loss 7 | """ 8 | import torch 9 | from torch.nn import functional as F 10 | 11 | 12 | class DistillationLoss(torch.nn.Module): 13 | """ 14 | This module wraps a standard criterion and adds an extra knowledge distillation loss by 15 | taking a teacher model prediction and using it as additional supervision. 16 | """ 17 | def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module, 18 | distillation_type: str, alpha: float, tau: float): 19 | super().__init__() 20 | self.base_criterion = base_criterion 21 | self.teacher_model = teacher_model 22 | assert distillation_type in ['none', 'soft', 'hard'] 23 | self.distillation_type = distillation_type 24 | self.alpha = alpha 25 | self.tau = tau 26 | 27 | def forward(self, inputs, outputs, labels): 28 | """ 29 | Args: 30 | inputs: The original inputs that are feed to the teacher model 31 | outputs: the outputs of the model to be trained. It is expected to be 32 | either a Tensor, or a Tuple[Tensor, Tensor], with the original output 33 | in the first position and the distillation predictions as the second output 34 | labels: the labels for the base criterion 35 | """ 36 | 37 | outputs_dist = outputs 38 | 39 | base_loss = self.base_criterion(outputs, labels) 40 | if self.distillation_type == 'none': 41 | return base_loss 42 | 43 | # don't backprop throught the teacher 44 | with torch.no_grad(): 45 | teacher_outputs = self.teacher_model(inputs) 46 | 47 | if self.distillation_type == 'soft': 48 | T = self.tau 49 | # taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100 50 | # with slight modifications 51 | distillation_loss = F.kl_div( 52 | F.log_softmax(outputs_dist / T, dim=1), 53 | F.log_softmax(teacher_outputs / T, dim=1), 54 | reduction='sum', 55 | log_target=True 56 | ) * (T * T) / outputs_dist.numel() 57 | elif self.distillation_type == 'hard': 58 | distillation_loss = F.cross_entropy(outputs_dist, teacher_outputs.argmax(dim=1)) 59 | 60 | loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha 61 | return loss 62 | -------------------------------------------------------------------------------- /SPViT_DeiT/main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | import argparse 4 | import random 5 | import datetime 6 | import numpy as np 7 | import time 8 | import torch 9 | import torch.backends.cudnn as cudnn 10 | import json 11 | import os 12 | from pathlib import Path 13 | 14 | from timm.data import Mixup 15 | from timm.models import create_model 16 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 17 | from timm.scheduler import create_scheduler 18 | from timm.optim import create_optimizer 19 | from timm.utils import NativeScaler, get_state_dict, ModelEma 20 | 21 | from datasets import build_dataset 22 | from engine import train_one_epoch, evaluate 23 | from losses import DistillationLoss 24 | from samplers import RASampler 25 | from models import Attention, get_attention_flops 26 | import utils 27 | from params import args 28 | from logger import logger 29 | 30 | from timm.models import model_entrypoint 31 | 32 | 33 | class Custom_scaler: 34 | state_dict_key = "amp_scaler" 35 | 36 | def __init__(self): 37 | self._scaler = torch.cuda.amp.GradScaler() 38 | 39 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False): 40 | self._scaler.scale(loss).backward(create_graph=create_graph) 41 | 42 | if clip_grad is not None: 43 | assert parameters is not None 44 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 45 | torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 46 | self._scaler.step(optimizer) 47 | self._scaler.update() 48 | 49 | def state_dict(self): 50 | return self._scaler.state_dict() 51 | 52 | def load_state_dict(self, state_dict): 53 | self._scaler.load_state_dict(state_dict) 54 | 55 | 56 | def main(): 57 | utils.init_distributed_mode(args) 58 | if utils.get_rank() != 0: 59 | logger.disabled = True 60 | print(args) 61 | 62 | if args.distillation_type != 'none' and args.finetune and not args.eval: 63 | raise NotImplementedError("Finetuning with distillation not yet supported") 64 | 65 | device = torch.device(args.device) 66 | 67 | # fix the seed for reproducibility 68 | torch.backends.cudnn.deterministic = True 69 | seed = args.seed + utils.get_rank() 70 | torch.manual_seed(seed) 71 | torch.cuda.manual_seed(seed) 72 | torch.cuda.manual_seed_all(seed) 73 | np.random.seed(seed) 74 | random.seed(seed) 75 | 76 | cudnn.benchmark = True 77 | 78 | dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) 79 | dataset_val, _ = build_dataset(is_train=False, args=args) 80 | 81 | if True: # args.distributed: 82 | num_tasks = utils.get_world_size() 83 | global_rank = utils.get_rank() 84 | if args.repeated_aug: 85 | sampler_train = RASampler( 86 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 87 | ) 88 | else: 89 | sampler_train = torch.utils.data.DistributedSampler( 90 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 91 | ) 92 | if args.dist_eval: 93 | if len(dataset_val) % num_tasks != 0: 94 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 95 | 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 96 | 'equal num of samples per-process.') 97 | sampler_val = torch.utils.data.DistributedSampler( 98 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) 99 | else: 100 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 101 | else: 102 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 103 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 104 | 105 | data_loader_train = torch.utils.data.DataLoader( 106 | dataset_train, sampler=sampler_train, 107 | batch_size=args.batch_size, 108 | num_workers=args.num_workers, 109 | pin_memory=args.pin_mem, 110 | drop_last=True, 111 | ) 112 | 113 | data_loader_val = torch.utils.data.DataLoader( 114 | dataset_val, sampler=sampler_val, 115 | batch_size=int(1.5 * args.batch_size), 116 | num_workers=args.num_workers, 117 | pin_memory=args.pin_mem, 118 | drop_last=False 119 | ) 120 | 121 | mixup_fn = None 122 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None 123 | if mixup_active: 124 | mixup_fn = Mixup( 125 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, 126 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, 127 | label_smoothing=args.smoothing, num_classes=args.nb_classes) 128 | 129 | logger.info(f"Creating model: {args.model}") 130 | model = create_model( 131 | args.model, 132 | pretrained=False, 133 | num_classes=args.nb_classes, 134 | drop_rate=args.drop, 135 | drop_path_rate=args.drop_path, 136 | drop_block_rate=None, 137 | # att_mode=args.att_mode 138 | ) 139 | 140 | # if utils.get_rank() == 0: 141 | # # print_size_of_model(model) 142 | # try: 143 | # from ptflops import get_model_complexity_info 144 | # macs, params = get_model_complexity_info(model, (3, args.input_size, args.input_size), as_strings=True, 145 | # print_per_layer_stat=False, verbose=False, custom_modules_hooks={Attention:get_attention_flops}) 146 | # # flops = macs 147 | # logger.info('{:<30} {:<8}'.format('MACs: ', macs)) 148 | # logger.info('{:<30} {:<8}'.format('Number of parameters: ', params)) 149 | # except: 150 | # pass 151 | 152 | if args.finetune: 153 | if args.finetune.startswith('https'): 154 | checkpoint = torch.hub.load_state_dict_from_url( 155 | args.finetune, map_location='cpu', check_hash=True) 156 | else: 157 | checkpoint = torch.load(args.finetune, map_location='cpu') 158 | 159 | checkpoint_model = checkpoint['model'] 160 | state_dict = model.state_dict() 161 | for k in ['head.weight', 'head.bias', 'head_dist.weight', 'head_dist.bias']: 162 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: 163 | print(f"Removing key {k} from pretrained checkpoint") 164 | del checkpoint_model[k] 165 | 166 | # interpolate position embedding 167 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 168 | embedding_size = pos_embed_checkpoint.shape[-1] 169 | num_patches = model.patch_embed.num_patches 170 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 171 | # height (== width) for the checkpoint position embedding 172 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 173 | # height (== width) for the new position embedding 174 | new_size = int(num_patches ** 0.5) 175 | # class_token and dist_token are kept unchanged 176 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 177 | # only the position tokens are interpolated 178 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 179 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 180 | pos_tokens = torch.nn.functional.interpolate( 181 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 182 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 183 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 184 | checkpoint_model['pos_embed'] = new_pos_embed 185 | 186 | model.load_state_dict(checkpoint_model, strict=False) 187 | 188 | model.to(device) 189 | 190 | model_without_ddp = model 191 | if args.distributed: 192 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False) 193 | model_without_ddp = model.module 194 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 195 | logger.info('number of params: ' + str(n_parameters)) 196 | 197 | linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0 198 | args.lr = linear_scaled_lr 199 | optimizer = create_optimizer(args, model_without_ddp) 200 | loss_scaler = NativeScaler() 201 | 202 | lr_scheduler, _ = create_scheduler(args, optimizer) 203 | 204 | criterion = LabelSmoothingCrossEntropy() 205 | 206 | if args.mixup > 0.: 207 | # smoothing is handled with mixup label transform 208 | criterion = SoftTargetCrossEntropy() 209 | elif args.smoothing: 210 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) 211 | else: 212 | criterion = torch.nn.CrossEntropyLoss() 213 | 214 | teacher_model = None 215 | if args.distillation_type != 'none': 216 | assert args.teacher_path, 'need to specify teacher-path when using distillation' 217 | print(f"Creating teacher model: {args.teacher_model}") 218 | teacher_model = create_model( 219 | args.teacher_model, 220 | pretrained=False, 221 | num_classes=args.nb_classes, 222 | global_pool='avg', 223 | ) 224 | if args.teacher_path.startswith('https'): 225 | checkpoint = torch.hub.load_state_dict_from_url( 226 | args.teacher_path, map_location='cpu', check_hash=True) 227 | else: 228 | checkpoint = torch.load(args.teacher_path, map_location='cpu') 229 | teacher_model.load_state_dict(checkpoint['model']) 230 | teacher_model.to(device) 231 | teacher_model.eval() 232 | 233 | # wrap the criterion in our custom DistillationLoss, which 234 | # just dispatches to the original criterion if args.distillation_type is 'none' 235 | criterion = DistillationLoss( 236 | criterion, teacher_model, args.distillation_type, args.distillation_alpha, args.distillation_tau 237 | ) 238 | 239 | output_dir = Path(args.output_dir) 240 | if args.resume: 241 | if args.resume.startswith('https'): 242 | checkpoint = torch.hub.load_state_dict_from_url( 243 | args.resume, map_location='cpu', check_hash=True) 244 | else: 245 | checkpoint = torch.load(args.resume, map_location='cpu') 246 | model_without_ddp.load_state_dict(checkpoint['model'], strict=False) 247 | if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: 248 | optimizer.load_state_dict(checkpoint['optimizer']) 249 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 250 | args.start_epoch = checkpoint['epoch'] + 1 251 | # if args.model_ema: 252 | # utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema']) 253 | if 'scaler' in checkpoint: 254 | loss_scaler.load_state_dict(checkpoint['scaler']) 255 | if args.eval: 256 | test_stats = evaluate(data_loader_val, model, device) 257 | logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 258 | return 259 | if args.throughput: 260 | throughput(data_loader_val, model, logger) 261 | return 262 | 263 | logger.info(f"Start training for {args.epochs} epochs") 264 | start_time = time.time() 265 | max_accuracy = 0.0 266 | for epoch in range(args.start_epoch, args.epochs): 267 | if args.distributed: 268 | data_loader_train.sampler.set_epoch(epoch) 269 | 270 | train_stats = train_one_epoch( 271 | model, criterion, data_loader_train, 272 | optimizer, device, epoch, loss_scaler, 273 | args.clip_grad, mixup_fn, 274 | set_training_mode=args.finetune == '' # keep in eval mode during finetuning 275 | ) 276 | 277 | lr_scheduler.step(epoch) 278 | if args.output_dir: 279 | checkpoint_paths = [output_dir / 'last_checkpoint.pth'] 280 | for checkpoint_path in checkpoint_paths: 281 | utils.save_on_master({ 282 | 'model': model_without_ddp.state_dict(), 283 | 'optimizer': optimizer.state_dict(), 284 | 'lr_scheduler': lr_scheduler.state_dict(), 285 | 'epoch': epoch, 286 | 'scaler': loss_scaler.state_dict(), 287 | 'args': args, 288 | }, checkpoint_path) 289 | 290 | test_stats = evaluate(data_loader_val, model, device) 291 | logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 292 | if max_accuracy < test_stats["acc1"]: 293 | utils.save_on_master({ 294 | 'model': model_without_ddp.state_dict(), 295 | 'optimizer': optimizer.state_dict(), 296 | 'lr_scheduler': lr_scheduler.state_dict(), 297 | 'epoch': epoch, 298 | 'scaler': loss_scaler.state_dict(), 299 | 'args': args, 300 | }, os.path.join(args.output_dir, 'best_checkpoint.pth')) 301 | 302 | max_accuracy = max(max_accuracy, test_stats["acc1"]) 303 | logger.info(f'Max accuracy: {max_accuracy:.2f}%') 304 | 305 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 306 | **{f'test_{k}': v for k, v in test_stats.items()}, 307 | 'epoch': epoch, 308 | 'n_parameters': n_parameters} 309 | 310 | if args.output_dir and utils.is_main_process(): 311 | with (output_dir / "log.txt").open("a") as f: 312 | f.write(json.dumps(log_stats) + "\n") 313 | 314 | total_time = time.time() - start_time 315 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 316 | logger.info('Training time {}'.format(total_time_str)) 317 | 318 | 319 | @torch.no_grad() 320 | def throughput(data_loader, model, logger): 321 | model.eval() 322 | 323 | for idx, (images, _) in enumerate(data_loader): 324 | images = images.cuda(non_blocking=True) 325 | batch_size = images.shape[0] 326 | for i in range(50): 327 | model(images) 328 | torch.cuda.synchronize() 329 | logger.info(f"throughput averaged with 100 times") 330 | tic1 = time.time() 331 | for i in range(100): 332 | model(images) 333 | torch.cuda.synchronize() 334 | tic2 = time.time() 335 | logger.info(f"batch_size {batch_size} throughput {100 * batch_size / (tic2 - tic1)}") 336 | return 337 | 338 | 339 | if __name__ == '__main__': 340 | main() 341 | -------------------------------------------------------------------------------- /SPViT_DeiT/params.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | import json 4 | 5 | 6 | parser = argparse.ArgumentParser('DeiT training and evaluation script', add_help=False) 7 | parser.add_argument('--batch-size', default=64, type=int) 8 | parser.add_argument('--epochs', default=300, type=int) 9 | 10 | # Model parameters 11 | parser.add_argument('--model', default='deit_base_patch16_224', type=str, metavar='MODEL', 12 | help='Name of model to train') 13 | parser.add_argument('--input-size', default=224, type=int, help='images input size') 14 | 15 | parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', 16 | help='Dropout rate (default: 0.)') 17 | parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT', 18 | help='Drop path rate (default: 0.1)') 19 | 20 | parser.add_argument('--model-ema', action='store_true') 21 | parser.add_argument('--no-model-ema', action='store_false', dest='model_ema') 22 | parser.set_defaults(model_ema=True) 23 | parser.add_argument('--model-ema-decay', type=float, default=0.99996, help='') 24 | parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='') 25 | 26 | # Optimizer parameters 27 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', 28 | help='Optimizer (default: "adamw"') 29 | parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON', 30 | help='Optimizer Epsilon (default: 1e-8)') 31 | parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', 32 | help='Optimizer Betas (default: None, use opt default)') 33 | parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM', 34 | help='Clip gradient norm (default: None, no clipping)') 35 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 36 | help='SGD momentum (default: 0.9)') 37 | parser.add_argument('--weight-decay', type=float, default=0.05, 38 | help='weight decay (default: 0.05)') 39 | # Learning rate schedule parameters 40 | parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', 41 | help='LR scheduler (default: "cosine"') 42 | parser.add_argument('--lr', type=float, default=5e-4, metavar='LR', 43 | help='learning rate (default: 5e-4)') 44 | parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', 45 | help='learning rate noise on/off epoch percentages') 46 | parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', 47 | help='learning rate noise limit percent (default: 0.67)') 48 | parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', 49 | help='learning rate noise std-dev (default: 1.0)') 50 | parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR', 51 | help='warmup learning rate (default: 1e-6)') 52 | parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', 53 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') 54 | 55 | parser.add_argument('--decay-epochs', type=float, default=30, metavar='N', 56 | help='epoch interval to decay LR') 57 | parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N', 58 | help='epochs to warmup LR, if scheduler supports') 59 | parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', 60 | help='epochs to cooldown LR at min_lr, after cyclic schedule ends') 61 | parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', 62 | help='patience epochs for Plateau LR scheduler (default: 10') 63 | parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', 64 | help='LR decay rate (default: 0.1)') 65 | 66 | # Augmentation parameters 67 | parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', 68 | help='Color jitter factor (default: 0.4)') 69 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', 70 | help='Use AutoAugment policy. "v0" or "original". " + \ 71 | "(default: rand-m9-mstd0.5-inc1)'), 72 | parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)') 73 | parser.add_argument('--train-interpolation', type=str, default='bicubic', 74 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")') 75 | 76 | parser.add_argument('--repeated-aug', action='store_true') 77 | parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug') 78 | parser.set_defaults(repeated_aug=True) 79 | 80 | # * Random Erase params 81 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', 82 | help='Random erase prob (default: 0.25)') 83 | parser.add_argument('--remode', type=str, default='pixel', 84 | help='Random erase mode (default: "pixel")') 85 | parser.add_argument('--recount', type=int, default=1, 86 | help='Random erase count (default: 1)') 87 | parser.add_argument('--resplit', action='store_true', default=False, 88 | help='Do not random erase first (clean) augmentation split') 89 | 90 | # * Mixup params 91 | parser.add_argument('--mixup', type=float, default=0.8, 92 | help='mixup alpha, mixup enabled if > 0. (default: 0.8)') 93 | parser.add_argument('--cutmix', type=float, default=1.0, 94 | help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)') 95 | parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None, 96 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') 97 | parser.add_argument('--mixup-prob', type=float, default=1.0, 98 | help='Probability of performing mixup or cutmix when either/both is enabled') 99 | parser.add_argument('--mixup-switch-prob', type=float, default=0.5, 100 | help='Probability of switching to cutmix when both mixup and cutmix enabled') 101 | parser.add_argument('--mixup-mode', type=str, default='batch', 102 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') 103 | 104 | # Distillation parameters 105 | parser.add_argument('--teacher-model', default='regnety_160', type=str, metavar='MODEL', 106 | help='Name of teacher model to train (default: "regnety_160"') 107 | parser.add_argument('--teacher-path', type=str, default='') 108 | parser.add_argument('--distillation-type', default='none', choices=['none', 'soft', 'hard'], type=str, help="") 109 | parser.add_argument('--distillation-alpha', default=0.5, type=float, help="") 110 | parser.add_argument('--distillation-tau', default=1.0, type=float, help="") 111 | 112 | # * Finetuning params 113 | parser.add_argument('--finetune', default='', help='finetune from checkpoint') 114 | 115 | # Dataset parameters 116 | parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', type=str, 117 | help='dataset path') 118 | parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19'], 119 | type=str, help='Image Net dataset path') 120 | parser.add_argument('--inat-category', default='name', 121 | choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'], 122 | type=str, help='semantic granularity') 123 | 124 | parser.add_argument('--output_dir', default='', 125 | help='path where to save, empty for no saving') 126 | parser.add_argument('--device', default='cuda', 127 | help='device to use for training / testing') 128 | parser.add_argument('--seed', default=0, type=int) 129 | parser.add_argument('--resume', default='', help='resume from checkpoint') 130 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 131 | help='start epoch') 132 | parser.add_argument('--eval', action='store_true', help='Perform evaluation only') 133 | parser.add_argument('--throughput', action='store_true', help='Perform throughput only') 134 | parser.add_argument('--dist-eval', action='store_true', default=False, help='Enabling distributed evaluation') 135 | parser.add_argument('--num_workers', default=10, type=int) 136 | parser.add_argument('--pin-mem', action='store_true', 137 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 138 | parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem', 139 | help='') 140 | parser.set_defaults(pin_mem=True) 141 | 142 | # distributed training parameters 143 | parser.add_argument('--world_size', default=1, type=int, 144 | help='number of distributed processes') 145 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 146 | parser.add_argument('--exp_name', default='deit', 147 | type=str, help='model configuration') 148 | parser.add_argument('--config', default=None, 149 | type=str, help='model configuration') 150 | parser.add_argument('--patch_size', default=16, type=int) 151 | parser.add_argument('--num_heads', default=3, type=int) 152 | parser.add_argument('--head_dim', default=64, type=int) 153 | parser.add_argument('--num_blocks', default=12, type=int) 154 | parser.add_argument('--input_size', default=224, type=int, help='images input size') 155 | parser.add_argument('--sparse_block_mode', default=0, type=int, help='sparse policy') 156 | parser.add_argument('--custom_blocks', default='', type=str, help='custom sparse blocks') 157 | parser.add_argument('--transformer_type', default='normal', type=str, help='') 158 | parser.add_argument('--local_rank', default=0, type=int, help='') 159 | 160 | # distributed training parameters 161 | 162 | args = parser.parse_args() 163 | if args.config is not None: 164 | config_args = json.load(open(args.config)) 165 | override_keys = {arg[2:].split('=')[0] for arg in sys.argv[1:] 166 | if arg.startswith('--')} 167 | for k, v in config_args.items(): 168 | if k not in override_keys: 169 | setattr(args, k, v) 170 | del args.config -------------------------------------------------------------------------------- /SPViT_DeiT/post_training_optimize_checkpoint.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | from collections import OrderedDict 4 | import ast 5 | 6 | 7 | def main(): 8 | 9 | # Optimize the number of parameters of a checkpoint 10 | 11 | if len(sys.argv) != 3: 12 | print('Error: Two input arguments, checkpoint_path and searched MSA architecture in a list.') 13 | return 14 | 15 | checkpoint_path = sys.argv[1] 16 | state_dict = torch.load(checkpoint_path, map_location='cpu')['model'] 17 | 18 | try: 19 | # Use ast.literal_eval to safely parse the string as a list 20 | MSA_indicators = ast.literal_eval(sys.argv[2]) 21 | 22 | if not isinstance(MSA_indicators, list): 23 | raise ValueError("The provided parameter is not a valid list.") 24 | 25 | # Now you have the list parameter 26 | print(f"List Parameter: {MSA_indicators}") 27 | 28 | except (ValueError, SyntaxError) as e: 29 | print(f"Invalid MSA indicators: {e}") 30 | 31 | new_dict = OrderedDict() 32 | 33 | if any('bconv' in key for key in list(state_dict.keys())): 34 | print('Error: The checkpoint is already optimized!') 35 | return 36 | 37 | for k, v in state_dict.items(): 38 | if 'head_probs' in k: 39 | block_name = k.replace('head_probs', '') 40 | block_num = int(k.split('.')[1]) 41 | head_probs = (state_dict[k] / 1e-2).softmax(0) 42 | num_heads = head_probs.shape[0] 43 | feature_dim = state_dict[block_name + 'v.weight'].shape[0] 44 | head_dim = feature_dim // num_heads 45 | 46 | if MSA_indicators[block_num][-1] == 1: 47 | print('Error: checkpoint and MSA indicators do not match!') 48 | return 49 | 50 | new_v_weight = state_dict[block_name + 'v.weight'].view(num_heads, head_dim, feature_dim).permute(1, 2, 0) @ head_probs 51 | new_v_bias = state_dict[block_name + 'v.bias'].view(num_heads, head_dim).permute(1, 0) @ head_probs 52 | new_proj_weight = state_dict[block_name + 'proj.weight'].view(feature_dim, num_heads, head_dim).permute(0, 2, 1) @ head_probs 53 | 54 | if MSA_indicators[block_num][1] == 1: 55 | bn_name = 'bn_3x3.' 56 | new_dict[block_name + 'bconv.0.weight'] = new_v_weight.permute(2, 0, 1).view(3, 3, head_dim, -1).permute(2, 3, 1, 0) 57 | new_dict[block_name + 'bconv.0.bias'] = new_v_bias.sum(-1) 58 | new_dict[block_name + 'bconv.3.weight'] = new_proj_weight.sum(-1)[..., None, None] 59 | else: 60 | bn_name = 'bn_1x1.' 61 | new_dict[block_name + 'bconv.0.weight'] = new_v_weight[..., 4][..., None, None] 62 | new_dict[block_name + 'bconv.0.bias'] = new_v_bias[..., 4] 63 | new_dict[block_name + 'bconv.3.weight'] = new_proj_weight[..., 4][..., None, None] 64 | 65 | new_dict[block_name + 'bconv.3.bias'] = state_dict[block_name + 'proj.bias'] 66 | 67 | new_dict[block_name + 'bconv.1.weight'] = state_dict[block_name + bn_name + 'weight'] 68 | new_dict[block_name + 'bconv.1.bias'] = state_dict[block_name + bn_name + 'bias'] 69 | new_dict[block_name + 'bconv.1.running_mean'] = state_dict[block_name + bn_name + 'running_mean'] 70 | new_dict[block_name + 'bconv.1.running_var'] = state_dict[block_name + bn_name + 'running_var'] 71 | new_dict[block_name + 'bconv.1.num_batches_tracked'] = state_dict[block_name + bn_name + 'num_batches_tracked'] 72 | 73 | else: 74 | if len(k.split('.')) <= 4 or '.'.join(k.split('.')[:-2]) + '.head_probs' not in state_dict.keys(): 75 | new_dict[k] = state_dict[k] 76 | 77 | torch.save({'model': new_dict}, '.'.join(checkpoint_path.split('.')[:-1]) + '_optimized.pth') 78 | 79 | 80 | if __name__ == '__main__': 81 | main() 82 | -------------------------------------------------------------------------------- /SPViT_DeiT/requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.7.0 2 | torchvision==0.8.1 3 | timm==0.3.2 4 | -------------------------------------------------------------------------------- /SPViT_DeiT/samplers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | # Modifications copyright (c) 2021 Zhuang AI Group, Haoyu He 4 | 5 | import torch 6 | import torch.distributed as dist 7 | import math 8 | 9 | 10 | class RASampler(torch.utils.data.Sampler): 11 | """Sampler that restricts data loading to a subset of the dataset for distributed, 12 | with repeated augmentation. 13 | It ensures that different each augmented version of a sample will be visible to a 14 | different process (GPU) 15 | Heavily based on torch.utils.data.DistributedSampler 16 | """ 17 | 18 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): 19 | if num_replicas is None: 20 | if not dist.is_available(): 21 | raise RuntimeError("Requires distributed package to be available") 22 | num_replicas = dist.get_world_size() 23 | if rank is None: 24 | if not dist.is_available(): 25 | raise RuntimeError("Requires distributed package to be available") 26 | rank = dist.get_rank() 27 | self.dataset = dataset 28 | self.num_replicas = num_replicas 29 | self.rank = rank 30 | self.epoch = 0 31 | self.num_samples = int(math.ceil(len(self.dataset) * 3.0 / self.num_replicas)) 32 | self.total_size = self.num_samples * self.num_replicas 33 | # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas)) 34 | self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas)) 35 | self.shuffle = shuffle 36 | 37 | def __iter__(self): 38 | # deterministically shuffle based on epoch 39 | g = torch.Generator() 40 | g.manual_seed(self.epoch) 41 | if self.shuffle: 42 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 43 | else: 44 | indices = list(range(len(self.dataset))) 45 | 46 | # add extra samples to make it evenly divisible 47 | indices = [ele for ele in indices for i in range(3)] 48 | indices += indices[:(self.total_size - len(indices))] 49 | assert len(indices) == self.total_size 50 | 51 | # subsample 52 | indices = indices[self.rank:self.total_size:self.num_replicas] 53 | assert len(indices) == self.num_samples 54 | 55 | return iter(indices[:self.num_selected_samples]) 56 | 57 | def __len__(self): 58 | return self.num_selected_samples 59 | 60 | def set_epoch(self, epoch): 61 | self.epoch = epoch 62 | -------------------------------------------------------------------------------- /SPViT_DeiT/tox.ini: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 120 3 | ignore = F401,E402,F403,W503,W504 4 | -------------------------------------------------------------------------------- /SPViT_DeiT/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | # Modifications copyright (c) 2021 Zhuang AI Group, Haoyu He 4 | 5 | """ 6 | Misc functions, including distributed helpers. 7 | 8 | Mostly copy-paste from torchvision references. 9 | """ 10 | import io 11 | import os 12 | import time 13 | from collections import defaultdict, deque 14 | import datetime 15 | 16 | import torch 17 | import torch.distributed as dist 18 | from logger import logger 19 | from timm.scheduler.cosine_lr import CosineLRScheduler 20 | from timm.scheduler.tanh_lr import TanhLRScheduler 21 | from timm.scheduler.step_lr import StepLRScheduler 22 | from timm.scheduler.plateau_lr import PlateauLRScheduler 23 | from torch import optim as optim 24 | from timm.optim.lookahead import Lookahead 25 | from collections import OrderedDict 26 | 27 | try: 28 | from apex.optimizers import FusedAdam 29 | has_apex = True 30 | except ImportError: 31 | has_apex = False 32 | 33 | 34 | class SmoothedValue(object): 35 | """Track a series of values and provide access to smoothed values over a 36 | window or the global series average. 37 | """ 38 | 39 | def __init__(self, window_size=20, fmt=None): 40 | if fmt is None: 41 | fmt = "{median:.4f} ({global_avg:.4f})" 42 | self.deque = deque(maxlen=window_size) 43 | self.total = 0.0 44 | self.count = 0 45 | self.fmt = fmt 46 | 47 | def update(self, value, n=1): 48 | self.deque.append(value) 49 | self.count += n 50 | self.total += value * n 51 | 52 | def synchronize_between_processes(self): 53 | """ 54 | Warning: does not synchronize the deque! 55 | """ 56 | if not is_dist_avail_and_initialized(): 57 | return 58 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 59 | dist.barrier() 60 | dist.all_reduce(t) 61 | t = t.tolist() 62 | self.count = int(t[0]) 63 | self.total = t[1] 64 | 65 | @property 66 | def median(self): 67 | d = torch.tensor(list(self.deque)) 68 | return d.median().item() 69 | 70 | @property 71 | def avg(self): 72 | d = torch.tensor(list(self.deque), dtype=torch.float32) 73 | return d.mean().item() 74 | 75 | @property 76 | def global_avg(self): 77 | return self.total / self.count 78 | 79 | @property 80 | def max(self): 81 | return max(self.deque) 82 | 83 | @property 84 | def value(self): 85 | return self.deque[-1] 86 | 87 | def __str__(self): 88 | return self.fmt.format( 89 | median=self.median, 90 | avg=self.avg, 91 | global_avg=self.global_avg, 92 | max=self.max, 93 | value=self.value) 94 | 95 | 96 | class MetricLogger(object): 97 | def __init__(self, delimiter="\t"): 98 | self.meters = defaultdict(SmoothedValue) 99 | self.delimiter = delimiter 100 | 101 | def update(self, **kwargs): 102 | for k, v in kwargs.items(): 103 | if isinstance(v, torch.Tensor): 104 | v = v.item() 105 | assert isinstance(v, (float, int)) 106 | self.meters[k].update(v) 107 | 108 | def __getattr__(self, attr): 109 | if attr in self.meters: 110 | return self.meters[attr] 111 | if attr in self.__dict__: 112 | return self.__dict__[attr] 113 | raise AttributeError("'{}' object has no attribute '{}'".format( 114 | type(self).__name__, attr)) 115 | 116 | def __str__(self): 117 | loss_str = [] 118 | for name, meter in self.meters.items(): 119 | loss_str.append( 120 | "{}: {}".format(name, str(meter)) 121 | ) 122 | return self.delimiter.join(loss_str) 123 | 124 | def synchronize_between_processes(self): 125 | for meter in self.meters.values(): 126 | meter.synchronize_between_processes() 127 | 128 | def add_meter(self, name, meter): 129 | self.meters[name] = meter 130 | 131 | def log_every(self, iterable, print_freq, header=None): 132 | i = 0 133 | if not header: 134 | header = '' 135 | start_time = time.time() 136 | end = time.time() 137 | iter_time = SmoothedValue(fmt='{avg:.4f}') 138 | data_time = SmoothedValue(fmt='{avg:.4f}') 139 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 140 | log_msg = [ 141 | header, 142 | '[{0' + space_fmt + '}/{1}]', 143 | 'eta: {eta}', 144 | '{meters}', 145 | 'time: {time}', 146 | 'data: {data}' 147 | ] 148 | if torch.cuda.is_available(): 149 | log_msg.append('max mem: {memory:.0f}') 150 | log_msg = self.delimiter.join(log_msg) 151 | MB = 1024.0 * 1024.0 152 | for obj in iterable: 153 | data_time.update(time.time() - end) 154 | yield obj 155 | iter_time.update(time.time() - end) 156 | if i % print_freq == 0 or i == len(iterable) - 1: 157 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 158 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 159 | if torch.cuda.is_available(): 160 | logger.info(log_msg.format( 161 | i, len(iterable), eta=eta_string, 162 | meters=str(self), 163 | time=str(iter_time), data=str(data_time), 164 | memory=torch.cuda.max_memory_allocated() / MB)) 165 | else: 166 | logger.info(log_msg.format( 167 | i, len(iterable), eta=eta_string, 168 | meters=str(self), 169 | time=str(iter_time), data=str(data_time))) 170 | i += 1 171 | end = time.time() 172 | total_time = time.time() - start_time 173 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 174 | logger.info('{} Total time: {} ({:.4f} s / it)'.format( 175 | header, total_time_str, total_time / len(iterable))) 176 | 177 | def log_indicator(self, info): 178 | logger.info('indicators: ' + str(info)) 179 | 180 | 181 | def _load_checkpoint_for_ema(model_ema, checkpoint): 182 | """ 183 | Workaround for ModelEma._load_checkpoint to accept an already-loaded object 184 | """ 185 | mem_file = io.BytesIO() 186 | torch.save(checkpoint, mem_file) 187 | mem_file.seek(0) 188 | model_ema._load_checkpoint(mem_file) 189 | 190 | 191 | def setup_for_distributed(is_master): 192 | """ 193 | This function disables printing when not in master process 194 | """ 195 | import builtins as __builtin__ 196 | builtin_print = __builtin__.print 197 | 198 | def print(*args, **kwargs): 199 | force = kwargs.pop('force', False) 200 | if is_master or force: 201 | builtin_print(*args, **kwargs) 202 | 203 | __builtin__.print = print 204 | 205 | 206 | def is_dist_avail_and_initialized(): 207 | if not dist.is_available(): 208 | return False 209 | if not dist.is_initialized(): 210 | return False 211 | return True 212 | 213 | 214 | def get_world_size(): 215 | if not is_dist_avail_and_initialized(): 216 | return 1 217 | return dist.get_world_size() 218 | 219 | 220 | def get_rank(): 221 | if not is_dist_avail_and_initialized(): 222 | return 0 223 | return dist.get_rank() 224 | 225 | 226 | def is_main_process(): 227 | return get_rank() == 0 228 | 229 | 230 | def save_on_master(*args, **kwargs): 231 | if is_main_process(): 232 | torch.save(*args, **kwargs) 233 | 234 | 235 | def init_distributed_mode(args): 236 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 237 | args.rank = int(os.environ["RANK"]) 238 | args.world_size = int(os.environ['WORLD_SIZE']) 239 | args.gpu = int(os.environ['LOCAL_RANK']) 240 | elif 'SLURM_PROCID' in os.environ: 241 | args.rank = int(os.environ['SLURM_PROCID']) 242 | args.gpu = args.rank % torch.cuda.device_count() 243 | else: 244 | logger.info('Not using distributed mode') 245 | args.distributed = False 246 | return 247 | 248 | args.distributed = True 249 | 250 | torch.cuda.set_device(args.gpu) 251 | args.dist_backend = 'nccl' 252 | logger.info('| distributed init (rank {}): {}'.format( 253 | args.rank, args.dist_url)) 254 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 255 | world_size=args.world_size, rank=args.rank) 256 | torch.distributed.barrier() 257 | setup_for_distributed(args.rank == 0) 258 | 259 | 260 | def create_scheduler(args, optimizer, epochs, warmup_epochs, min_lr): 261 | num_epochs = epochs 262 | 263 | if getattr(args, 'lr_noise', None) is not None: 264 | lr_noise = getattr(args, 'lr_noise') 265 | if isinstance(lr_noise, (list, tuple)): 266 | noise_range = [n * num_epochs for n in lr_noise] 267 | if len(noise_range) == 1: 268 | noise_range = noise_range[0] 269 | else: 270 | noise_range = lr_noise * num_epochs 271 | else: 272 | noise_range = None 273 | 274 | lr_scheduler = None 275 | if args.sched == 'cosine': 276 | lr_scheduler = CosineLRScheduler( 277 | optimizer, 278 | t_initial=num_epochs, 279 | t_mul=getattr(args, 'lr_cycle_mul', 1.), 280 | lr_min=args.min_lr, 281 | decay_rate=args.decay_rate, 282 | warmup_lr_init=args.warmup_lr, 283 | warmup_t=warmup_epochs, 284 | cycle_limit=getattr(args, 'lr_cycle_limit', 1), 285 | t_in_epochs=True, 286 | noise_range_t=noise_range, 287 | noise_pct=getattr(args, 'lr_noise_pct', 0.67), 288 | noise_std=getattr(args, 'lr_noise_std', 1.), 289 | noise_seed=getattr(args, 'seed', 42), 290 | ) 291 | num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs 292 | elif args.sched == 'tanh': 293 | lr_scheduler = TanhLRScheduler( 294 | optimizer, 295 | t_initial=num_epochs, 296 | t_mul=getattr(args, 'lr_cycle_mul', 1.), 297 | lr_min=args.min_lr, 298 | warmup_lr_init=args.warmup_lr, 299 | warmup_t=args.warmup_epochs, 300 | cycle_limit=getattr(args, 'lr_cycle_limit', 1), 301 | t_in_epochs=True, 302 | noise_range_t=noise_range, 303 | noise_pct=getattr(args, 'lr_noise_pct', 0.67), 304 | noise_std=getattr(args, 'lr_noise_std', 1.), 305 | noise_seed=getattr(args, 'seed', 42), 306 | ) 307 | num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs 308 | elif args.sched == 'step': 309 | lr_scheduler = StepLRScheduler( 310 | optimizer, 311 | decay_t=args.decay_epochs, 312 | decay_rate=args.decay_rate, 313 | warmup_lr_init=args.warmup_lr, 314 | warmup_t=args.warmup_epochs, 315 | noise_range_t=noise_range, 316 | noise_pct=getattr(args, 'lr_noise_pct', 0.67), 317 | noise_std=getattr(args, 'lr_noise_std', 1.), 318 | noise_seed=getattr(args, 'seed', 42), 319 | ) 320 | elif args.sched == 'plateau': 321 | mode = 'min' if 'loss' in getattr(args, 'eval_metric', '') else 'max' 322 | lr_scheduler = PlateauLRScheduler( 323 | optimizer, 324 | decay_rate=args.decay_rate, 325 | patience_t=args.patience_epochs, 326 | lr_min=args.min_lr, 327 | mode=mode, 328 | warmup_lr_init=args.warmup_lr, 329 | warmup_t=args.warmup_epochs, 330 | cooldown_t=0, 331 | noise_range_t=noise_range, 332 | noise_pct=getattr(args, 'lr_noise_pct', 0.67), 333 | noise_std=getattr(args, 'lr_noise_std', 1.), 334 | noise_seed=getattr(args, 'seed', 42), 335 | ) 336 | 337 | return lr_scheduler, num_epochs 338 | 339 | 340 | def add_weight_decay_2ops(model, weight_decay=1e-5, skip_list=()): 341 | decay = [] 342 | no_decay = [] 343 | diff_lr = [] 344 | for name, param in model.named_parameters(): 345 | if not param.requires_grad: 346 | continue # frozen weights 347 | if 'thresholds' in name: 348 | diff_lr.append(param) 349 | elif len(param.shape) == 1 or name.endswith(".bias") or name in skip_list: 350 | no_decay.append(param) 351 | else: 352 | decay.append(param) 353 | 354 | return [ 355 | {'params': no_decay, 'weight_decay': 0.}, 356 | {'params': decay, 'weight_decay': weight_decay}], [ 357 | {'params': diff_lr, 'weight_decay': 0.}] 358 | 359 | 360 | def create_2optimizers(args, model, filter_bias_and_bn=True): 361 | opt_lower = args.opt.lower() 362 | weight_decay = args.weight_decay 363 | if weight_decay and filter_bias_and_bn: 364 | skip = {} 365 | if hasattr(model, 'no_weight_decay'): 366 | skip = model.no_weight_decay() 367 | parameters1, parameters2 = add_weight_decay_2ops(model, weight_decay, skip) 368 | weight_decay = 0. 369 | else: 370 | parameters = model.parameters() 371 | 372 | if 'fused' in opt_lower: 373 | assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers' 374 | 375 | opt_args1 = dict(lr=args.lr, weight_decay=weight_decay) 376 | if hasattr(args, 'opt_eps') and args.opt_eps is not None: 377 | opt_args1['eps'] = args.opt_eps 378 | if hasattr(args, 'opt_betas') and args.opt_betas is not None: 379 | opt_args1['betas'] = args.opt_betas 380 | 381 | opt_args2 = dict(lr=args.arc_lr, weight_decay=0.) 382 | if hasattr(args, 'opt_eps') and args.opt_eps is not None: 383 | opt_args2['eps'] = args.opt_eps 384 | if hasattr(args, 'opt_betas') and args.opt_betas is not None: 385 | opt_args2['betas'] = args.opt_betas 386 | 387 | opt_split = opt_lower.split('_') 388 | opt_lower = opt_split[-1] 389 | if opt_lower == 'adamw': 390 | optimizer1 = optim.AdamW(parameters1, **opt_args1) 391 | optimizer2 = optim.AdamW(parameters2, **opt_args2) 392 | 393 | elif opt_lower == 'fusedadamw': 394 | optimizer1 = FusedAdam(parameters1, adam_w_mode=True, **opt_args1) 395 | optimizer2 = FusedAdam(parameters2, adam_w_mode=True, **opt_args2) 396 | else: 397 | assert False and "Invalid optimizer" 398 | raise ValueError 399 | 400 | return optimizer1, optimizer2 401 | 402 | 403 | class NativeScaler: 404 | state_dict_key = "amp_scaler" 405 | 406 | def __init__(self): 407 | self._scaler = torch.cuda.amp.GradScaler() 408 | 409 | def __call__(self, loss, optimizer1, optimizer2, clip_grad=None, parameters=None, create_graph=False, model=None): 410 | self._scaler.scale(loss).backward(create_graph=create_graph) 411 | 412 | self._scaler.step(optimizer1) 413 | 414 | if optimizer2: 415 | self._scaler.step(optimizer2) 416 | self._scaler.update() 417 | 418 | def state_dict(self): 419 | return self._scaler.state_dict() 420 | 421 | def load_state_dict(self, state_dict): 422 | self._scaler.load_state_dict(state_dict) 423 | 424 | 425 | def prune_ffn(checkpoint_dict, ffn_indicators): 426 | 427 | depth = 12 428 | 429 | for i in range(depth): 430 | assigned_indicator_index = ffn_indicators[i].nonzero().squeeze(-1) 431 | 432 | in_dim = checkpoint_dict[f'blocks.{i}.mlp.fc1.weight'].shape[1] 433 | checkpoint_dict[f'blocks.{i}.mlp.fc1.weight'] = torch.gather(checkpoint_dict[f'blocks.{i}.mlp.fc1.weight'], 0, 434 | assigned_indicator_index.unsqueeze(-1).expand(-1, in_dim)) 435 | checkpoint_dict[f'blocks.{i}.mlp.fc1.bias'] = torch.gather(checkpoint_dict[f'blocks.{i}.mlp.fc1.bias'], 0, 436 | assigned_indicator_index) 437 | checkpoint_dict[f'blocks.{i}.mlp.fc2.weight'] = torch.gather(checkpoint_dict[f'blocks.{i}.mlp.fc2.weight'], 1, 438 | assigned_indicator_index.unsqueeze(0).expand(in_dim, -1)) 439 | 440 | i += 1 441 | return checkpoint_dict 442 | 443 | 444 | def save_ffn_indicators(model, epoch, logger, output_dir): 445 | new_dict = OrderedDict() 446 | 447 | old_dict = model.state_dict() 448 | for name in old_dict.keys(): 449 | if 'assigned_indicator_index' in name: 450 | new_dict[name] = old_dict[name] 451 | 452 | save_path = os.path.join(output_dir, f'search_{epoch}epoch.pth') 453 | logger.info(f"{save_path} saving......") 454 | save_on_master({ 455 | 'model': model.state_dict() 456 | }, save_path) 457 | logger.info(f"{save_path} saved !!!") 458 | -------------------------------------------------------------------------------- /SPViT_Swin/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | /output/ 131 | /detectron2/ 132 | detectron2 133 | detectron2/ 134 | /detectron2.OLD/ 135 | /detectron2.WRONG/ 136 | /detectron2.egg-info/ 137 | -------------------------------------------------------------------------------- /SPViT_Swin/README.md: -------------------------------------------------------------------------------- 1 | ### Getting started on SPViT-Swin: 2 | 3 | #### Installation and data preparation 4 | 5 | - First, you can install the required environments as illustrated in the [Swin](https://github.com/microsoft/Swin-Transformer/blob/main/get_started.md) repository or follow the instructions below: 6 | 7 | ```bash 8 | # Create virtual env 9 | conda create -n spvit-swin python=3.7 -y 10 | conda activate spvit-swin 11 | 12 | # Install PyTorch 13 | conda install pytorch==1.7.1 torchvision==0.8.2 cudatoolkit=10.1 -c pytorch 14 | pip install timm==0.3.2 15 | 16 | # Install Apex 17 | git clone https://github.com/NVIDIA/apex 18 | cd apex 19 | pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ 20 | 21 | # Install other requirements: 22 | pip install opencv-python==4.4.0.46 termcolor==1.1.0 yacs==0.1.8 23 | ``` 24 | 25 | - Next, install some other dependencies that are required by SPViT: 26 | 27 | ```bash 28 | pip install tensorboardX tensorboard 29 | ``` 30 | 31 | - Please refer to the [Swin](https://github.com/microsoft/Swin-Transformer/blob/main/get_started.md) repository to prepare the standard ImageNet dataset, then link the ImageNet dataset under the `dataset`folder: 32 | 33 | ```bash 34 | $ tree dataset 35 | imagenet 36 | ├── train 37 | │ ├── class1 38 | │ │ ├── img1.jpeg 39 | │ │ ├── img2.jpeg 40 | │ │ └── ... 41 | │ ├── class2 42 | │ │ ├── img3.jpeg 43 | │ │ └── ... 44 | │ └── ... 45 | └── val 46 | ├── class1 47 | │ ├── img4.jpeg 48 | │ ├── img5.jpeg 49 | │ └── ... 50 | ├── class2 51 | │ ├── img6.jpeg 52 | │ └── ... 53 | └── ... 54 | ``` 55 | 56 | #### Download pretrained models 57 | 58 | - We start searching and fine-tuneing both from the pre-trained models. 59 | 60 | - Since we provide training scripts for three Swin models: Swin-T, Swin-S and Swin-B, please download the corresponding three pre-trained models from the [Swin](https://github.com/microsoft/Swin-Transformer/blob/main/get_started.md) repository as well. 61 | 62 | - Next, move the downloaded pre-trained models into the following file structure: 63 | 64 | ```bash 65 | $ tree model 66 | ├── swin_base_patch4_window7_224.pth 67 | ├── swin_small_patch4_window7_224.pth 68 | ├── swin_tiny_patch4_window7_224.pth 69 | ``` 70 | 71 | - Note that do not change the filenames for the pre-trained models as we hard-coded these filenames when tailoring and loading the pre-trained models. Feel free to modify the hard-coded parts when pruning from other pre-trained models. 72 | 73 | #### Searching 74 | 75 | To search architectures with SPViT-Swin-T, run: 76 | 77 | ```bash 78 | python -m torch.distributed.launch --nproc_per_node 8 --master_port 3132 main_pruning.py --cfg configs/spvit_swin_tn_l28_t32_search.yaml --resume model/swin_tiny_patch4_window7_224.pth 79 | ``` 80 | 81 | To search architectures with SPViT-Swin-S, run: 82 | 83 | ```bash 84 | python -m torch.distributed.launch --nproc_per_node 8 --master_port 3132 main_pruning.py --cfg configs/spvit_swin_sm_l04_t55_search.yaml --resume model/swin_small_patch4_window7_224.pth 85 | ``` 86 | 87 | To search architectures with SPViT-Swin-B, run: 88 | 89 | ```bash 90 | python -m torch.distributed.launch --nproc_per_node 8 --master_port 3132 main_pruning.py --cfg configs/spvit_swin_bs_l01_t100_search.yaml --resume model/swin_base_patch4_window7_224.pth 91 | ``` 92 | 93 | #### Fine-tuning 94 | 95 | You can start fine-tuning from either your own searched architectures or from our provided architectures by modifying and assigning the MSA indicators in `assigned_indicators` and the FFN indicators in `searching_model`. 96 | 97 | To fine-tune architectures searched by SPViT-Swin-T, run: 98 | 99 | ```bash 100 | python -m torch.distributed.launch --nproc_per_node 8 --master_port 3132 main_pruning.py --cfg configs/spvit_swin_tn_l28_t32_ft.yaml --resume model/swin_tiny_patch4_window7_224.pth 101 | ``` 102 | 103 | To fine-tune the architectures with SPViT-Swin-S, run: 104 | 105 | ```bash 106 | python -m torch.distributed.launch --nproc_per_node 8 --master_port 3132 main_pruning.py --cfg configs/spvit_swin_sm_l04_t55_ft.yaml --resume model/swin_small_patch4_window7_224.pth 107 | ``` 108 | 109 | To fine-tune the architectures with SPViT-Swin-B, run: 110 | 111 | ```bash 112 | python -m torch.distributed.launch --nproc_per_node 8 --master_port 3132 main_pruning.py --cfg configs/spvit_swin_bs_l01_t100_ft.yaml --resume model/swin_base_patch4_window7_224.pth 113 | ``` 114 | 115 | #### Evaluation 116 | 117 | We provide several examples for evaluating pre-trained SPViT models. 118 | 119 | To evaluate SPViT-Swin-T pre-trained models, run: 120 | 121 | ```bash 122 | python -m torch.distributed.launch --nproc_per_node 1 --master_port 3132 main_pruning.py --cfg configs/spvit_swin_tn_l28_t32_ft.yaml --resume [PRE-TRAINED MODEL PATH] --opts EVAL_MODE True 123 | ``` 124 | 125 | To evaluate SPViT-Swin-S pre-trained models, run: 126 | 127 | ```bash 128 | python -m torch.distributed.launch --nproc_per_node 1 --master_port 3132 main_pruning.py --cfg configs/spvit_swin_sm_l04_t55_ft.yaml --resume [PRE-TRAINED MODEL PATH] --opts EVAL_MODE True 129 | ``` 130 | 131 | To evaluate SPViT-Swin-B pre-trained models, run: 132 | 133 | ```bash 134 | python -m torch.distributed.launch --nproc_per_node 1 --master_port 3132 main_pruning.py --cfg configs/spvit_swin_bs_l01_t100_ft.yaml --resume [PRE-TRAINED MODEL PATH] --opts EVAL_MODE True 135 | ``` 136 | 137 | After fine-tuning, you can optimize your checkpoint to a smaller size with the following code: 138 | ```bash 139 | python post_training_optimize_checkpoint.py YOUR_CHECKPOINT_PATH 140 | ``` 141 | The optimized checkpoint can be evaluated by replacing `UnifiedWindowAttention` with `UnifiedWindowAttentionParamOpt` and we have provided an example below: 142 | ```bash 143 | main_pruning.py 144 | --cfg 145 | configs/spvit_swin_tn_l28_t32_ft_dist.yaml 146 | --resume 147 | model/spvit_swin_t_l28_t32_dist_optimized.pth 148 | --opts 149 | EVAL_MODE 150 | True 151 | EXTRA.attention_type 152 | UnifiedWindowAttentionParamOpt 153 | --local_rank 154 | 0 155 | ``` 156 | #### 157 | 158 | #### TODO: 159 | 160 | ``` 161 | - [x] Release code. 162 | - [x] Release pre-trained models. 163 | ``` 164 | -------------------------------------------------------------------------------- /SPViT_Swin/config.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # --------------------------------------------------------' 7 | # Modifications copyright (c) 2021 Zhuang AI Group, Haoyu He 8 | 9 | import os 10 | import yaml 11 | from yacs.config import CfgNode as CN 12 | 13 | _C = CN() 14 | 15 | # Base config files 16 | _C.BASE = [''] 17 | 18 | # ----------------------------------------------------------------------------- 19 | # Data settings 20 | # ----------------------------------------------------------------------------- 21 | _C.DATA = CN() 22 | # Batch size for a single GPU, could be overwritten by command line argument 23 | _C.DATA.BATCH_SIZE = 128 24 | # Path to dataset, could be overwritten by command line argument 25 | _C.DATA.DATA_PATH = '' 26 | # Dataset name 27 | _C.DATA.DATASET = 'imagenet' 28 | # Input image size 29 | _C.DATA.IMG_SIZE = 224 30 | # Interpolation to resize image (random, bilinear, bicubic) 31 | _C.DATA.INTERPOLATION = 'bicubic' 32 | # Use zipped dataset instead of folder dataset 33 | # could be overwritten by command line argument 34 | _C.DATA.ZIP_MODE = False 35 | # Cache Data in Memory, could be overwritten by command line argument 36 | _C.DATA.CACHE_MODE = 'part' 37 | # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU. 38 | _C.DATA.PIN_MEMORY = True 39 | # Number of data loading threads 40 | _C.DATA.NUM_WORKERS = 8 41 | 42 | # ----------------------------------------------------------------------------- 43 | # Model settings 44 | # ----------------------------------------------------------------------------- 45 | _C.MODEL = CN() 46 | # Model type 47 | _C.MODEL.TYPE = 'swin' 48 | # Model name 49 | _C.MODEL.NAME = 'swin_tiny_patch4_window7_224' 50 | # Checkpoint to resume, could be overwritten by command line argument 51 | _C.MODEL.RESUME = '' 52 | # Number of classes, overwritten in data preparation 53 | _C.MODEL.NUM_CLASSES = 1000 54 | # Dropout rate 55 | _C.MODEL.DROP_RATE = 0.0 56 | # Drop path rate 57 | _C.MODEL.DROP_PATH_RATE = 0.1 58 | # Label Smoothing 59 | _C.MODEL.LABEL_SMOOTHING = 0.1 60 | 61 | # Swin Transformer parameters 62 | _C.MODEL.SWIN = CN() 63 | _C.MODEL.SWIN.PATCH_SIZE = 4 64 | _C.MODEL.SWIN.IN_CHANS = 3 65 | _C.MODEL.SWIN.EMBED_DIM = 96 66 | _C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2] 67 | _C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24] 68 | _C.MODEL.SWIN.WINDOW_SIZE = 7 69 | _C.MODEL.SWIN.MLP_RATIO = 4. 70 | _C.MODEL.SWIN.QKV_BIAS = True 71 | _C.MODEL.SWIN.QK_SCALE = None 72 | _C.MODEL.SWIN.APE = False 73 | _C.MODEL.SWIN.PATCH_NORM = True 74 | 75 | # ----------------------------------------------------------------------------- 76 | # Training settings 77 | # ----------------------------------------------------------------------------- 78 | _C.TRAIN = CN() 79 | _C.TRAIN.START_EPOCH = 0 80 | _C.TRAIN.EPOCHS = 300 81 | _C.TRAIN.WARMUP_EPOCHS = 20 82 | _C.TRAIN.WEIGHT_DECAY = 0.05 83 | _C.TRAIN.BASE_LR = 5e-4 84 | _C.TRAIN.WARMUP_LR = 5e-7 85 | _C.TRAIN.MIN_LR = 5e-6 86 | # Clip gradient norm 87 | _C.TRAIN.CLIP_GRAD = 5.0 88 | # Auto resume from latest checkpoint 89 | _C.TRAIN.AUTO_RESUME = True 90 | # Gradient accumulation steps 91 | # could be overwritten by command line argument 92 | _C.TRAIN.ACCUMULATION_STEPS = 0 93 | # Whether to use gradient checkpointing to save memory 94 | # could be overwritten by command line argument 95 | _C.TRAIN.USE_CHECKPOINT = False 96 | 97 | # LR scheduler 98 | _C.TRAIN.LR_SCHEDULER = CN() 99 | _C.TRAIN.LR_SCHEDULER.NAME = 'cosine' 100 | # Epoch interval to decay LR, used in StepLRScheduler 101 | _C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30 102 | # LR decay rate, used in StepLRScheduler 103 | _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1 104 | 105 | # Optimizer 106 | _C.TRAIN.OPTIMIZER = CN() 107 | _C.TRAIN.OPTIMIZER.NAME = 'adamw' 108 | # Optimizer Epsilon 109 | _C.TRAIN.OPTIMIZER.EPS = 1e-8 110 | # Optimizer Betas 111 | _C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999) 112 | # SGD momentum 113 | _C.TRAIN.OPTIMIZER.MOMENTUM = 0.9 114 | 115 | # ----------------------------------------------------------------------------- 116 | # Augmentation settings 117 | # ----------------------------------------------------------------------------- 118 | _C.AUG = CN() 119 | # Color jitter factor 120 | _C.AUG.COLOR_JITTER = 0.4 121 | # Use AutoAugment policy. "v0" or "original" 122 | _C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1' 123 | # Random erase prob 124 | _C.AUG.REPROB = 0.25 125 | # Random erase mode 126 | _C.AUG.REMODE = 'pixel' 127 | # Random erase count 128 | _C.AUG.RECOUNT = 1 129 | # Mixup alpha, mixup enabled if > 0 130 | _C.AUG.MIXUP = 0.8 131 | # Cutmix alpha, cutmix enabled if > 0 132 | _C.AUG.CUTMIX = 1.0 133 | # Cutmix min/max ratio, overrides alpha and enables cutmix if set 134 | _C.AUG.CUTMIX_MINMAX = None 135 | # Probability of performing mixup or cutmix when either/both is enabled 136 | _C.AUG.MIXUP_PROB = 1.0 137 | # Probability of switching to cutmix when both mixup and cutmix enabled 138 | _C.AUG.MIXUP_SWITCH_PROB = 0.5 139 | # How to apply mixup/cutmix params. Per "batch", "pair", or "elem" 140 | _C.AUG.MIXUP_MODE = 'batch' 141 | 142 | # ----------------------------------------------------------------------------- 143 | # Augmentation settings 144 | # ----------------------------------------------------------------------------- 145 | _C.EXTRA = CN() 146 | 147 | # Fine-tuning settings 148 | _C.EXTRA.searching_model = None 149 | _C.EXTRA.assigned_indicators = None 150 | 151 | # Architecture hyper-parameters 152 | _C.EXTRA.architecture_lr = 5e-4 153 | _C.EXTRA.arc_decay = 100 154 | _C.EXTRA.arc_warmup = 20 155 | _C.EXTRA.arc_min_lr = 5e-6 156 | 157 | # Hyper-parameters 158 | _C.EXTRA.theta = 0.5 # Bernoulli gates' initial parameter 159 | _C.EXTRA.alpha = 1e2 # Softmax temperature for ensembling heads 160 | _C.EXTRA.loss_lambda = 0.14 161 | _C.EXTRA.target_flops = 3.6 162 | 163 | _C.EXTRA.teacher_model = 'regnety_160' 164 | _C.EXTRA.teacher_path = 'https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth' 165 | _C.EXTRA.distillation_type = 'none' 166 | _C.EXTRA.distillation_alpha = 0.5 167 | _C.EXTRA.distillation_tau = 1.0 168 | _C.EXTRA.attention_type = 'UnifiedWindowAttention' 169 | 170 | # ----------------------------------------------------------------------------- 171 | # Testing settings 172 | # ----------------------------------------------------------------------------- 173 | _C.TEST = CN() 174 | # Whether to use center crop when testing 175 | _C.TEST.CROP = True 176 | 177 | # ----------------------------------------------------------------------------- 178 | # Misc 179 | # ----------------------------------------------------------------------------- 180 | # Mixed precision opt level, if O0, no amp is used ('O0', 'O1', 'O2') 181 | # overwritten by command line argument 182 | _C.AMP_OPT_LEVEL = '' 183 | # Path to output folder, overwritten by command line argument 184 | _C.OUTPUT = '' 185 | # Tag of experiment, overwritten by command line argument 186 | _C.TAG = 'default' 187 | # Frequency to save checkpoint 188 | _C.SAVE_FREQ = 50 189 | # Frequency to logging info 190 | _C.PRINT_FREQ = 10 191 | # Fixed random seed 192 | _C.SEED = 0 193 | # Perform evaluation only, overwritten by command line argument 194 | _C.EVAL_MODE = False 195 | # Test throughput only, overwritten by command line argument 196 | _C.THROUGHPUT_MODE = False 197 | # local rank for DistributedDataParallel, given by command line argument 198 | _C.LOCAL_RANK = 0 199 | 200 | 201 | def _update_config_from_file(config, cfg_file): 202 | config.defrost() 203 | with open(cfg_file, 'r') as f: 204 | yaml_cfg = yaml.load(f, Loader=yaml.FullLoader) 205 | 206 | for cfg in yaml_cfg.setdefault('BASE', ['']): 207 | if cfg: 208 | _update_config_from_file( 209 | config, os.path.join(os.path.dirname(cfg_file), cfg) 210 | ) 211 | 212 | print('=> merge config from {}'.format(cfg_file)) 213 | 214 | config.merge_from_file(cfg_file) 215 | config.freeze() 216 | 217 | 218 | def update_config(config, args): 219 | _update_config_from_file(config, args.cfg) 220 | 221 | config.defrost() 222 | if args.opts: 223 | config.merge_from_list(args.opts) 224 | 225 | # merge from specific arguments 226 | if args.batch_size: 227 | config.DATA.BATCH_SIZE = args.batch_size 228 | if args.data_path: 229 | config.DATA.DATA_PATH = args.data_path 230 | if args.zip: 231 | config.DATA.ZIP_MODE = True 232 | if args.cache_mode: 233 | config.DATA.CACHE_MODE = args.cache_mode 234 | if args.resume: 235 | config.MODEL.RESUME = args.resume 236 | if args.accumulation_steps: 237 | config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps 238 | if args.use_checkpoint: 239 | config.TRAIN.USE_CHECKPOINT = True 240 | if args.amp_opt_level: 241 | config.AMP_OPT_LEVEL = args.amp_opt_level 242 | if args.output: 243 | config.OUTPUT = args.output 244 | if args.tag: 245 | config.TAG = args.tag 246 | if args.eval: 247 | config.EVAL_MODE = True 248 | if args.throughput: 249 | config.THROUGHPUT_MODE = True 250 | 251 | # set local rank for distributed training 252 | config.LOCAL_RANK = args.local_rank 253 | 254 | # output folder 255 | config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME, config.TAG) 256 | config.freeze() 257 | 258 | 259 | def get_config(args): 260 | """Get a yacs CfgNode object with default values.""" 261 | # Return a clone so that the defaults will not be altered 262 | # This is for the "local variable" use pattern 263 | config = _C.clone() 264 | update_config(config, args) 265 | 266 | return config 267 | -------------------------------------------------------------------------------- /SPViT_Swin/configs/spvit_swin_bs_l01_t100_ft.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: spvit_swin 3 | NAME: spvit_swin_bs_l01_t100_ft 4 | DROP_PATH_RATE: 0.5 5 | SWIN: 6 | EMBED_DIM: 128 7 | DEPTHS: [ 2, 2, 18, 2 ] 8 | NUM_HEADS: [ 4, 8, 16, 32 ] 9 | WINDOW_SIZE: 7 10 | DATA: 11 | NUM_WORKERS: 10 12 | BATCH_SIZE: 128 13 | DATA_PATH: dataset/imagenet 14 | DATASET: imagenet 15 | EXTRA: 16 | loss_lambda: 0.1 17 | arc_decay: 150 18 | target_flops: 10.0 19 | arc_warmup: 0 20 | arc_min_lr: 5e-4 21 | architecture_lr: 5e-4 22 | alpha: 1e2 23 | theta: 1.5 24 | assigned_indicators: [[[1.0, 1.0, 1.0], [0.0, 0.0, 0.0]], [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [0.0, 0.0, 0.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 1.0, 1.0], [0.0, 0.0, 0.0]], [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]] 25 | searching_model: 'ffn_indicators/spvit_swin_bs_l01_t100_search_20epoch.pth' 26 | TRAIN: 27 | EPOCHS: 130 28 | WARMUP_EPOCHS: 0 29 | BASE_LR: 5e-5 30 | #EVAL_MODE: True -------------------------------------------------------------------------------- /SPViT_Swin/configs/spvit_swin_bs_l01_t100_search.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: spvit_swin 3 | NAME: spvit_swin_bs_l01_t100_search 4 | DROP_PATH_RATE: 0.5 5 | SWIN: 6 | EMBED_DIM: 128 7 | DEPTHS: [ 2, 2, 18, 2 ] 8 | NUM_HEADS: [ 4, 8, 16, 32 ] 9 | WINDOW_SIZE: 7 10 | DATA: 11 | NUM_WORKERS: 10 12 | BATCH_SIZE: 92 13 | DATA_PATH: dataset/imagenet 14 | DATASET: imagenet 15 | EXTRA: 16 | loss_lambda: 0.1 17 | arc_decay: 150 18 | target_flops: 10.0 19 | arc_warmup: 0 20 | arc_min_lr: 5e-4 21 | architecture_lr: 5e-4 22 | alpha: 1e2 23 | theta: 1.5 24 | TRAIN: 25 | EPOCHS: 300 26 | WARMUP_EPOCHS: 0 27 | BASE_LR: 5e-5 28 | MIN_LR: 5e-5 29 | #EVAL_MODE: True -------------------------------------------------------------------------------- /SPViT_Swin/configs/spvit_swin_sm_l04_t55_ft.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: spvit_swin 3 | NAME: spvit_swin_sm_l04_t55_ft 4 | DROP_PATH_RATE: 0.3 5 | SWIN: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 18, 2 ] 8 | NUM_HEADS: [ 3, 6, 12, 24 ] 9 | WINDOW_SIZE: 7 10 | DATA: 11 | NUM_WORKERS: 10 12 | BATCH_SIZE: 128 13 | DATA_PATH: dataset/imagenet 14 | DATASET: imagenet 15 | EXTRA: 16 | loss_lambda: 0.4 17 | arc_decay: 150 18 | target_flops: 5.5 19 | arc_warmup: 0 20 | arc_min_lr: 5e-4 21 | architecture_lr: 5e-4 22 | alpha: 1e2 23 | theta: 1.5 24 | assigned_indicators: [[[1.0, 0.0, 0.0], [0.0, 0.0, 0.0]], [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]] 25 | searching_model: 'ffn_indicators/spvit_swin_sm_l04_t55_search_14epoch.pth' 26 | TRAIN: 27 | EPOCHS: 130 28 | WARMUP_EPOCHS: 0 29 | BASE_LR: 5e-5 30 | #EVAL_MODE: True -------------------------------------------------------------------------------- /SPViT_Swin/configs/spvit_swin_sm_l04_t55_ft_dist.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: spvit_swin 3 | NAME: spvit_swin_sm_l04_t55_ft_dist 4 | DROP_PATH_RATE: 0.3 5 | SWIN: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 18, 2 ] 8 | NUM_HEADS: [ 3, 6, 12, 24 ] 9 | WINDOW_SIZE: 7 10 | DATA: 11 | NUM_WORKERS: 10 12 | BATCH_SIZE: 128 13 | DATA_PATH: dataset/imagenet 14 | DATASET: imagenet 15 | EXTRA: 16 | loss_lambda: 0.4 17 | arc_decay: 150 18 | target_flops: 5.5 19 | arc_warmup: 0 20 | arc_min_lr: 5e-4 21 | architecture_lr: 5e-4 22 | alpha: 1e2 23 | theta: 1.5 24 | assigned_indicators: [[[1.0, 0.0, 0.0], [0.0, 0.0, 0.0]], [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]] 25 | searching_model: 'ffn_indicators/spvit_swin_sm_l04_t55_search_14epoch.pth' 26 | distillation_type: 'hard' 27 | TRAIN: 28 | EPOCHS: 200 29 | WARMUP_EPOCHS: 0 30 | BASE_LR: 5e-5 31 | #EVAL_MODE: True -------------------------------------------------------------------------------- /SPViT_Swin/configs/spvit_swin_sm_l04_t55_search.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: spvit_swin 3 | NAME: spvit_swin_sm_l04_t55_search 4 | DROP_PATH_RATE: 0.3 5 | SWIN: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 18, 2 ] 8 | NUM_HEADS: [ 3, 6, 12, 24 ] 9 | WINDOW_SIZE: 7 10 | DATA: 11 | NUM_WORKERS: 10 12 | BATCH_SIZE: 128 13 | DATA_PATH: dataset/imagenet 14 | DATASET: imagenet 15 | EXTRA: 16 | loss_lambda: 0.4 17 | arc_decay: 150 18 | target_flops: 5.5 19 | arc_warmup: 0 20 | arc_min_lr: 5e-4 21 | architecture_lr: 5e-4 22 | alpha: 1e2 23 | theta: 1.5 24 | TRAIN: 25 | EPOCHS: 300 26 | WARMUP_EPOCHS: 0 27 | BASE_LR: 5e-5 28 | MIN_LR: 5e-5 29 | #EVAL_MODE: True -------------------------------------------------------------------------------- /SPViT_Swin/configs/spvit_swin_tn_l28_t32_ft.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: spvit_swin 3 | NAME: spvit_swin_tn_l28_t32_ft 4 | DROP_PATH_RATE: 0.2 5 | SWIN: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 6, 2 ] 8 | NUM_HEADS: [ 3, 6, 12, 24 ] 9 | WINDOW_SIZE: 7 10 | DATA: 11 | NUM_WORKERS: 10 12 | BATCH_SIZE: 128 13 | DATA_PATH: dataset/imagenet 14 | DATASET: imagenet 15 | EXTRA: 16 | loss_lambda: 2.8 17 | arc_decay: 150 18 | target_flops: 3.2 19 | arc_warmup: 0 20 | arc_min_lr: 5e-4 21 | architecture_lr: 5e-4 22 | alpha: 1e2 23 | theta: 1.5 24 | assigned_indicators: [[[1.0, 0.0, 0.0], [0.0, 0.0, 0.0]], [[1.0, 0.0, 0.0], [0.0, 0.0, 0.0]], [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]] 25 | searching_model: 'ffn_indicators/spvit_swin_t_l28_t32_search_12epoch.pth' 26 | TRAIN: 27 | EPOCHS: 130 28 | WARMUP_EPOCHS: 0 29 | BASE_LR: 5e-5 30 | #EVAL_MODE: True -------------------------------------------------------------------------------- /SPViT_Swin/configs/spvit_swin_tn_l28_t32_ft_dist.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: spvit_swin 3 | NAME: spvit_swin_tn_l28_t32_ft_dist 4 | DROP_PATH_RATE: 0.2 5 | SWIN: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 6, 2 ] 8 | NUM_HEADS: [ 3, 6, 12, 24 ] 9 | WINDOW_SIZE: 7 10 | DATA: 11 | NUM_WORKERS: 10 12 | BATCH_SIZE: 128 13 | DATA_PATH: dataset/imagenet 14 | DATASET: imagenet 15 | EXTRA: 16 | loss_lambda: 2.8 17 | arc_decay: 150 18 | target_flops: 3.2 19 | arc_warmup: 0 20 | arc_min_lr: 5e-4 21 | architecture_lr: 5e-4 22 | alpha: 1e2 23 | theta: 1.5 24 | assigned_indicators: [[[1.0, 0.0, 0.0], [0.0, 0.0, 0.0]], [[1.0, 0.0, 0.0], [0.0, 0.0, 0.0]], [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]] 25 | searching_model: 'ffn_indicators/spvit_swin_t_l28_t32_search_12epoch.pth' 26 | distillation_type: 'hard' 27 | TRAIN: 28 | EPOCHS: 200 29 | WARMUP_EPOCHS: 0 30 | BASE_LR: 5e-5 31 | #EVAL_MODE: True -------------------------------------------------------------------------------- /SPViT_Swin/configs/spvit_swin_tn_l28_t32_search.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: spvit_swin 3 | NAME: spvit_swin_tn_l28_t32_search 4 | DROP_PATH_RATE: 0.2 5 | SWIN: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 6, 2 ] 8 | NUM_HEADS: [ 3, 6, 12, 24 ] 9 | WINDOW_SIZE: 7 10 | DATA: 11 | NUM_WORKERS: 10 12 | BATCH_SIZE: 128 13 | DATA_PATH: dataset/imagenet 14 | DATASET: imagenet 15 | EXTRA: 16 | loss_lambda: 2.8 17 | arc_decay: 150 18 | target_flops: 3.2 19 | arc_warmup: 0 20 | arc_min_lr: 5e-4 21 | architecture_lr: 5e-4 22 | alpha: 1e2 23 | theta: 1.5 24 | TRAIN: 25 | EPOCHS: 300 26 | WARMUP_EPOCHS: 0 27 | BASE_LR: 5e-5 28 | MIN_LR: 5e-5 29 | #EVAL_MODE: True -------------------------------------------------------------------------------- /SPViT_Swin/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_loader, build_loader_darts, build_loader_darts_v2 -------------------------------------------------------------------------------- /SPViT_Swin/data/build.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | # Modifications copyright (c) 2021 Zhuang AI Group, Haoyu He 8 | 9 | import os 10 | import torch 11 | import numpy as np 12 | import torch.distributed as dist 13 | from torchvision import datasets, transforms 14 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 15 | from timm.data import Mixup 16 | from timm.data import create_transform 17 | from timm.data.transforms import _pil_interp 18 | 19 | from .cached_image_folder import CachedImageFolder 20 | from .samplers import SubsetRandomSampler 21 | from utils import DistributedSamplerWrapper 22 | 23 | def build_loader(config): 24 | config.defrost() 25 | dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config) 26 | config.freeze() 27 | print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build train dataset") 28 | dataset_val, _ = build_dataset(is_train=False, config=config) 29 | print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build val dataset") 30 | 31 | num_tasks = dist.get_world_size() 32 | global_rank = dist.get_rank() 33 | if config.DATA.ZIP_MODE and config.DATA.CACHE_MODE == 'part': 34 | indices = np.arange(dist.get_rank(), len(dataset_train), dist.get_world_size()) 35 | sampler_train = SubsetRandomSampler(indices) 36 | else: 37 | sampler_train = torch.utils.data.DistributedSampler( 38 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 39 | ) 40 | 41 | indices = np.arange(dist.get_rank(), len(dataset_val), dist.get_world_size()) 42 | sampler_val = SubsetRandomSampler(indices) 43 | 44 | data_loader_train = torch.utils.data.DataLoader( 45 | dataset_train, sampler=sampler_train, 46 | batch_size=config.DATA.BATCH_SIZE, 47 | num_workers=config.DATA.NUM_WORKERS, 48 | pin_memory=config.DATA.PIN_MEMORY, 49 | drop_last=True, 50 | ) 51 | 52 | data_loader_val = torch.utils.data.DataLoader( 53 | dataset_val, sampler=sampler_val, 54 | batch_size=config.DATA.BATCH_SIZE, 55 | shuffle=False, 56 | num_workers=config.DATA.NUM_WORKERS, 57 | pin_memory=config.DATA.PIN_MEMORY, 58 | drop_last=False 59 | ) 60 | 61 | # setup mixup / cutmix 62 | mixup_fn = None 63 | mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None 64 | if mixup_active: 65 | mixup_fn = Mixup( 66 | mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX, 67 | prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE, 68 | label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES) 69 | 70 | return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn 71 | 72 | 73 | def build_loader_darts(config): 74 | config.defrost() 75 | dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config) 76 | config.freeze() 77 | print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build train dataset") 78 | dataset_val, _ = build_dataset(is_train=False, config=config) 79 | print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build val dataset") 80 | 81 | num_tasks = dist.get_world_size() 82 | global_rank = dist.get_rank() 83 | if config.DATA.ZIP_MODE and config.DATA.CACHE_MODE == 'part': 84 | indices = np.arange(dist.get_rank(), len(dataset_train), dist.get_world_size()) 85 | sampler_train = SubsetRandomSampler(indices) 86 | else: 87 | 88 | num_train = len(dataset_train) 89 | indices = list(range(num_train)) 90 | split = int(np.floor(0.5 * num_train)) 91 | traintrain_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:split]) 92 | trainval_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[split:num_train]) 93 | dis_traintrain_sampler = DistributedSamplerWrapper(traintrain_sampler, num_replicas=num_tasks, rank=global_rank, shuffle=True) 94 | dis_trainval_sampler = DistributedSamplerWrapper(trainval_sampler, num_replicas=num_tasks, rank=global_rank, shuffle=True) 95 | 96 | indices = np.arange(dist.get_rank(), len(dataset_val), dist.get_world_size()) 97 | sampler_val = SubsetRandomSampler(indices) 98 | 99 | train_queue = torch.utils.data.DataLoader( 100 | dataset_train, sampler=dis_traintrain_sampler, 101 | batch_size=config.DATA.BATCH_SIZE, 102 | num_workers=0, 103 | pin_memory=config.DATA.PIN_MEMORY, 104 | drop_last=True, 105 | ) 106 | 107 | val_queue = torch.utils.data.DataLoader( 108 | dataset_train, sampler=dis_trainval_sampler, 109 | batch_size=config.DATA.BATCH_SIZE, 110 | num_workers=0, 111 | pin_memory=config.DATA.PIN_MEMORY, 112 | drop_last=True, 113 | ) 114 | 115 | data_loader_val = torch.utils.data.DataLoader( 116 | dataset_val, sampler=sampler_val, 117 | batch_size=config.DATA.BATCH_SIZE, 118 | shuffle=False, 119 | num_workers=config.DATA.NUM_WORKERS, 120 | pin_memory=config.DATA.PIN_MEMORY, 121 | drop_last=False 122 | ) 123 | 124 | # setup mixup / cutmix 125 | mixup_fn = None 126 | mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None 127 | if mixup_active: 128 | mixup_fn = Mixup( 129 | mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX, 130 | prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE, 131 | label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES) 132 | 133 | return dataset_train, dataset_val, train_queue, val_queue, data_loader_val, mixup_fn 134 | 135 | 136 | def build_loader_darts_v2(config): 137 | config.defrost() 138 | dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config) 139 | config.freeze() 140 | print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build train dataset") 141 | dataset_val, _ = build_dataset(is_train=False, config=config) 142 | print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build val dataset") 143 | 144 | num_tasks = dist.get_world_size() 145 | global_rank = dist.get_rank() 146 | 147 | if config.DATA.ZIP_MODE and config.DATA.CACHE_MODE == 'part': 148 | indices = np.arange(dist.get_rank(), len(dataset_train), dist.get_world_size()) 149 | sampler_train = SubsetRandomSampler(indices) 150 | else: 151 | num_train = len(dataset_train) 152 | split = int(np.floor(0.5 * num_train)) 153 | 154 | dataset_traintrain, dataset_trainval = torch.utils.data.random_split( 155 | dataset_train, 156 | (len(dataset_train) - split, split) 157 | ) 158 | 159 | sampler_traintrain = torch.utils.data.DistributedSampler( 160 | dataset_traintrain, num_replicas=num_tasks, rank=global_rank, shuffle=True 161 | ) 162 | 163 | sampler_trainval = torch.utils.data.DistributedSampler( 164 | dataset_trainval, num_replicas=num_tasks, rank=global_rank, shuffle=True 165 | ) 166 | 167 | # sampler_traintrain = torch.utils.data.RandomSampler( 168 | # dataset_traintrain 169 | # ) 170 | # 171 | # sampler_trainval = torch.utils.data.RandomSampler( 172 | # dataset_trainval 173 | # ) 174 | 175 | indices = np.arange(dist.get_rank(), len(dataset_val), dist.get_world_size()) 176 | sampler_val = SubsetRandomSampler(indices) 177 | 178 | train_queue = torch.utils.data.DataLoader( 179 | dataset_train, sampler=sampler_traintrain, 180 | batch_size=config.DATA.BATCH_SIZE, 181 | num_workers=2, 182 | pin_memory=config.DATA.PIN_MEMORY, 183 | drop_last=True, 184 | persistent_workers=False 185 | ) 186 | 187 | val_queue = torch.utils.data.DataLoader( 188 | dataset_train, sampler=sampler_trainval, 189 | batch_size=config.DATA.BATCH_SIZE, 190 | num_workers=2, 191 | pin_memory=config.DATA.PIN_MEMORY, 192 | drop_last=True, 193 | persistent_workers=False 194 | ) 195 | 196 | data_loader_val = torch.utils.data.DataLoader( 197 | dataset_val, sampler=sampler_val, 198 | batch_size=config.DATA.BATCH_SIZE, 199 | shuffle=False, 200 | num_workers=config.DATA.NUM_WORKERS, 201 | pin_memory=config.DATA.PIN_MEMORY, 202 | drop_last=False 203 | ) 204 | 205 | # setup mixup / cutmix 206 | mixup_fn = None 207 | mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None 208 | if mixup_active: 209 | mixup_fn = Mixup( 210 | mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX, 211 | prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE, 212 | label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES) 213 | 214 | return dataset_train, dataset_val, train_queue, val_queue, data_loader_val, mixup_fn 215 | 216 | 217 | def build_dataset(is_train, config): 218 | transform = build_transform(is_train, config) 219 | if config.DATA.DATASET == 'imagenet': 220 | prefix = 'train' if is_train else 'val' 221 | if config.DATA.ZIP_MODE: 222 | ann_file = prefix + "_map.txt" 223 | prefix = prefix + ".zip@/" 224 | dataset = CachedImageFolder(config.DATA.DATA_PATH, ann_file, prefix, transform, 225 | cache_mode=config.DATA.CACHE_MODE if is_train else 'part') 226 | else: 227 | root = os.path.join(config.DATA.DATA_PATH, prefix) 228 | dataset = datasets.ImageFolder(root, transform=transform) 229 | nb_classes = 1000 230 | elif config.DATA.DATASET == 'cifar': 231 | dataset = datasets.CIFAR100(config.DATA.DATA_PATH, train=is_train, transform=transform) 232 | nb_classes = 100 233 | elif config.DATA.DATASET == 'imagenet100': 234 | root = os.path.join(config.DATA.DATA_PATH, 'train100' if is_train else 'val100') 235 | dataset = datasets.ImageFolder(root, transform=transform) 236 | nb_classes = 100 237 | else: 238 | raise NotImplementedError("We only support ImageNet Now.") 239 | 240 | return dataset, nb_classes 241 | 242 | 243 | def build_transform(is_train, config): 244 | resize_im = config.DATA.IMG_SIZE > 32 245 | if is_train: 246 | # this should always dispatch to transforms_imagenet_train 247 | transform = create_transform( 248 | input_size=config.DATA.IMG_SIZE, 249 | is_training=True, 250 | color_jitter=config.AUG.COLOR_JITTER if config.AUG.COLOR_JITTER > 0 else None, 251 | auto_augment=config.AUG.AUTO_AUGMENT if config.AUG.AUTO_AUGMENT != 'none' else None, 252 | re_prob=config.AUG.REPROB, 253 | re_mode=config.AUG.REMODE, 254 | re_count=config.AUG.RECOUNT, 255 | interpolation=config.DATA.INTERPOLATION, 256 | ) 257 | if not resize_im: 258 | # replace RandomResizedCropAndInterpolation with 259 | # RandomCrop 260 | transform.transforms[0] = transforms.RandomCrop(config.DATA.IMG_SIZE, padding=4) 261 | return transform 262 | 263 | t = [] 264 | if resize_im: 265 | if config.TEST.CROP: 266 | size = int((256 / 224) * config.DATA.IMG_SIZE) 267 | t.append( 268 | transforms.Resize(size, interpolation=_pil_interp(config.DATA.INTERPOLATION)), 269 | # to maintain same ratio w.r.t. 224 images 270 | ) 271 | t.append(transforms.CenterCrop(config.DATA.IMG_SIZE)) 272 | else: 273 | t.append( 274 | transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE), 275 | interpolation=_pil_interp(config.DATA.INTERPOLATION)) 276 | ) 277 | 278 | t.append(transforms.ToTensor()) 279 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) 280 | return transforms.Compose(t) 281 | -------------------------------------------------------------------------------- /SPViT_Swin/data/cached_image_folder.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | # Modifications copyright (c) 2021 Zhuang AI Group, Haoyu He 8 | 9 | import io 10 | import os 11 | import time 12 | import torch.distributed as dist 13 | import torch.utils.data as data 14 | from PIL import Image 15 | 16 | from .zipreader import is_zip_path, ZipReader 17 | 18 | 19 | def has_file_allowed_extension(filename, extensions): 20 | """Checks if a file is an allowed extension. 21 | Args: 22 | filename (string): path to a file 23 | Returns: 24 | bool: True if the filename ends with a known image extension 25 | """ 26 | filename_lower = filename.lower() 27 | return any(filename_lower.endswith(ext) for ext in extensions) 28 | 29 | 30 | def find_classes(dir): 31 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 32 | classes.sort() 33 | class_to_idx = {classes[i]: i for i in range(len(classes))} 34 | return classes, class_to_idx 35 | 36 | 37 | def make_dataset(dir, class_to_idx, extensions): 38 | images = [] 39 | dir = os.path.expanduser(dir) 40 | for target in sorted(os.listdir(dir)): 41 | d = os.path.join(dir, target) 42 | if not os.path.isdir(d): 43 | continue 44 | 45 | for root, _, fnames in sorted(os.walk(d)): 46 | for fname in sorted(fnames): 47 | if has_file_allowed_extension(fname, extensions): 48 | path = os.path.join(root, fname) 49 | item = (path, class_to_idx[target]) 50 | images.append(item) 51 | 52 | return images 53 | 54 | 55 | def make_dataset_with_ann(ann_file, img_prefix, extensions): 56 | images = [] 57 | with open(ann_file, "r") as f: 58 | contents = f.readlines() 59 | for line_str in contents: 60 | path_contents = [c for c in line_str.split('\t')] 61 | im_file_name = path_contents[0] 62 | class_index = int(path_contents[1]) 63 | 64 | assert str.lower(os.path.splitext(im_file_name)[-1]) in extensions 65 | item = (os.path.join(img_prefix, im_file_name), class_index) 66 | 67 | images.append(item) 68 | 69 | return images 70 | 71 | 72 | class DatasetFolder(data.Dataset): 73 | """A generic data loader where the samples are arranged in this way: :: 74 | root/class_x/xxx.ext 75 | root/class_x/xxy.ext 76 | root/class_x/xxz.ext 77 | root/class_y/123.ext 78 | root/class_y/nsdf3.ext 79 | root/class_y/asd932_.ext 80 | Args: 81 | root (string): Root directory path. 82 | loader (callable): A function to load a sample given its path. 83 | extensions (list[string]): A list of allowed extensions. 84 | transform (callable, optional): A function/transform that takes in 85 | a sample and returns a transformed version. 86 | E.g, ``transforms.RandomCrop`` for images. 87 | target_transform (callable, optional): A function/transform that takes 88 | in the target and transforms it. 89 | Attributes: 90 | samples (list): List of (sample path, class_index) tuples 91 | """ 92 | 93 | def __init__(self, root, loader, extensions, ann_file='', img_prefix='', transform=None, target_transform=None, 94 | cache_mode="no"): 95 | # image folder mode 96 | if ann_file == '': 97 | _, class_to_idx = find_classes(root) 98 | samples = make_dataset(root, class_to_idx, extensions) 99 | # zip mode 100 | else: 101 | samples = make_dataset_with_ann(os.path.join(root, ann_file), 102 | os.path.join(root, img_prefix), 103 | extensions) 104 | 105 | if len(samples) == 0: 106 | raise (RuntimeError("Found 0 files in subfolders of: " + root + "\n" + 107 | "Supported extensions are: " + ",".join(extensions))) 108 | 109 | self.root = root 110 | self.loader = loader 111 | self.extensions = extensions 112 | 113 | self.samples = samples 114 | self.labels = [y_1k for _, y_1k in samples] 115 | self.classes = list(set(self.labels)) 116 | 117 | self.transform = transform 118 | self.target_transform = target_transform 119 | 120 | self.cache_mode = cache_mode 121 | if self.cache_mode != "no": 122 | self.init_cache() 123 | 124 | def init_cache(self): 125 | assert self.cache_mode in ["part", "full"] 126 | n_sample = len(self.samples) 127 | global_rank = dist.get_rank() 128 | world_size = dist.get_world_size() 129 | 130 | samples_bytes = [None for _ in range(n_sample)] 131 | start_time = time.time() 132 | for index in range(n_sample): 133 | if index % (n_sample // 10) == 0: 134 | t = time.time() - start_time 135 | print(f'global_rank {dist.get_rank()} cached {index}/{n_sample} takes {t:.2f}s per block') 136 | start_time = time.time() 137 | path, target = self.samples[index] 138 | if self.cache_mode == "full": 139 | samples_bytes[index] = (ZipReader.read(path), target) 140 | elif self.cache_mode == "part" and index % world_size == global_rank: 141 | samples_bytes[index] = (ZipReader.read(path), target) 142 | else: 143 | samples_bytes[index] = (path, target) 144 | self.samples = samples_bytes 145 | 146 | def __getitem__(self, index): 147 | """ 148 | Args: 149 | index (int): Index 150 | Returns: 151 | tuple: (sample, target) where target is class_index of the target class. 152 | """ 153 | path, target = self.samples[index] 154 | sample = self.loader(path) 155 | if self.transform is not None: 156 | sample = self.transform(sample) 157 | if self.target_transform is not None: 158 | target = self.target_transform(target) 159 | 160 | return sample, target 161 | 162 | def __len__(self): 163 | return len(self.samples) 164 | 165 | def __repr__(self): 166 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 167 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 168 | fmt_str += ' Root Location: {}\n'.format(self.root) 169 | tmp = ' Transforms (if any): ' 170 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 171 | tmp = ' Target Transforms (if any): ' 172 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 173 | return fmt_str 174 | 175 | 176 | IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif'] 177 | 178 | 179 | def pil_loader(path): 180 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 181 | if isinstance(path, bytes): 182 | img = Image.open(io.BytesIO(path)) 183 | elif is_zip_path(path): 184 | data = ZipReader.read(path) 185 | img = Image.open(io.BytesIO(data)) 186 | else: 187 | with open(path, 'rb') as f: 188 | img = Image.open(f) 189 | return img.convert('RGB') 190 | 191 | 192 | def accimage_loader(path): 193 | import accimage 194 | try: 195 | return accimage.Image(path) 196 | except IOError: 197 | # Potentially a decoding problem, fall back to PIL.Image 198 | return pil_loader(path) 199 | 200 | 201 | def default_img_loader(path): 202 | from torchvision import get_image_backend 203 | if get_image_backend() == 'accimage': 204 | return accimage_loader(path) 205 | else: 206 | return pil_loader(path) 207 | 208 | 209 | class CachedImageFolder(DatasetFolder): 210 | """A generic data loader where the images are arranged in this way: :: 211 | root/dog/xxx.png 212 | root/dog/xxy.png 213 | root/dog/xxz.png 214 | root/cat/123.png 215 | root/cat/nsdf3.png 216 | root/cat/asd932_.png 217 | Args: 218 | root (string): Root directory path. 219 | transform (callable, optional): A function/transform that takes in an PIL image 220 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 221 | target_transform (callable, optional): A function/transform that takes in the 222 | target and transforms it. 223 | loader (callable, optional): A function to load an image given its path. 224 | Attributes: 225 | imgs (list): List of (image path, class_index) tuples 226 | """ 227 | 228 | def __init__(self, root, ann_file='', img_prefix='', transform=None, target_transform=None, 229 | loader=default_img_loader, cache_mode="no"): 230 | super(CachedImageFolder, self).__init__(root, loader, IMG_EXTENSIONS, 231 | ann_file=ann_file, img_prefix=img_prefix, 232 | transform=transform, target_transform=target_transform, 233 | cache_mode=cache_mode) 234 | self.imgs = self.samples 235 | 236 | def __getitem__(self, index): 237 | """ 238 | Args: 239 | index (int): Index 240 | Returns: 241 | tuple: (image, target) where target is class_index of the target class. 242 | """ 243 | path, target = self.samples[index] 244 | image = self.loader(path) 245 | if self.transform is not None: 246 | img = self.transform(image) 247 | else: 248 | img = image 249 | if self.target_transform is not None: 250 | target = self.target_transform(target) 251 | 252 | return img, target 253 | -------------------------------------------------------------------------------- /SPViT_Swin/data/samplers.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | # Modifications copyright (c) 2021 Zhuang AI Group, Haoyu He 8 | 9 | import torch 10 | 11 | 12 | class SubsetRandomSampler(torch.utils.data.Sampler): 13 | r"""Samples elements randomly from a given list of indices, without replacement. 14 | 15 | Arguments: 16 | indices (sequence): a sequence of indices 17 | """ 18 | 19 | def __init__(self, indices): 20 | self.epoch = 0 21 | self.indices = indices 22 | 23 | def __iter__(self): 24 | return (self.indices[i] for i in torch.randperm(len(self.indices))) 25 | 26 | def __len__(self): 27 | return len(self.indices) 28 | 29 | def set_epoch(self, epoch): 30 | self.epoch = epoch 31 | -------------------------------------------------------------------------------- /SPViT_Swin/data/zipreader.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | # Modifications copyright (c) 2021 Zhuang AI Group, Haoyu He 8 | 9 | import os 10 | import zipfile 11 | import io 12 | import numpy as np 13 | from PIL import Image 14 | from PIL import ImageFile 15 | 16 | ImageFile.LOAD_TRUNCATED_IMAGES = True 17 | 18 | 19 | def is_zip_path(img_or_path): 20 | """judge if this is a zip path""" 21 | return '.zip@' in img_or_path 22 | 23 | 24 | class ZipReader(object): 25 | """A class to read zipped files""" 26 | zip_bank = dict() 27 | 28 | def __init__(self): 29 | super(ZipReader, self).__init__() 30 | 31 | @staticmethod 32 | def get_zipfile(path): 33 | zip_bank = ZipReader.zip_bank 34 | if path not in zip_bank: 35 | zfile = zipfile.ZipFile(path, 'r') 36 | zip_bank[path] = zfile 37 | return zip_bank[path] 38 | 39 | @staticmethod 40 | def split_zip_style_path(path): 41 | pos_at = path.index('@') 42 | assert pos_at != -1, "character '@' is not found from the given path '%s'" % path 43 | 44 | zip_path = path[0: pos_at] 45 | folder_path = path[pos_at + 1:] 46 | folder_path = str.strip(folder_path, '/') 47 | return zip_path, folder_path 48 | 49 | @staticmethod 50 | def list_folder(path): 51 | zip_path, folder_path = ZipReader.split_zip_style_path(path) 52 | 53 | zfile = ZipReader.get_zipfile(zip_path) 54 | folder_list = [] 55 | for file_foler_name in zfile.namelist(): 56 | file_foler_name = str.strip(file_foler_name, '/') 57 | if file_foler_name.startswith(folder_path) and \ 58 | len(os.path.splitext(file_foler_name)[-1]) == 0 and \ 59 | file_foler_name != folder_path: 60 | if len(folder_path) == 0: 61 | folder_list.append(file_foler_name) 62 | else: 63 | folder_list.append(file_foler_name[len(folder_path) + 1:]) 64 | 65 | return folder_list 66 | 67 | @staticmethod 68 | def list_files(path, extension=None): 69 | if extension is None: 70 | extension = ['.*'] 71 | zip_path, folder_path = ZipReader.split_zip_style_path(path) 72 | 73 | zfile = ZipReader.get_zipfile(zip_path) 74 | file_lists = [] 75 | for file_foler_name in zfile.namelist(): 76 | file_foler_name = str.strip(file_foler_name, '/') 77 | if file_foler_name.startswith(folder_path) and \ 78 | str.lower(os.path.splitext(file_foler_name)[-1]) in extension: 79 | if len(folder_path) == 0: 80 | file_lists.append(file_foler_name) 81 | else: 82 | file_lists.append(file_foler_name[len(folder_path) + 1:]) 83 | 84 | return file_lists 85 | 86 | @staticmethod 87 | def read(path): 88 | zip_path, path_img = ZipReader.split_zip_style_path(path) 89 | zfile = ZipReader.get_zipfile(zip_path) 90 | data = zfile.read(path_img) 91 | return data 92 | 93 | @staticmethod 94 | def imread(path): 95 | zip_path, path_img = ZipReader.split_zip_style_path(path) 96 | zfile = ZipReader.get_zipfile(zip_path) 97 | data = zfile.read(path_img) 98 | try: 99 | im = Image.open(io.BytesIO(data)) 100 | except: 101 | print("ERROR IMG LOADED: ", path_img) 102 | random_img = np.random.rand(224, 224, 3) * 255 103 | im = Image.fromarray(np.uint8(random_img)) 104 | return im 105 | -------------------------------------------------------------------------------- /SPViT_Swin/dev/README.md: -------------------------------------------------------------------------------- 1 | 2 | ## Some scripts for developers to use, include: 3 | 4 | - `linter.sh`: lint the codebase before commit 5 | - `run_{inference,instant}_tests.sh`: run inference/training for a few iterations. 6 | Note that these tests require 2 GPUs. 7 | - `parse_results.sh`: parse results from a log file. 8 | -------------------------------------------------------------------------------- /SPViT_Swin/dev/linter.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -e 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | # Run this script at project root by "./dev/linter.sh" before you commit 5 | 6 | vergte() { 7 | [ "$2" = "$(echo -e "$1\\n$2" | sort -V | head -n1)" ] 8 | } 9 | 10 | { 11 | black --version | grep -E "(19.3b0.*6733274)|(19.3b0\\+8)" > /dev/null 12 | } || { 13 | echo "Linter requires 'black @ git+https://github.com/psf/black@673327449f86fce558adde153bb6cbe54bfebad2' !" 14 | exit 1 15 | } 16 | 17 | ISORT_TARGET_VERSION="4.3.21" 18 | ISORT_VERSION=$(isort -v | grep VERSION | awk '{print $2}') 19 | vergte "$ISORT_VERSION" "$ISORT_TARGET_VERSION" || { 20 | echo "Linter requires isort>=${ISORT_TARGET_VERSION} !" 21 | exit 1 22 | } 23 | 24 | set -v 25 | 26 | echo "Running isort ..." 27 | isort -y -sp . --atomic 28 | 29 | echo "Running black ..." 30 | black -l 100 . 31 | 32 | echo "Running flake8 ..." 33 | if [ -x "$(command -v flake8-3)" ]; then 34 | flake8-3 . 35 | else 36 | python3 -m flake8 . 37 | fi 38 | 39 | # echo "Running mypy ..." 40 | # Pytorch does not have enough type annotations 41 | # mypy detectron2/solver detectron2/structures detectron2/config 42 | 43 | echo "Running clang-format ..." 44 | find . -regex ".*\.\(cpp\|c\|cc\|cu\|cxx\|h\|hh\|hpp\|hxx\|tcc\|mm\|m\)" -print0 | xargs -0 clang-format -i 45 | 46 | command -v arc > /dev/null && arc lint 47 | -------------------------------------------------------------------------------- /SPViT_Swin/dev/packaging/README.md: -------------------------------------------------------------------------------- 1 | 2 | ## To build a cu101 wheel for release: 3 | 4 | ``` 5 | $ nvidia-docker run -it --storage-opt "size=20GB" --name pt pytorch/manylinux-cuda101 6 | # inside the container: 7 | # git clone https://github.com/facebookresearch/detectron2/ 8 | # cd detectron2 9 | # export CU_VERSION=cu101 D2_VERSION_SUFFIX= PYTHON_VERSION=3.7 PYTORCH_VERSION=1.4 10 | # ./dev/packaging/build_wheel.sh 11 | ``` 12 | 13 | ## To build all wheels for `CUDA {9.2,10.0,10.1}` x `Python {3.6,3.7,3.8}`: 14 | ``` 15 | ./dev/packaging/build_all_wheels.sh 16 | ./dev/packaging/gen_wheel_index.sh /path/to/wheels 17 | ``` 18 | -------------------------------------------------------------------------------- /SPViT_Swin/dev/packaging/build_all_wheels.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -e 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | [[ -d "dev/packaging" ]] || { 5 | echo "Please run this script at detectron2 root!" 6 | exit 1 7 | } 8 | 9 | build_one() { 10 | cu=$1 11 | pytorch_ver=$2 12 | 13 | case "$cu" in 14 | cu*) 15 | container_name=manylinux-cuda${cu/cu/} 16 | ;; 17 | cpu) 18 | container_name=manylinux-cuda101 19 | ;; 20 | *) 21 | echo "Unrecognized cu=$cu" 22 | exit 1 23 | ;; 24 | esac 25 | 26 | echo "Launching container $container_name ..." 27 | 28 | for py in 3.6 3.7 3.8; do 29 | docker run -itd \ 30 | --name $container_name \ 31 | --mount type=bind,source="$(pwd)",target=/detectron2 \ 32 | pytorch/$container_name 33 | 34 | cat </dev/null 2>&1 && pwd )" 8 | . "$script_dir/pkg_helpers.bash" 9 | 10 | echo "Build Settings:" 11 | echo "CU_VERSION: $CU_VERSION" # e.g. cu101 12 | echo "D2_VERSION_SUFFIX: $D2_VERSION_SUFFIX" # e.g. +cu101 or "" 13 | echo "PYTHON_VERSION: $PYTHON_VERSION" # e.g. 3.6 14 | echo "PYTORCH_VERSION: $PYTORCH_VERSION" # e.g. 1.4 15 | 16 | setup_cuda 17 | setup_wheel_python 18 | yum install ninja-build -y && ln -sv /usr/bin/ninja-build /usr/bin/ninja 19 | 20 | pip_install pip numpy -U 21 | pip_install "torch==$PYTORCH_VERSION" \ 22 | -f https://download.pytorch.org/whl/"$CU_VERSION"/torch_stable.html 23 | 24 | # use separate directories to allow parallel build 25 | BASE_BUILD_DIR=build/cu$CU_VERSION-py$PYTHON_VERSION-pt$PYTORCH_VERSION 26 | python setup.py \ 27 | build -b "$BASE_BUILD_DIR" \ 28 | bdist_wheel -b "$BASE_BUILD_DIR/build_dist" -d "wheels/$CU_VERSION/torch$PYTORCH_VERSION" 29 | rm -rf "$BASE_BUILD_DIR" 30 | -------------------------------------------------------------------------------- /SPViT_Swin/dev/packaging/gen_install_table.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import argparse 5 | 6 | template = """
install
\
 7 | python -m pip install detectron2{d2_version} -f \\
 8 |   https://dl.fbaipublicfiles.com/detectron2/wheels/{cuda}/torch{torch}/index.html
 9 | 
""" 10 | CUDA_SUFFIX = {"10.2": "cu102", "10.1": "cu101", "10.0": "cu100", "9.2": "cu92", "cpu": "cpu"} 11 | 12 | 13 | def gen_header(torch_versions): 14 | return '' + "".join( 15 | [ 16 | ''.format(t) 17 | for t in torch_versions 18 | ] 19 | ) 20 | 21 | 22 | if __name__ == "__main__": 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("--d2-version", help="detectron2 version number, default to empty") 25 | args = parser.parse_args() 26 | d2_version = f"=={args.d2_version}" if args.d2_version else "" 27 | 28 | all_versions = [("1.4", k) for k in ["10.1", "10.0", "9.2", "cpu"]] + [ 29 | ("1.5", k) for k in ["10.2", "10.1", "9.2", "cpu"] 30 | ] 31 | 32 | torch_versions = sorted({k[0] for k in all_versions}, key=float, reverse=True) 33 | cuda_versions = sorted( 34 | {k[1] for k in all_versions}, key=lambda x: float(x) if x != "cpu" else 0, reverse=True 35 | ) 36 | 37 | table = gen_header(torch_versions) 38 | for cu in cuda_versions: 39 | table += f""" """ 40 | cu_suffix = CUDA_SUFFIX[cu] 41 | for torch in torch_versions: 42 | if (torch, cu) in all_versions: 43 | cell = template.format(d2_version=d2_version, cuda=cu_suffix, torch=torch) 44 | else: 45 | cell = "" 46 | table += f""" """ 47 | table += "" 48 | table += "
CUDA torch {}
{cu}{cell}
" 49 | print(table) 50 | -------------------------------------------------------------------------------- /SPViT_Swin/dev/packaging/gen_wheel_index.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -e 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | 5 | root=$1 6 | if [[ -z "$root" ]]; then 7 | echo "Usage: ./gen_wheel_index.sh /path/to/wheels" 8 | exit 9 | fi 10 | 11 | export LC_ALL=C # reproducible sort 12 | # NOTE: all sort in this script might not work when xx.10 is released 13 | 14 | index=$root/index.html 15 | 16 | cd "$root" 17 | for cu in cpu cu92 cu100 cu101 cu102; do 18 | cd "$root/$cu" 19 | echo "Creating $PWD/index.html ..." 20 | # First sort by torch version, then stable sort by d2 version with unique. 21 | # As a result, the latest torch version for each d2 version is kept. 22 | for whl in $(find -type f -name '*.whl' -printf '%P\n' \ 23 | | sort -k 1 -r | sort -t '/' -k 2 --stable -r --unique); do 24 | echo "$whl
" 25 | done > index.html 26 | 27 | 28 | for torch in torch*; do 29 | cd "$root/$cu/$torch" 30 | 31 | # list all whl for each cuda,torch version 32 | echo "Creating $PWD/index.html ..." 33 | for whl in $(find . -type f -name '*.whl' -printf '%P\n' | sort -r); do 34 | echo "$whl
" 35 | done > index.html 36 | done 37 | done 38 | 39 | cd "$root" 40 | # Just list everything: 41 | echo "Creating $index ..." 42 | for whl in $(find . -type f -name '*.whl' -printf '%P\n' | sort -r); do 43 | echo "$whl
" 44 | done > "$index" 45 | 46 | -------------------------------------------------------------------------------- /SPViT_Swin/dev/packaging/pkg_helpers.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash -e 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | # Function to retry functions that sometimes timeout or have flaky failures 5 | retry () { 6 | $* || (sleep 1 && $*) || (sleep 2 && $*) || (sleep 4 && $*) || (sleep 8 && $*) 7 | } 8 | # Install with pip a bit more robustly than the default 9 | pip_install() { 10 | retry pip install --progress-bar off "$@" 11 | } 12 | 13 | 14 | setup_cuda() { 15 | # Now work out the CUDA settings 16 | # Like other torch domain libraries, we choose common GPU architectures only. 17 | export FORCE_CUDA=1 18 | case "$CU_VERSION" in 19 | cu102) 20 | export CUDA_HOME=/usr/local/cuda-10.2/ 21 | export TORCH_CUDA_ARCH_LIST="3.5;3.7;5.0;5.2;6.0+PTX;6.1+PTX;7.0+PTX;7.5+PTX" 22 | ;; 23 | cu101) 24 | export CUDA_HOME=/usr/local/cuda-10.1/ 25 | export TORCH_CUDA_ARCH_LIST="3.5;3.7;5.0;5.2;6.0+PTX;6.1+PTX;7.0+PTX;7.5+PTX" 26 | ;; 27 | cu100) 28 | export CUDA_HOME=/usr/local/cuda-10.0/ 29 | export TORCH_CUDA_ARCH_LIST="3.5;3.7;5.0;5.2;6.0+PTX;6.1+PTX;7.0+PTX;7.5+PTX" 30 | ;; 31 | cu92) 32 | export CUDA_HOME=/usr/local/cuda-9.2/ 33 | export TORCH_CUDA_ARCH_LIST="3.5;3.7;5.0;5.2;6.0+PTX;6.1+PTX;7.0+PTX" 34 | ;; 35 | cpu) 36 | unset FORCE_CUDA 37 | export CUDA_VISIBLE_DEVICES= 38 | ;; 39 | *) 40 | echo "Unrecognized CU_VERSION=$CU_VERSION" 41 | exit 1 42 | ;; 43 | esac 44 | } 45 | 46 | setup_wheel_python() { 47 | case "$PYTHON_VERSION" in 48 | 3.6) python_abi=cp36-cp36m ;; 49 | 3.7) python_abi=cp37-cp37m ;; 50 | 3.8) python_abi=cp38-cp38 ;; 51 | *) 52 | echo "Unrecognized PYTHON_VERSION=$PYTHON_VERSION" 53 | exit 1 54 | ;; 55 | esac 56 | export PATH="/opt/python/$python_abi/bin:$PATH" 57 | } 58 | -------------------------------------------------------------------------------- /SPViT_Swin/dev/parse_results.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | # A shell script that parses metrics from the log file. 5 | # Make it easier for developers to track performance of models. 6 | 7 | LOG="$1" 8 | 9 | if [[ -z "$LOG" ]]; then 10 | echo "Usage: $0 /path/to/log/file" 11 | exit 1 12 | fi 13 | 14 | # [12/15 11:47:32] trainer INFO: Total training time: 12:15:04.446477 (0.4900 s / it) 15 | # [12/15 11:49:03] inference INFO: Total inference time: 0:01:25.326167 (0.13652186737060548 s / img per device, on 8 devices) 16 | # [12/15 11:49:03] inference INFO: Total inference pure compute time: ..... 17 | 18 | # training time 19 | trainspeed=$(grep -o 'Overall training.*' "$LOG" | grep -Eo '\(.*\)' | grep -o '[0-9\.]*') 20 | echo "Training speed: $trainspeed s/it" 21 | 22 | # inference time: there could be multiple inference during training 23 | inferencespeed=$(grep -o 'Total inference pure.*' "$LOG" | tail -n1 | grep -Eo '\(.*\)' | grep -o '[0-9\.]*' | head -n1) 24 | echo "Inference speed: $inferencespeed s/it" 25 | 26 | # [12/15 11:47:18] trainer INFO: eta: 0:00:00 iter: 90000 loss: 0.5407 (0.7256) loss_classifier: 0.1744 (0.2446) loss_box_reg: 0.0838 (0.1160) loss_mask: 0.2159 (0.2722) loss_objectness: 0.0244 (0.0429) loss_rpn_box_reg: 0.0279 (0.0500) time: 0.4487 (0.4899) data: 0.0076 (0.0975) lr: 0.000200 max mem: 4161 27 | memory=$(grep -o 'max[_ ]mem: [0-9]*' "$LOG" | tail -n1 | grep -o '[0-9]*') 28 | echo "Training memory: $memory MB" 29 | 30 | echo "Easy to copypaste:" 31 | echo "$trainspeed","$inferencespeed","$memory" 32 | 33 | echo "------------------------------" 34 | 35 | # [12/26 17:26:32] engine.coco_evaluation: copypaste: Task: bbox 36 | # [12/26 17:26:32] engine.coco_evaluation: copypaste: AP,AP50,AP75,APs,APm,APl 37 | # [12/26 17:26:32] engine.coco_evaluation: copypaste: 0.0017,0.0024,0.0017,0.0005,0.0019,0.0011 38 | # [12/26 17:26:32] engine.coco_evaluation: copypaste: Task: segm 39 | # [12/26 17:26:32] engine.coco_evaluation: copypaste: AP,AP50,AP75,APs,APm,APl 40 | # [12/26 17:26:32] engine.coco_evaluation: copypaste: 0.0014,0.0021,0.0016,0.0005,0.0016,0.0011 41 | 42 | echo "COCO Results:" 43 | num_tasks=$(grep -o 'copypaste:.*Task.*' "$LOG" | sort -u | wc -l) 44 | # each task has 3 lines 45 | grep -o 'copypaste:.*' "$LOG" | cut -d ' ' -f 2- | tail -n $((num_tasks * 3)) 46 | -------------------------------------------------------------------------------- /SPViT_Swin/dev/run_inference_tests.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -e 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | BIN="python tools/train_net.py" 5 | OUTPUT="inference_test_output" 6 | NUM_GPUS=2 7 | 8 | CFG_LIST=( "${@:1}" ) 9 | 10 | if [ ${#CFG_LIST[@]} -eq 0 ]; then 11 | CFG_LIST=( ./configs/quick_schedules/*inference_acc_test.yaml ) 12 | fi 13 | 14 | echo "========================================================================" 15 | echo "Configs to run:" 16 | echo "${CFG_LIST[@]}" 17 | echo "========================================================================" 18 | 19 | 20 | for cfg in "${CFG_LIST[@]}"; do 21 | echo "========================================================================" 22 | echo "Running $cfg ..." 23 | echo "========================================================================" 24 | $BIN \ 25 | --eval-only \ 26 | --num-gpus $NUM_GPUS \ 27 | --config-file "$cfg" \ 28 | OUTPUT_DIR $OUTPUT 29 | rm -rf $OUTPUT 30 | done 31 | 32 | 33 | echo "========================================================================" 34 | echo "Running demo.py ..." 35 | echo "========================================================================" 36 | DEMO_BIN="python demo/demo.py" 37 | COCO_DIR=datasets/coco/val2014 38 | mkdir -pv $OUTPUT 39 | 40 | set -v 41 | 42 | $DEMO_BIN --config-file ./configs/quick_schedules/panoptic_fpn_R_50_inference_acc_test.yaml \ 43 | --input $COCO_DIR/COCO_val2014_0000001933* --output $OUTPUT 44 | rm -rf $OUTPUT 45 | -------------------------------------------------------------------------------- /SPViT_Swin/dev/run_instant_tests.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -e 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | BIN="python tools/train_net.py" 5 | OUTPUT="instant_test_output" 6 | NUM_GPUS=2 7 | 8 | CFG_LIST=( "${@:1}" ) 9 | if [ ${#CFG_LIST[@]} -eq 0 ]; then 10 | CFG_LIST=( ./configs/quick_schedules/*instant_test.yaml ) 11 | fi 12 | 13 | echo "========================================================================" 14 | echo "Configs to run:" 15 | echo "${CFG_LIST[@]}" 16 | echo "========================================================================" 17 | 18 | for cfg in "${CFG_LIST[@]}"; do 19 | echo "========================================================================" 20 | echo "Running $cfg ..." 21 | echo "========================================================================" 22 | $BIN --num-gpus $NUM_GPUS --config-file "$cfg" \ 23 | SOLVER.IMS_PER_BATCH $(($NUM_GPUS * 2)) \ 24 | OUTPUT_DIR "$OUTPUT" 25 | rm -rf "$OUTPUT" 26 | done 27 | 28 | -------------------------------------------------------------------------------- /SPViT_Swin/ffn_indicators/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziplab/SPViT/ae19b0dafae3baf15c8eb4e817f4156386936866/SPViT_Swin/ffn_indicators/.DS_Store -------------------------------------------------------------------------------- /SPViT_Swin/ffn_indicators/spvit_swin_bs_l01_t100_search_20epoch.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziplab/SPViT/ae19b0dafae3baf15c8eb4e817f4156386936866/SPViT_Swin/ffn_indicators/spvit_swin_bs_l01_t100_search_20epoch.pth -------------------------------------------------------------------------------- /SPViT_Swin/ffn_indicators/spvit_swin_sm_l04_t55_search_14epoch.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziplab/SPViT/ae19b0dafae3baf15c8eb4e817f4156386936866/SPViT_Swin/ffn_indicators/spvit_swin_sm_l04_t55_search_14epoch.pth -------------------------------------------------------------------------------- /SPViT_Swin/ffn_indicators/spvit_swin_t_l28_t32_search_12epoch.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziplab/SPViT/ae19b0dafae3baf15c8eb4e817f4156386936866/SPViT_Swin/ffn_indicators/spvit_swin_t_l28_t32_search_12epoch.pth -------------------------------------------------------------------------------- /SPViT_Swin/logger.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | # Modifications copyright (c) 2021 Zhuang AI Group, Haoyu He 8 | 9 | import os 10 | import sys 11 | import logging 12 | import functools 13 | from termcolor import colored 14 | 15 | 16 | @functools.lru_cache() 17 | def create_logger(output_dir, dist_rank=0, name=''): 18 | # create logger 19 | logger = logging.getLogger(name) 20 | logger.setLevel(logging.DEBUG) 21 | logger.propagate = False 22 | 23 | # create formatter 24 | fmt = '[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s' 25 | color_fmt = colored('[%(asctime)s %(name)s]', 'green') + \ 26 | colored('(%(filename)s %(lineno)d)', 'yellow') + ': %(levelname)s %(message)s' 27 | 28 | # create console handlers for master process 29 | if dist_rank == 0: 30 | console_handler = logging.StreamHandler(sys.stdout) 31 | console_handler.setLevel(logging.DEBUG) 32 | console_handler.setFormatter( 33 | logging.Formatter(fmt=color_fmt, datefmt='%Y-%m-%d %H:%M:%S')) 34 | logger.addHandler(console_handler) 35 | 36 | # create file handlers 37 | file_handler = logging.FileHandler(os.path.join(output_dir, f'log_rank{dist_rank}.txt'), mode='a') 38 | file_handler.setLevel(logging.DEBUG) 39 | file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S')) 40 | logger.addHandler(file_handler) 41 | 42 | return logger 43 | -------------------------------------------------------------------------------- /SPViT_Swin/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | # Modifications copyright (c) 2021 Zhuang AI Group, Haoyu He 8 | 9 | import torch 10 | from timm.scheduler.cosine_lr import CosineLRScheduler 11 | from timm.scheduler.step_lr import StepLRScheduler 12 | from timm.scheduler.scheduler import Scheduler 13 | 14 | 15 | def build_scheduler(config, optimizer, n_iter_per_epoch, num_steps=None, warmup_steps=None, min_lr=None): 16 | if num_steps is None: 17 | num_steps = int(config.TRAIN.EPOCHS * n_iter_per_epoch) 18 | if warmup_steps is None: 19 | warmup_steps = int(config.TRAIN.WARMUP_EPOCHS * n_iter_per_epoch) 20 | if min_lr is None: 21 | min_lr = config.TRAIN.MIN_LR 22 | decay_steps = int(config.TRAIN.LR_SCHEDULER.DECAY_EPOCHS * n_iter_per_epoch) 23 | 24 | lr_scheduler = None 25 | if config.TRAIN.LR_SCHEDULER.NAME == 'cosine': 26 | lr_scheduler = CosineLRScheduler( 27 | optimizer, 28 | t_initial=num_steps, 29 | t_mul=1., 30 | lr_min=min_lr, 31 | warmup_lr_init=config.TRAIN.WARMUP_LR, 32 | warmup_t=warmup_steps, 33 | cycle_limit=1, 34 | t_in_epochs=False, 35 | ) 36 | elif config.TRAIN.LR_SCHEDULER.NAME == 'linear': 37 | lr_scheduler = LinearLRScheduler( 38 | optimizer, 39 | t_initial=num_steps, 40 | lr_min_rate=0.01, 41 | warmup_lr_init=config.TRAIN.WARMUP_LR, 42 | warmup_t=warmup_steps, 43 | t_in_epochs=False, 44 | ) 45 | elif config.TRAIN.LR_SCHEDULER.NAME == 'step': 46 | lr_scheduler = StepLRScheduler( 47 | optimizer, 48 | decay_t=decay_steps, 49 | decay_rate=config.TRAIN.LR_SCHEDULER.DECAY_RATE, 50 | warmup_lr_init=config.TRAIN.WARMUP_LR, 51 | warmup_t=warmup_steps, 52 | t_in_epochs=False, 53 | ) 54 | 55 | return lr_scheduler 56 | 57 | 58 | class LinearLRScheduler(Scheduler): 59 | def __init__(self, 60 | optimizer: torch.optim.Optimizer, 61 | t_initial: int, 62 | lr_min_rate: float, 63 | warmup_t=0, 64 | warmup_lr_init=0., 65 | t_in_epochs=True, 66 | noise_range_t=None, 67 | noise_pct=0.67, 68 | noise_std=1.0, 69 | noise_seed=42, 70 | initialize=True, 71 | ) -> None: 72 | super().__init__( 73 | optimizer, param_group_field="lr", 74 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 75 | initialize=initialize) 76 | 77 | self.t_initial = t_initial 78 | self.lr_min_rate = lr_min_rate 79 | self.warmup_t = warmup_t 80 | self.warmup_lr_init = warmup_lr_init 81 | self.t_in_epochs = t_in_epochs 82 | if self.warmup_t: 83 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 84 | super().update_groups(self.warmup_lr_init) 85 | else: 86 | self.warmup_steps = [1 for _ in self.base_values] 87 | 88 | def _get_lr(self, t): 89 | if t < self.warmup_t: 90 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 91 | else: 92 | t = t - self.warmup_t 93 | total_t = self.t_initial - self.warmup_t 94 | lrs = [v - ((v - v * self.lr_min_rate) * (t / total_t)) for v in self.base_values] 95 | return lrs 96 | 97 | def get_epoch_values(self, epoch: int): 98 | if self.t_in_epochs: 99 | return self._get_lr(epoch) 100 | else: 101 | return None 102 | 103 | def get_update_values(self, num_updates: int): 104 | if not self.t_in_epochs: 105 | return self._get_lr(num_updates) 106 | else: 107 | return None 108 | -------------------------------------------------------------------------------- /SPViT_Swin/main.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import time 10 | import argparse 11 | import datetime 12 | import numpy as np 13 | 14 | import torch 15 | import torch.backends.cudnn as cudnn 16 | import torch.distributed as dist 17 | 18 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 19 | from timm.utils import accuracy, AverageMeter 20 | 21 | from config import get_config 22 | from models import build_model 23 | from data import build_loader 24 | from lr_scheduler import build_scheduler 25 | from optimizer import build_optimizer 26 | from logger import create_logger 27 | from utils import load_checkpoint, save_checkpoint, get_grad_norm, auto_resume_helper, reduce_tensor 28 | 29 | try: 30 | # noinspection PyUnresolvedReferences 31 | from apex import amp 32 | except ImportError: 33 | amp = None 34 | 35 | 36 | def parse_option(): 37 | parser = argparse.ArgumentParser('Swin Transformer training and evaluation script', add_help=False) 38 | parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', ) 39 | parser.add_argument( 40 | "--opts", 41 | help="Modify config options by adding 'KEY VALUE' pairs. ", 42 | default=None, 43 | nargs='+', 44 | ) 45 | 46 | # easy config modification 47 | parser.add_argument('--batch-size', type=int, help="batch size for single GPU") 48 | parser.add_argument('--data-path', type=str, help='path to dataset') 49 | parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset') 50 | parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'], 51 | help='no: no cache, ' 52 | 'full: cache all data, ' 53 | 'part: sharding the dataset into nonoverlapping pieces and only cache one piece') 54 | parser.add_argument('--resume', help='resume from checkpoint') 55 | parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps") 56 | parser.add_argument('--use-checkpoint', action='store_true', 57 | help="whether to use gradient checkpointing to save memory") 58 | parser.add_argument('--amp-opt-level', type=str, default='O1', choices=['O0', 'O1', 'O2'], 59 | help='mixed precision opt level, if O0, no amp is used') 60 | parser.add_argument('--output', default='output', type=str, metavar='PATH', 61 | help='root of output folder, the full path is // (default: output)') 62 | parser.add_argument('--tag', help='tag of experiment') 63 | parser.add_argument('--eval', action='store_true', help='Perform evaluation only') 64 | parser.add_argument('--throughput', action='store_true', help='Test throughput only') 65 | 66 | # distributed training 67 | parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel') 68 | 69 | args, unparsed = parser.parse_known_args() 70 | 71 | config = get_config(args) 72 | 73 | return args, config 74 | 75 | 76 | def main(config): 77 | dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config) 78 | 79 | logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}") 80 | model = build_model(config) 81 | model.cuda() 82 | logger.info(str(model)) 83 | 84 | optimizer = build_optimizer(config, model) 85 | if config.AMP_OPT_LEVEL != "O0": 86 | model, optimizer = amp.initialize(model, optimizer, opt_level=config.AMP_OPT_LEVEL) 87 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False) 88 | model_without_ddp = model.module 89 | 90 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 91 | logger.info(f"number of params: {n_parameters}") 92 | if hasattr(model_without_ddp, 'flops'): 93 | flops = model_without_ddp.flops() 94 | logger.info(f"number of GFLOPs: {flops / 1e9}") 95 | 96 | lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train)) 97 | 98 | if config.AUG.MIXUP > 0.: 99 | # smoothing is handled with mixup label transform 100 | criterion = SoftTargetCrossEntropy() 101 | elif config.MODEL.LABEL_SMOOTHING > 0.: 102 | criterion = LabelSmoothingCrossEntropy(smoothing=config.MODEL.LABEL_SMOOTHING) 103 | else: 104 | criterion = torch.nn.CrossEntropyLoss() 105 | 106 | max_accuracy = 0.0 107 | 108 | if config.TRAIN.AUTO_RESUME: 109 | resume_file = auto_resume_helper(config.OUTPUT) 110 | if resume_file: 111 | if config.MODEL.RESUME: 112 | logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}") 113 | config.defrost() 114 | config.MODEL.RESUME = resume_file 115 | config.freeze() 116 | logger.info(f'auto resuming from {resume_file}') 117 | else: 118 | logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume') 119 | 120 | if config.MODEL.RESUME: 121 | max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, logger) 122 | acc1, acc5, loss = validate(config, data_loader_val, model) 123 | logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%") 124 | if config.EVAL_MODE: 125 | return 126 | 127 | if config.THROUGHPUT_MODE: 128 | throughput(data_loader_val, model, logger) 129 | return 130 | 131 | logger.info("Start training") 132 | start_time = time.time() 133 | for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS): 134 | data_loader_train.sampler.set_epoch(epoch) 135 | 136 | train_one_epoch(config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler) 137 | if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)): 138 | save_checkpoint(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, logger) 139 | 140 | acc1, acc5, loss = validate(config, data_loader_val, model) 141 | logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%") 142 | max_accuracy = max(max_accuracy, acc1) 143 | logger.info(f'Max accuracy: {max_accuracy:.2f}%') 144 | 145 | total_time = time.time() - start_time 146 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 147 | logger.info('Training time {}'.format(total_time_str)) 148 | 149 | 150 | def train_one_epoch(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler): 151 | model.train() 152 | optimizer.zero_grad() 153 | 154 | num_steps = len(data_loader) 155 | batch_time = AverageMeter() 156 | loss_meter = AverageMeter() 157 | norm_meter = AverageMeter() 158 | 159 | start = time.time() 160 | end = time.time() 161 | for idx, (samples, targets) in enumerate(data_loader): 162 | samples = samples.cuda(non_blocking=True) 163 | targets = targets.cuda(non_blocking=True) 164 | 165 | if mixup_fn is not None: 166 | samples, targets = mixup_fn(samples, targets) 167 | 168 | outputs = model(samples) 169 | 170 | if config.TRAIN.ACCUMULATION_STEPS > 1: 171 | loss = criterion(outputs, targets) 172 | loss = loss / config.TRAIN.ACCUMULATION_STEPS 173 | if config.AMP_OPT_LEVEL != "O0": 174 | with amp.scale_loss(loss, optimizer) as scaled_loss: 175 | scaled_loss.backward() 176 | if config.TRAIN.CLIP_GRAD: 177 | grad_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), config.TRAIN.CLIP_GRAD) 178 | else: 179 | grad_norm = get_grad_norm(amp.master_params(optimizer)) 180 | else: 181 | loss.backward() 182 | if config.TRAIN.CLIP_GRAD: 183 | grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD) 184 | else: 185 | grad_norm = get_grad_norm(model.parameters()) 186 | if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0: 187 | optimizer.step() 188 | optimizer.zero_grad() 189 | lr_scheduler.step_update(epoch * num_steps + idx) 190 | else: 191 | loss = criterion(outputs, targets) 192 | optimizer.zero_grad() 193 | if config.AMP_OPT_LEVEL != "O0": 194 | with amp.scale_loss(loss, optimizer) as scaled_loss: 195 | scaled_loss.backward() 196 | if config.TRAIN.CLIP_GRAD: 197 | grad_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), config.TRAIN.CLIP_GRAD) 198 | else: 199 | grad_norm = get_grad_norm(amp.master_params(optimizer)) 200 | else: 201 | loss.backward() 202 | if config.TRAIN.CLIP_GRAD: 203 | grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD) 204 | else: 205 | grad_norm = get_grad_norm(model.parameters()) 206 | optimizer.step() 207 | lr_scheduler.step_update(epoch * num_steps + idx) 208 | 209 | torch.cuda.synchronize() 210 | 211 | loss_meter.update(loss.item(), targets.size(0)) 212 | norm_meter.update(grad_norm) 213 | batch_time.update(time.time() - end) 214 | end = time.time() 215 | 216 | if idx % config.PRINT_FREQ == 0: 217 | lr = optimizer.param_groups[0]['lr'] 218 | memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) 219 | etas = batch_time.avg * (num_steps - idx) 220 | logger.info( 221 | f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t' 222 | f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t' 223 | f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t' 224 | f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' 225 | f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t' 226 | f'mem {memory_used:.0f}MB') 227 | epoch_time = time.time() - start 228 | logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}") 229 | 230 | 231 | @torch.no_grad() 232 | def validate(config, data_loader, model): 233 | criterion = torch.nn.CrossEntropyLoss() 234 | model.eval() 235 | 236 | batch_time = AverageMeter() 237 | loss_meter = AverageMeter() 238 | acc1_meter = AverageMeter() 239 | acc5_meter = AverageMeter() 240 | 241 | end = time.time() 242 | for idx, (images, target) in enumerate(data_loader): 243 | images = images.cuda(non_blocking=True) 244 | target = target.cuda(non_blocking=True) 245 | 246 | # compute output 247 | output = model(images) 248 | 249 | # measure accuracy and record loss 250 | loss = criterion(output, target) 251 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 252 | 253 | acc1 = reduce_tensor(acc1) 254 | acc5 = reduce_tensor(acc5) 255 | loss = reduce_tensor(loss) 256 | 257 | loss_meter.update(loss.item(), target.size(0)) 258 | acc1_meter.update(acc1.item(), target.size(0)) 259 | acc5_meter.update(acc5.item(), target.size(0)) 260 | 261 | # measure elapsed time 262 | batch_time.update(time.time() - end) 263 | end = time.time() 264 | 265 | if idx % config.PRINT_FREQ == 0: 266 | memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) 267 | logger.info( 268 | f'Test: [{idx}/{len(data_loader)}]\t' 269 | f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 270 | f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' 271 | f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t' 272 | f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t' 273 | f'Mem {memory_used:.0f}MB') 274 | logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}') 275 | return acc1_meter.avg, acc5_meter.avg, loss_meter.avg 276 | 277 | 278 | @torch.no_grad() 279 | def throughput(data_loader, model, logger): 280 | model.eval() 281 | 282 | for idx, (images, _) in enumerate(data_loader): 283 | images = images.cuda(non_blocking=True) 284 | batch_size = images.shape[0] 285 | for i in range(50): 286 | model(images) 287 | torch.cuda.synchronize() 288 | logger.info(f"throughput averaged with 30 times") 289 | tic1 = time.time() 290 | for i in range(30): 291 | model(images) 292 | torch.cuda.synchronize() 293 | tic2 = time.time() 294 | logger.info(f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}") 295 | return 296 | 297 | 298 | if __name__ == '__main__': 299 | _, config = parse_option() 300 | 301 | if config.AMP_OPT_LEVEL != "O0": 302 | assert amp is not None, "amp not installed!" 303 | 304 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 305 | rank = int(os.environ["RANK"]) 306 | world_size = int(os.environ['WORLD_SIZE']) 307 | print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}") 308 | else: 309 | rank = -1 310 | world_size = -1 311 | torch.cuda.set_device(config.LOCAL_RANK) 312 | torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank) 313 | torch.distributed.barrier() 314 | 315 | seed = config.SEED + dist.get_rank() 316 | torch.manual_seed(seed) 317 | np.random.seed(seed) 318 | cudnn.benchmark = True 319 | 320 | # linear scale the learning rate according to total batch size, may not be optimal 321 | linear_scaled_lr = config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 322 | linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 323 | linear_scaled_min_lr = config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 324 | # gradient accumulation also need to scale the learning rate 325 | if config.TRAIN.ACCUMULATION_STEPS > 1: 326 | linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUMULATION_STEPS 327 | linear_scaled_warmup_lr = linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS 328 | linear_scaled_min_lr = linear_scaled_min_lr * config.TRAIN.ACCUMULATION_STEPS 329 | config.defrost() 330 | config.TRAIN.BASE_LR = linear_scaled_lr 331 | config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr 332 | config.TRAIN.MIN_LR = linear_scaled_min_lr 333 | config.freeze() 334 | 335 | os.makedirs(config.OUTPUT, exist_ok=True) 336 | logger = create_logger(output_dir=config.OUTPUT, dist_rank=dist.get_rank(), name=f"{config.MODEL.NAME}") 337 | 338 | if dist.get_rank() == 0: 339 | path = os.path.join(config.OUTPUT, "config.json") 340 | with open(path, "w") as f: 341 | f.write(config.dump()) 342 | logger.info(f"Full config saved to {path}") 343 | 344 | # print config 345 | logger.info(config.dump()) 346 | 347 | main(config) -------------------------------------------------------------------------------- /SPViT_Swin/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_model -------------------------------------------------------------------------------- /SPViT_Swin/models/build.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | # Modifications copyright (c) 2021 Zhuang AI Group, Haoyu He 8 | 9 | from .spvit_swin import * 10 | 11 | 12 | def build_model(config, ffn_indicators=None): 13 | model_type = config.MODEL.TYPE 14 | if model_type == 'swin': 15 | model = SwinTransformer(img_size=config.DATA.IMG_SIZE, 16 | patch_size=config.MODEL.SWIN.PATCH_SIZE, 17 | in_chans=config.MODEL.SWIN.IN_CHANS, 18 | num_classes=config.MODEL.NUM_CLASSES, 19 | embed_dim=config.MODEL.SWIN.EMBED_DIM, 20 | depths=config.MODEL.SWIN.DEPTHS, 21 | num_heads=config.MODEL.SWIN.NUM_HEADS, 22 | window_size=config.MODEL.SWIN.WINDOW_SIZE, 23 | mlp_ratio=config.MODEL.SWIN.MLP_RATIO, 24 | qkv_bias=config.MODEL.SWIN.QKV_BIAS, 25 | qk_scale=config.MODEL.SWIN.QK_SCALE, 26 | drop_rate=config.MODEL.DROP_RATE, 27 | drop_path_rate=config.MODEL.DROP_PATH_RATE, 28 | ape=config.MODEL.SWIN.APE, 29 | patch_norm=config.MODEL.SWIN.PATCH_NORM, 30 | use_checkpoint=config.TRAIN.USE_CHECKPOINT) 31 | elif model_type == 'spvit_swin': 32 | model = SPVisionTransformerSwin(img_size=config.DATA.IMG_SIZE, 33 | patch_size=config.MODEL.SWIN.PATCH_SIZE, 34 | in_chans=config.MODEL.SWIN.IN_CHANS, 35 | num_classes=config.MODEL.NUM_CLASSES, 36 | embed_dim=config.MODEL.SWIN.EMBED_DIM, 37 | depths=config.MODEL.SWIN.DEPTHS, 38 | num_heads=config.MODEL.SWIN.NUM_HEADS, 39 | window_size=config.MODEL.SWIN.WINDOW_SIZE, 40 | mlp_ratio=config.MODEL.SWIN.MLP_RATIO, 41 | qkv_bias=config.MODEL.SWIN.QKV_BIAS, 42 | qk_scale=config.MODEL.SWIN.QK_SCALE, 43 | drop_rate=config.MODEL.DROP_RATE, 44 | drop_path_rate=config.MODEL.DROP_PATH_RATE, 45 | ape=config.MODEL.SWIN.APE, 46 | patch_norm=config.MODEL.SWIN.PATCH_NORM, 47 | use_checkpoint=config.TRAIN.USE_CHECKPOINT, 48 | loss_lambda=config.EXTRA.loss_lambda, 49 | target_flops=config.EXTRA.target_flops, 50 | alpha=config.EXTRA.alpha, 51 | theta=config.EXTRA.theta, 52 | msa_indicators=config.EXTRA.assigned_indicators, 53 | ffn_indicators=ffn_indicators, 54 | attention_type=config.EXTRA.attention_type 55 | ) 56 | else: 57 | raise NotImplementedError(f"Unkown model: {model_type}") 58 | 59 | return model -------------------------------------------------------------------------------- /SPViT_Swin/models/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.ops.boxes import box_area 3 | 4 | 5 | def box_cxcywh_to_xyxy(x): 6 | x_c, y_c, w, h = x.unbind(-1) 7 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), 8 | (x_c + 0.5 * w), (y_c + 0.5 * h)] 9 | return torch.stack(b, dim=-1) 10 | 11 | 12 | def box_xyxy_to_cxcywh(x): 13 | x0, y0, x1, y1 = x.unbind(-1) 14 | b = [(x0 + x1) / 2, (y0 + y1) / 2, 15 | (x1 - x0), (y1 - y0)] 16 | return torch.stack(b, dim=-1) 17 | -------------------------------------------------------------------------------- /SPViT_Swin/optimizer.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | # Modifications copyright (c) 2021 Zhuang AI Group, Haoyu He 8 | 9 | from torch import optim as optim 10 | 11 | 12 | def build_optimizer(config, model): 13 | """ 14 | Build optimizer, set weight decay of normalization to 0 by default. 15 | """ 16 | skip = {} 17 | skip_keywords = {} 18 | if hasattr(model, 'no_weight_decay'): 19 | skip = model.no_weight_decay() 20 | if hasattr(model, 'no_weight_decay_keywords'): 21 | skip_keywords = model.no_weight_decay_keywords() 22 | parameters = set_weight_decay(model, skip, skip_keywords) 23 | 24 | opt_lower = config.TRAIN.OPTIMIZER.NAME.lower() 25 | optimizer = None 26 | if opt_lower == 'sgd': 27 | optimizer = optim.SGD(parameters, momentum=config.TRAIN.OPTIMIZER.MOMENTUM, nesterov=True, 28 | lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) 29 | elif opt_lower == 'adamw': 30 | optimizer = optim.AdamW(parameters, eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS, 31 | lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) 32 | 33 | return optimizer 34 | 35 | 36 | def build_our_optimizer_2params(config, model): 37 | """ 38 | Build optimizer, set weight decay of normalization to 0 by default. 39 | """ 40 | skip = {} 41 | skip_keywords = {} 42 | if hasattr(model, 'no_weight_decay'): 43 | skip = model.no_weight_decay() 44 | if hasattr(model, 'no_weight_decay_keywords'): 45 | skip_keywords = model.no_weight_decay_keywords() 46 | parameters1, parameters2 = set_weight_decay_and_lr_2parameters(model, skip, skip_keywords, config.TRAIN.BASE_LR, config.EXTRA.architecture_lr) 47 | 48 | opt_lower = config.TRAIN.OPTIMIZER.NAME.lower() 49 | optimizer1 = None 50 | optimizer2 = None 51 | 52 | if opt_lower == 'sgd': 53 | optimizer1 = optim.SGD(parameters1, momentum=config.TRAIN.OPTIMIZER.MOMENTUM, nesterov=True, 54 | lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) 55 | elif opt_lower == 'adamw': 56 | optimizer1 = optim.AdamW(parameters1, eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS, 57 | lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) 58 | 59 | optimizer2 = optim.AdamW(parameters2, eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS, 60 | lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) 61 | 62 | return optimizer1, optimizer2 63 | 64 | 65 | def build_our_optimizer(config, model): 66 | """ 67 | Build optimizer, set weight decay of normalization to 0 by default. 68 | """ 69 | skip = {} 70 | skip_keywords = {} 71 | if hasattr(model, 'no_weight_decay'): 72 | skip = model.no_weight_decay() 73 | if hasattr(model, 'no_weight_decay_keywords'): 74 | skip_keywords = model.no_weight_decay_keywords() 75 | parameters = set_weight_decay_and_lr(model, skip, skip_keywords, config.TRAIN.BASE_LR, config.EXTRA.architecture_lr) 76 | 77 | opt_lower = config.TRAIN.OPTIMIZER.NAME.lower() 78 | 79 | if opt_lower == 'sgd': 80 | optimizer = optim.SGD(parameters, momentum=config.TRAIN.OPTIMIZER.MOMENTUM, nesterov=True, 81 | lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) 82 | elif opt_lower == 'adamw': 83 | 84 | optimizer = optim.AdamW(parameters, eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS, 85 | lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) 86 | 87 | return optimizer 88 | 89 | 90 | def set_weight_decay(model, skip_list=(), skip_keywords=(), small_decay_num=0.0): 91 | has_decay = [] 92 | no_decay = [] 93 | small_decay = [] 94 | 95 | for name, param in model.named_parameters(): 96 | if not param.requires_grad: 97 | continue # frozen weights 98 | if (len(param.shape) == 1 and 'thresholds' not in name) or name.endswith(".bias") or (name in skip_list) or \ 99 | check_keywords_in_name(name, skip_keywords): 100 | no_decay.append(param) 101 | # print(f"{name} has no weight decay") 102 | elif 'thresholds' in name: 103 | small_decay.append(param) 104 | else: 105 | has_decay.append(param) 106 | 107 | return [{'params': has_decay}, 108 | {'params': no_decay, 'weight_decay': 0.}, 109 | {'params': small_decay, 'weight_decay': small_decay_num}] 110 | 111 | 112 | def set_weight_decay_and_lr_2parameters(model, skip_list=(), skip_keywords=(), regular_lr=None, diff_lr=None): 113 | 114 | has_decay = [] 115 | no_decay = [] 116 | has_diff_lr = [] 117 | 118 | for name, param in model.named_parameters(): 119 | if not param.requires_grad: 120 | continue # frozen weights 121 | if 'thresholds' in name: 122 | has_diff_lr.append(param) 123 | else: 124 | if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \ 125 | check_keywords_in_name(name, skip_keywords): 126 | no_decay.append(param) 127 | # print(f"{name} has no weight decay") 128 | else: 129 | has_decay.append(param) 130 | 131 | return [{'params': has_decay}, 132 | {'params': no_decay, 'weight_decay': 0.}], \ 133 | [{'params': has_diff_lr, 'lr': diff_lr, 'weight_decay': 0.}] 134 | 135 | 136 | def set_weight_decay_and_lr(model, skip_list=(), skip_keywords=(), regular_lr=None, diff_lr=None): 137 | 138 | has_decay = [] 139 | no_decay = [] 140 | has_diff_lr = [] 141 | 142 | for name, param in model.named_parameters(): 143 | if not param.requires_grad: 144 | continue # frozen weights 145 | if 'thresholds' in name: 146 | has_diff_lr.append(param) 147 | else: 148 | if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \ 149 | check_keywords_in_name(name, skip_keywords): 150 | no_decay.append(param) 151 | # print(f"{name} has no weight decay") 152 | else: 153 | has_decay.append(param) 154 | 155 | return [{'params': has_decay}, 156 | {'params': no_decay, 'weight_decay': 0.}, 157 | {'params': has_diff_lr, 'lr': diff_lr, 'weight_decay': 0.}] 158 | 159 | 160 | def check_keywords_in_name(name, keywords=()): 161 | isin = False 162 | for keyword in keywords: 163 | if keyword in name: 164 | isin = True 165 | return isin 166 | -------------------------------------------------------------------------------- /SPViT_Swin/post_training_optimize_checkpoint.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | from collections import OrderedDict 4 | import ast 5 | 6 | 7 | def main(): 8 | 9 | # Optimize the number of parameters of a checkpoint 10 | 11 | if len(sys.argv) != 3: 12 | print('Error: Two input arguments, checkpoint_path and searched MSA architecture in a list.') 13 | return 14 | 15 | checkpoint_path = sys.argv[1] 16 | state_dict = torch.load(checkpoint_path, map_location='cpu')['model'] 17 | 18 | try: 19 | # Use ast.literal_eval to safely parse the string as a list 20 | MSA_indicators = ast.literal_eval(sys.argv[2]) 21 | 22 | if not isinstance(MSA_indicators, list): 23 | raise ValueError("The provided parameter is not a valid list.") 24 | 25 | # Now you have the list parameter 26 | print(f"List Parameter: {MSA_indicators}") 27 | 28 | except (ValueError, SyntaxError) as e: 29 | print(f"Invalid MSA indicators: {e}") 30 | 31 | new_dict = OrderedDict() 32 | 33 | if any('bconv' in key for key in list(state_dict.keys())): 34 | print('Error: The checkpoint is already optimized!') 35 | return 36 | 37 | for k, v in state_dict.items(): 38 | if 'head_probs' in k: 39 | block_name = k.replace('head_probs', '') 40 | block_num = int(k.split('.')[1]) 41 | head_probs = (state_dict[k] / 1e-2).softmax(0) 42 | num_heads = head_probs.shape[0] 43 | feature_dim = state_dict[block_name + 'v.weight'].shape[0] 44 | head_dim = feature_dim // num_heads 45 | 46 | if MSA_indicators[block_num][-1] == 1: 47 | print('Error: checkpoint and MSA indicators do not match!') 48 | return 49 | 50 | new_v_weight = state_dict[block_name + 'v.weight'].view(num_heads, head_dim, feature_dim).permute(1, 2, 0) @ head_probs 51 | new_v_bias = state_dict[block_name + 'v.bias'].view(num_heads, head_dim).permute(1, 0) @ head_probs 52 | new_proj_weight = state_dict[block_name + 'proj.weight'].view(feature_dim, num_heads, head_dim).permute(0, 2, 1) @ head_probs 53 | 54 | if MSA_indicators[block_num][1] == 1: 55 | bn_name = 'bn_3x3.' 56 | new_dict[block_name + 'bconv.0.weight'] = new_v_weight.permute(2, 0, 1).view(3, 3, head_dim, -1).permute(2, 3, 1, 0) 57 | new_dict[block_name + 'bconv.0.bias'] = new_v_bias.sum(-1) 58 | new_dict[block_name + 'bconv.3.weight'] = new_proj_weight.sum(-1)[..., None, None] 59 | else: 60 | bn_name = 'bn_1x1.' 61 | new_dict[block_name + 'bconv.0.weight'] = new_v_weight[..., 4][..., None, None] 62 | new_dict[block_name + 'bconv.0.bias'] = new_v_bias[..., 4] 63 | new_dict[block_name + 'bconv.3.weight'] = new_proj_weight[..., 4][..., None, None] 64 | 65 | new_dict[block_name + 'bconv.3.bias'] = state_dict[block_name + 'proj.bias'] 66 | 67 | new_dict[block_name + 'bconv.1.weight'] = state_dict[block_name + bn_name + 'weight'] 68 | new_dict[block_name + 'bconv.1.bias'] = state_dict[block_name + bn_name + 'bias'] 69 | new_dict[block_name + 'bconv.1.running_mean'] = state_dict[block_name + bn_name + 'running_mean'] 70 | new_dict[block_name + 'bconv.1.running_var'] = state_dict[block_name + bn_name + 'running_var'] 71 | new_dict[block_name + 'bconv.1.num_batches_tracked'] = state_dict[block_name + bn_name + 'num_batches_tracked'] 72 | 73 | else: 74 | if len(k.split('.')) <= 4 or '.'.join(k.split('.')[:-2]) + '.head_probs' not in state_dict.keys(): 75 | new_dict[k] = state_dict[k] 76 | 77 | print(f'Save to ' + '.'.join(checkpoint_path.split('.')[:-1]) + '_optimized.pth') 78 | torch.save({'model': new_dict}, '.'.join(checkpoint_path.split('.')[:-1]) + '_optimized.pth') 79 | 80 | 81 | if __name__ == '__main__': 82 | main() 83 | -------------------------------------------------------------------------------- /SPViT_Swin/requirements.txt: -------------------------------------------------------------------------------- 1 | fvcore -------------------------------------------------------------------------------- /SPViT_Swin/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | # Modifications copyright (c) 2021 Zhuang AI Group, Haoyu He 4 | 5 | import glob 6 | import os 7 | import shutil 8 | from os import path 9 | from setuptools import find_packages, setup 10 | from typing import List 11 | import torch 12 | from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension 13 | from torch.utils.hipify import hipify_python 14 | 15 | torch_ver = [int(x) for x in torch.__version__.split(".")[:2]] 16 | assert torch_ver >= [1, 4], "Requires PyTorch >= 1.4" 17 | 18 | 19 | def get_version(): 20 | init_py_path = path.join(path.abspath(path.dirname(__file__)), "detectron2", "__init__.py") 21 | init_py = open(init_py_path, "r").readlines() 22 | version_line = [l.strip() for l in init_py if l.startswith("__version__")][0] 23 | version = version_line.split("=")[-1].strip().strip("'\"") 24 | 25 | # The following is used to build release packages. 26 | # Users should never use it. 27 | suffix = os.getenv("D2_VERSION_SUFFIX", "") 28 | version = version + suffix 29 | if os.getenv("BUILD_NIGHTLY", "0") == "1": 30 | from datetime import datetime 31 | 32 | date_str = datetime.today().strftime("%y%m%d") 33 | version = version + ".dev" + date_str 34 | 35 | new_init_py = [l for l in init_py if not l.startswith("__version__")] 36 | new_init_py.append('__version__ = "{}"\n'.format(version)) 37 | with open(init_py_path, "w") as f: 38 | f.write("".join(new_init_py)) 39 | return version 40 | 41 | 42 | def get_extensions(): 43 | this_dir = path.dirname(path.abspath(__file__)) 44 | extensions_dir = path.join(this_dir, "detectron2", "layers", "csrc") 45 | 46 | main_source = path.join(extensions_dir, "vision.cpp") 47 | sources = glob.glob(path.join(extensions_dir, "**", "*.cpp")) 48 | 49 | is_rocm_pytorch = False 50 | if torch_ver >= [1, 5]: 51 | from torch.utils.cpp_extension import ROCM_HOME 52 | 53 | is_rocm_pytorch = ( 54 | True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False 55 | ) 56 | 57 | if is_rocm_pytorch: 58 | hipify_python.hipify( 59 | project_directory=this_dir, 60 | output_directory=this_dir, 61 | includes="/detectron2/layers/csrc/*", 62 | show_detailed=True, 63 | is_pytorch_extension=True, 64 | ) 65 | 66 | # Current version of hipify function in pytorch creates an intermediate directory 67 | # named "hip" at the same level of the path hierarchy if a "cuda" directory exists, 68 | # or modifying the hierarchy, if it doesn't. Once pytorch supports 69 | # "same directory" hipification (PR pendeing), the source_cuda will be set 70 | # similarly in both cuda and hip paths, and the explicit header file copy 71 | # (below) will not be needed. 72 | source_cuda = glob.glob(path.join(extensions_dir, "**", "hip", "*.hip")) + glob.glob( 73 | path.join(extensions_dir, "hip", "*.hip") 74 | ) 75 | 76 | shutil.copy( 77 | "detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_utils.h", 78 | "detectron2/layers/csrc/box_iou_rotated/hip/box_iou_rotated_utils.h", 79 | ) 80 | shutil.copy( 81 | "detectron2/layers/csrc/deformable/deform_conv.h", 82 | "detectron2/layers/csrc/deformable/hip/deform_conv.h", 83 | ) 84 | 85 | else: 86 | source_cuda = glob.glob(path.join(extensions_dir, "**", "*.cu")) + glob.glob( 87 | path.join(extensions_dir, "*.cu") 88 | ) 89 | 90 | sources = [main_source] + sources 91 | 92 | extension = CppExtension 93 | 94 | extra_compile_args = {"cxx": []} 95 | define_macros = [] 96 | 97 | if (torch.cuda.is_available() and ((CUDA_HOME is not None) or is_rocm_pytorch)) or os.getenv( 98 | "FORCE_CUDA", "0" 99 | ) == "1": 100 | extension = CUDAExtension 101 | sources += source_cuda 102 | 103 | if not is_rocm_pytorch: 104 | define_macros += [("WITH_CUDA", None)] 105 | extra_compile_args["nvcc"] = [ 106 | "-DCUDA_HAS_FP16=1", 107 | "-D__CUDA_NO_HALF_OPERATORS__", 108 | "-D__CUDA_NO_HALF_CONVERSIONS__", 109 | "-D__CUDA_NO_HALF2_OPERATORS__", 110 | "-arch=sm_60", 111 | "-gencode=arch=compute_60,code=sm_60", 112 | "-gencode=arch=compute_61,code=sm_61", 113 | "-gencode=arch=compute_70,code=sm_70", 114 | "-gencode=arch=compute_75,code=sm_75", 115 | ] 116 | else: 117 | define_macros += [("WITH_HIP", None)] 118 | extra_compile_args["nvcc"] = [] 119 | 120 | # It's better if pytorch can do this by default .. 121 | CC = os.environ.get("CC", None) 122 | if CC is not None: 123 | extra_compile_args["nvcc"].append("-ccbin={}".format(CC)) 124 | 125 | include_dirs = [extensions_dir] 126 | 127 | ext_modules = [ 128 | extension( 129 | "detectron2._C", 130 | sources, 131 | include_dirs=include_dirs, 132 | define_macros=define_macros, 133 | extra_compile_args=extra_compile_args, 134 | ) 135 | ] 136 | 137 | return ext_modules 138 | 139 | 140 | def get_model_zoo_configs() -> List[str]: 141 | """ 142 | Return a list of configs to include in package for model zoo. Copy over these configs inside 143 | detectron2/model_zoo. 144 | """ 145 | 146 | # Use absolute paths while symlinking. 147 | source_configs_dir = path.join(path.dirname(path.realpath(__file__)), "configs") 148 | destination = path.join( 149 | path.dirname(path.realpath(__file__)), "detectron2", "model_zoo", "configs" 150 | ) 151 | # Symlink the config directory inside package to have a cleaner pip install. 152 | 153 | # Remove stale symlink/directory from a previous build. 154 | if path.exists(source_configs_dir): 155 | if path.islink(destination): 156 | os.unlink(destination) 157 | elif path.isdir(destination): 158 | shutil.rmtree(destination) 159 | 160 | if not path.exists(destination): 161 | try: 162 | os.symlink(source_configs_dir, destination) 163 | except OSError: 164 | # Fall back to copying if symlink fails: ex. on Windows. 165 | shutil.copytree(source_configs_dir, destination) 166 | 167 | config_paths = glob.glob("configs/**/*.yaml", recursive=True) 168 | return config_paths 169 | 170 | 171 | setup( 172 | name="detectron2", 173 | version=get_version(), 174 | author="FAIR", 175 | url="https://github.com/facebookresearch/detectron2", 176 | description="Detectron2 is FAIR's next-generation research " 177 | "platform for object detection and segmentation.", 178 | packages=find_packages(exclude=("configs", "tests*")), 179 | package_data={"detectron2.model_zoo": get_model_zoo_configs()}, 180 | python_requires=">=3.6", 181 | install_requires=[ 182 | "termcolor>=1.1", 183 | "Pillow>=7.0", # or use pillow-simd for better performance 184 | "yacs>=0.1.6", 185 | "tabulate", 186 | "cloudpickle", 187 | "matplotlib", 188 | "mock", 189 | "tqdm>4.29.0", 190 | "tensorboard", 191 | "fvcore>=0.1.1", 192 | "pycocotools>=2.0.1", 193 | "future", # used by caffe2 194 | "pydot", # used to save caffe2 SVGs 195 | ], 196 | extras_require={ 197 | "all": ["shapely", "psutil"], 198 | "dev": [ 199 | "flake8==3.8.1", 200 | "isort", 201 | "black @ git+https://github.com/psf/black@673327449f86fce558adde153bb6cbe54bfebad2", 202 | "flake8-bugbear", 203 | "flake8-comprehensions", 204 | ], 205 | }, 206 | ext_modules=get_extensions(), 207 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, 208 | ) 209 | --------------------------------------------------------------------------------