├── .gitignore
├── LICENSE
├── README.md
├── SPViT_DeiT
├── .gitignore
├── README.md
├── config
│ ├── spvit_deit_bs_l006_t100_ft.json
│ ├── spvit_deit_bs_l006_t100_search.json
│ ├── spvit_deit_bs_l008_t60_ft.json
│ ├── spvit_deit_bs_l008_t60_ft_dist.json
│ ├── spvit_deit_bs_l008_t60_ft_param_opt.json
│ ├── spvit_deit_bs_l008_t60_search.json
│ ├── spvit_deit_sm_l30_t32_ft.json
│ ├── spvit_deit_sm_l30_t32_ft_dist.json
│ ├── spvit_deit_sm_l30_t32_search.json
│ ├── spvit_deit_ti_l200_t10_ft.json
│ ├── spvit_deit_ti_l200_t10_ft_dist.json
│ └── spvit_deit_ti_l200_t10_search.json
├── datasets.py
├── engine.py
├── ffn_indicators
│ ├── .DS_Store
│ ├── spvit_deit_bs_l006_t100_search_15epoch.pth
│ ├── spvit_deit_bs_l008_t60_search_10epoch.pth
│ ├── spvit_deit_sm_l30_t32_search_10epoch.pth
│ └── spvit_deit_ti_l200_t10_search_10epoch.pth
├── hubconf.py
├── logger.py
├── losses.py
├── main.py
├── main_pruning.py
├── models.py
├── models_pruning.py
├── params.py
├── post_training_optimize_checkpoint.py
├── requirements.txt
├── samplers.py
├── tox.ini
└── utils.py
└── SPViT_Swin
├── .gitignore
├── README.md
├── config.py
├── configs
├── spvit_swin_bs_l01_t100_ft.yaml
├── spvit_swin_bs_l01_t100_search.yaml
├── spvit_swin_sm_l04_t55_ft.yaml
├── spvit_swin_sm_l04_t55_ft_dist.yaml
├── spvit_swin_sm_l04_t55_search.yaml
├── spvit_swin_tn_l28_t32_ft.yaml
├── spvit_swin_tn_l28_t32_ft_dist.yaml
└── spvit_swin_tn_l28_t32_search.yaml
├── data
├── __init__.py
├── build.py
├── cached_image_folder.py
├── samplers.py
└── zipreader.py
├── dev
├── README.md
├── linter.sh
├── packaging
│ ├── README.md
│ ├── build_all_wheels.sh
│ ├── build_wheel.sh
│ ├── gen_install_table.py
│ ├── gen_wheel_index.sh
│ └── pkg_helpers.bash
├── parse_results.sh
├── run_inference_tests.sh
└── run_instant_tests.sh
├── ffn_indicators
├── .DS_Store
├── spvit_swin_bs_l01_t100_search_20epoch.pth
├── spvit_swin_sm_l04_t55_search_14epoch.pth
└── spvit_swin_t_l28_t32_search_12epoch.pth
├── logger.py
├── lr_scheduler.py
├── main.py
├── main_pruning.py
├── models
├── __init__.py
├── build.py
├── spvit_swin.py
└── utils.py
├── optimizer.py
├── post_training_optimize_checkpoint.py
├── requirements.txt
├── setup.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
[TPAMI 2024] Pruning Self-attentions into Convolutional Layers in Single Path
2 |
3 | **This is the official repository for our paper:** [Pruning Self-attentions into Convolutional Layers in Single Path](https://arxiv.org/abs/2111.11802) by [Haoyu He](https://charles-haoyuhe.github.io/), [Jianfei Cai](https://jianfei-cai.github.io/), [Jing liu](https://sites.google.com/view/jing-liu/%E9%A6%96%E9%A1%B5), [Zizheng Pan](https://zizhengpan.github.io/), [Jing Zhang](https://scholar.google.com/citations?user=9jH5v74AAAAJ&hl=en), [Dacheng Tao](https://www.sydney.edu.au/engineering/about/our-people/academic-staff/dacheng-tao.html) and [Bohan Zhuang](https://bohanzhuang.github.io/).
4 |
5 | ***
6 |
7 | >🚀 News
8 | >
9 | >[2023-12-29]: Accepted by TPAMI!
10 | >
11 | >[2023-06-09]: Update distillation configurations and pre-trained checkpoints.
12 | >
13 | >[2021-12-04]: Release pre-trained models.
14 | >
15 | >[2021-11-25]: Release code.
16 |
17 | ***
18 |
19 | ### Introduction:
20 |
21 | To reduce the massive computational resource consumption for ViTs and add convolutional inductive bias, **our SPViT prunes pre-trained ViT models into accurate and compact hybrid models by pruning self-attentions into convolutional layers**. Thanks to the proposed weight-sharing scheme between self-attention and convolutional layers that cast the search problem as finding which subset of parameters to use, our **SPViT has significantly reduced search cost**.
22 |
23 | ***
24 |
25 | ### Experimental results:
26 |
27 | We provide experimental results and pre-trained models for SPViT:
28 |
29 | | Name | Acc@1 | Acc@5 | # parameters | FLOPs | Model |
30 | | :------------ | :---: | :---: | ------------ | ----- | ------------------------------------------------------------ |
31 | | SPViT-DeiT-Ti | 70.7 | 90.3 | 4.9M | 1.0G | [Model](https://github.com/ziplab/SPViT/releases/download/1.0/spvit_deit_ti_l200_t10.pth) |
32 | | SPViT-DeiT-Ti* | 73.2 | 91.4 | 4.9M | 1.0G | [Model](https://github.com/ziplab/SPViT/releases/download/1.0/spvit_deit_ti_l200_t10_dist.pth) |
33 | | SPViT-DeiT-S | 78.3 | 94.3 | 16.4M | 3.3G | [Model](https://github.com/ziplab/SPViT/releases/download/1.0/spvit_deit_sm_l30_t32.pth) |
34 | | SPViT-DeiT-S* | 80.3 | 95.1 | 16.4M | 3.3G | [Model](https://github.com/ziplab/SPViT/releases/download/1.0/spvit_deit_sm_l30_t32_dist.pth) |
35 | | SPViT-DeiT-B | 81.5 | 95.7 | 46.2M | 8.3G | [Model](https://github.com/ziplab/SPViT/releases/download/1.0/spvit_deit_bs_l008_t60.pth) |
36 | | SPViT-DeiT-B* | 82.4 | 96.1 | 46.2M | 8.3G | [Model](https://github.com/ziplab/SPViT/releases/download/1.0/spvit_deit_bs_l008_t60_dist.pth) |
37 |
38 | | Name | Acc@1 | Acc@5 | # parameters | FLOPs | Model |
39 | | :------------ | :---: | :---: | ------------ | ----- | ------------------------------------------------------------ |
40 | | SPViT-Swin-Ti | 80.1 | 94.9 | 26.3M | 3.3G | [Model](https://github.com/ziplab/SPViT/releases/download/1.0/spvit_swin_t_l28_t32.pth) |
41 | | SPViT-Swin-Ti* | 81.0 | 95.3 | 26.3M | 3.3G | [Model](https://github.com/ziplab/SPViT/releases/download/1.0/spvit_swin_t_l28_t32_dist.pth) |
42 | | SPViT-Swin-S | 82.4 | 96.0 | 39.2M | 6.1G | [Model](https://github.com/ziplab/SPViT/releases/download/1.0/spvit_swin_sm_l04_t55.pth) |
43 | | SPViT-Swin-S* | 83.0 | 96.4 | 39.2M | 6.1G | [Model](https://github.com/ziplab/SPViT/releases/download/1.0/spvit_swin_sm_l04_t55_dist.pth) |
44 |
45 | * indicates knowledge distillation.
46 | ### Getting started:
47 |
48 | In this repository, we provide code for pruning two representative ViT models.
49 |
50 | - SPViT-DeiT that prunes [DeiT](https://github.com/facebookresearch/deit). Please see [SPViT_DeiT/README.md](SPViT_DeiT/README.md ) for details.
51 | - SPViT-Swin that prunes [Swin](https://github.com/microsoft/Swin-Transformer). Please see [SPViT_Swin/README.md](SPViT_Swin/README.md) for details.
52 |
53 | ***
54 |
55 | If you find our paper useful, please consider cite:
56 |
57 | ```
58 | @article{he2024Pruning,
59 | title={Pruning Self-attentions into Convolutional Layers in Single Path},
60 | author={He, Haoyu and Liu, Jing and Pan, Zizheng and Cai, Jianfei and Zhang, Jing and Tao, Dacheng and Zhuang, Bohan},
61 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
62 | year={2024},
63 | publisher={IEEE}
64 | }
65 |
66 | ```
67 |
68 |
--------------------------------------------------------------------------------
/SPViT_DeiT/.gitignore:
--------------------------------------------------------------------------------
1 | *.swp
2 | **/__pycache__/**
3 | imnet_resnet50_scratch/timm_temp/
4 | .dumbo.json
5 | checkpoints/
6 |
--------------------------------------------------------------------------------
/SPViT_DeiT/README.md:
--------------------------------------------------------------------------------
1 | ### Getting started on SPViT-DeiT:
2 |
3 | #### Installation and data preparation
4 |
5 | - First, you can install the required environments as illustrated in the [DeiT](https://github.com/facebookresearch/deit) repository or follow the instructions below:
6 |
7 | ```bash
8 | # Create virtual env
9 | conda create -n spvit-deit python=3.7 -y
10 | conda activate spvit-deit
11 |
12 | # Install PyTorch 1.7.0+ and torchvision 0.8.1+ and pytorch-image-models 0.3.2:
13 | conda install -c pytorch pytorch torchvision
14 | pip install timm==0.3.2
15 | ```
16 |
17 | - Next, install some other dependencies that are required by SPViT:
18 |
19 | ```bash
20 | pip install tensorboardX tensorboard
21 | ```
22 |
23 | - Please refer to the [DeiT](https://github.com/facebookresearch/deit) repository to prepare the standard ImageNet dataset, then link the ImageNet dataset under the `data`folder:
24 |
25 | ```bash
26 | $ tree data
27 | imagenet
28 | ├── train
29 | │ ├── class1
30 | │ │ ├── img1.jpeg
31 | │ │ ├── img2.jpeg
32 | │ │ └── ...
33 | │ ├── class2
34 | │ │ ├── img3.jpeg
35 | │ │ └── ...
36 | │ └── ...
37 | └── val
38 | ├── class1
39 | │ ├── img4.jpeg
40 | │ ├── img5.jpeg
41 | │ └── ...
42 | ├── class2
43 | │ ├── img6.jpeg
44 | │ └── ...
45 | └── ...
46 | ```
47 |
48 | #### Download pretrained models
49 |
50 | - We start searching and fine-tuneing both from the pre-trained models.
51 |
52 | - Since we provide training scripts for three DeiT models: DeiT-Ti, DeiT-S and DeiT-B, please download the corresponding three pre-trained models from the [DeiT](https://github.com/facebookresearch/deit) repository as well.
53 |
54 | - Next, move the downloaded pre-trained models into the following file structure:
55 |
56 | ```bash
57 | $ tree model
58 | ├── deit_base_patch16_224-b5f2ef4d.pth
59 | ├── deit_small_patch16_224-cd65a155.pth
60 | ├── deit_tiny_patch16_224-a1311bcf.pth
61 | ```
62 |
63 | - Note that do not change the filenames for the pre-trained models as we hard-coded these filenames when tailoring and loading the pre-trained models. Feel free to modify the hard-coded parts when pruning from other pre-trained models.
64 |
65 | #### Searching
66 |
67 | To search architectures with SPViT-DeiT-Ti, run:
68 |
69 | ```bash
70 | python -m torch.distributed.launch --nproc_per_node=4 --master_port=3146 --use_env main_pruning.py --config config/spvit_deit_ti_l200_t10_search.json
71 | ```
72 |
73 | To search architectures with SPViT-DeiT-S, run:
74 |
75 | ```bash
76 | python -m torch.distributed.launch --nproc_per_node=8 --master_port=3146 --use_env main_pruning.py --config config/spvit_deit_sm_l30_t32_search.json
77 | ```
78 |
79 | To search architectures with SPViT-DeiT-B, run:
80 |
81 | ```bash
82 | python -m torch.distributed.launch --nproc_per_node=8 --master_port=3146 --use_env main_pruning.py --config config/spvit_deit_bs_l006_t100_search.json
83 | ```
84 |
85 | #### Fine-tuning
86 |
87 | You can start fine-tuning from either your own searched architectures or from our provided architectures by modifying and assigning the MSA indicators in `assigned_indicators` and the FFN indicators in `searching_model`.
88 |
89 | To fine-tune the architectures searched by SPViT-DeiT-Ti, run:
90 |
91 | ```bash
92 | python -m torch.distributed.launch --nproc_per_node=4 --master_port=3146 --use_env main_pruning.py --config config/spvit_deit_ti_l200_t10_ft.json
93 | ```
94 |
95 | To fine-tune the architectures with SPViT-DeiT-S, run:
96 |
97 | ```bash
98 | python -m torch.distributed.launch --nproc_per_node=8 --master_port=3146 --use_env main_pruning.py --config config/spvit_deit_sm_l30_t32_ft.json
99 | ```
100 |
101 | To fine-tune the architectures with SPViT-DeiT-B, run:
102 |
103 | ```bash
104 | python -m torch.distributed.launch --nproc_per_node=8 --master_port=3146 --use_env main_pruning.py --config config/spvit_deit_bs_l006_t100_ft.json
105 | ```
106 |
107 | #### Evaluation
108 |
109 | We provide several examples for evaluating pre-trained SPViT models.
110 |
111 | To evaluate SPViT-DeiT-Ti pre-trained models, run:
112 |
113 | ```bash
114 | python -m torch.distributed.launch --nproc_per_node=1 --master_port=3146 --use_env main_pruning.py --config config/spvit_deit_ti_l200_t10_ft.json --resume [PRE-TRAINED MODEL PATH] --eval
115 | ```
116 |
117 | To evaluate SPViT-DeiT-S pre-trained models, run:
118 |
119 | ```bash
120 | python -m torch.distributed.launch --nproc_per_node=1 --master_port=3146 --use_env main_pruning.py --config config/spvit_deit_sm_l30_t32_ft.json --resume [PRE-TRAINED MODEL PATH] --eval
121 | ```
122 |
123 | To evaluate SPViT-DeiT-B pre-trained models, run:
124 |
125 | ```bash
126 | python -m torch.distributed.launch --nproc_per_node=1 --master_port=3146 --use_env main_pruning.py --config config/spvit_deit_bs_l006_t100_ft.json --resume [PRE-TRAINED MODEL PATH] --eval
127 | ```
128 |
129 | After fine-tuning, you can optimize your checkpoint to a smaller size with the following code:
130 | ```bash
131 | python post_training_optimize_checkpoint.py YOUR_CHECKPOINT_PATH
132 | ```
133 | The optimized checkpoint can be evaluated by replacing `UnifiedAttention` with `UnifiedAttentionParamOpt` and we have provided an example in `SPViT_DeiT/config/spvit_deit_bs_l008_t60_ft_param_opt.json`.
134 |
135 | #### TODO:
136 |
137 | ```
138 | - [x] Release code.
139 | - [x] Release pre-trained models.
140 | ```
141 |
142 |
--------------------------------------------------------------------------------
/SPViT_DeiT/config/spvit_deit_bs_l006_t100_ft.json:
--------------------------------------------------------------------------------
1 | {
2 | "model": "spvit_deit_base_patch16_224",
3 | "batch_size": 128,
4 | "data_path": "data/imagenet",
5 | "data_set": "IMNET",
6 | "exp_name": "spvit_deit_bs_l006_t100_ft",
7 | "input_size": 224,
8 | "patch_size": 16,
9 | "num_workers": 10,
10 | "att_layer": "UnifiedAttention",
11 | "ffn_layer": "UnifiedMlp",
12 | "loss_lambda": 0.06,
13 | "theta": 1.5,
14 | "target_flops": 10.0,
15 | "resume": "model/deit_base_patch16_224-b5f2ef4d.pth",
16 | "searching_model": "ffn_indicators/spvit_deit_bs_l006_t100_search_15epoch.pth",
17 | "assigned_indicators": [[1.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 1.0, 0.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
18 | "arc_lr": 1e-3,
19 | "lr": 5e-5,
20 | "epochs": 130,
21 | "warmup_epochs": 0
22 | }
--------------------------------------------------------------------------------
/SPViT_DeiT/config/spvit_deit_bs_l006_t100_search.json:
--------------------------------------------------------------------------------
1 | {
2 | "model": "spvit_deit_base_patch16_224",
3 | "batch_size": 128,
4 | "data_path": "data/imagenet",
5 | "data_set": "IMNET",
6 | "exp_name": "spvit_deit_bs_l006_t100_search",
7 | "input_size": 224,
8 | "patch_size": 16,
9 | "num_workers": 10,
10 | "att_layer": "UnifiedAttention",
11 | "ffn_layer": "UnifiedMlp",
12 | "loss_lambda": 0.06,
13 | "theta": 1.5,
14 | "target_flops": 10.0,
15 | "resume": "model/deit_base_patch16_224-b5f2ef4d.pth",
16 | "searching_model": "",
17 | "assigned_indicators": [],
18 | "arc_lr": 1e-3,
19 | "lr": 5e-5,
20 | "min_lr": 1e-4,
21 | "warmup_epochs": 0
22 | }
--------------------------------------------------------------------------------
/SPViT_DeiT/config/spvit_deit_bs_l008_t60_ft.json:
--------------------------------------------------------------------------------
1 | {
2 | "model": "spvit_deit_base_patch16_224",
3 | "batch_size": 128,
4 | "data_path": "data/imagenet",
5 | "data_set": "IMNET",
6 | "exp_name": "spvit_deit_bs_l008_t60_ft",
7 | "input_size": 224,
8 | "patch_size": 16,
9 | "num_workers": 10,
10 | "att_layer": "UnifiedAttention",
11 | "ffn_layer": "UnifiedMlp",
12 | "loss_lambda": 0.08,
13 | "theta": 1.5,
14 | "target_flops": 6.0,
15 | "resume": "model/deit_base_patch16_224-b5f2ef4d.pth",
16 | "searching_model": "ffn_indicators/spvit_deit_bs_l008_t60_search_10epoch.pth",
17 | "assigned_indicators": [[1.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 1.0, 0.0], [1.0, 1.0, 0.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
18 | "arc_lr": 1e-3,
19 | "lr": 5e-5,
20 | "epochs": 130,
21 | "warmup_epochs": 0
22 | }
--------------------------------------------------------------------------------
/SPViT_DeiT/config/spvit_deit_bs_l008_t60_ft_dist.json:
--------------------------------------------------------------------------------
1 | {
2 | "model": "spvit_deit_base_patch16_224",
3 | "batch_size": 128,
4 | "data_path": "data/imagenet",
5 | "data_set": "IMNET",
6 | "exp_name": "spvit_deit_bs_l008_t60_ft_dist",
7 | "input_size": 224,
8 | "patch_size": 16,
9 | "num_workers": 10,
10 | "att_layer": "UnifiedAttention",
11 | "ffn_layer": "UnifiedMlp",
12 | "loss_lambda": 0.08,
13 | "theta": 1.5,
14 | "target_flops": 6.0,
15 | "resume": "model/deit_base_patch16_224-b5f2ef4d.pth",
16 | "searching_model": "ffn_indicators/spvit_deit_bs_l008_t60_search_10epoch.pth",
17 | "assigned_indicators": [[1.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 1.0, 0.0], [1.0, 1.0, 0.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
18 | "arc_lr": 1e-3,
19 | "lr": 5e-5,
20 | "epochs": 200,
21 | "warmup_epochs": 0,
22 | "teacher_model": "regnety_160",
23 | "teacher_path": "https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth",
24 | "distillation_type": "hard"
25 | }
--------------------------------------------------------------------------------
/SPViT_DeiT/config/spvit_deit_bs_l008_t60_ft_param_opt.json:
--------------------------------------------------------------------------------
1 | {
2 | "model": "spvit_deit_base_patch16_224",
3 | "batch_size": 128,
4 | "data_path": "data/imagenet",
5 | "data_set": "IMNET",
6 | "exp_name": "spvit_deit_bs_l008_t60_ft",
7 | "input_size": 224,
8 | "patch_size": 16,
9 | "num_workers": 10,
10 | "att_layer": "UnifiedAttentionParamOpt",
11 | "ffn_layer": "UnifiedMlp",
12 | "loss_lambda": 0.08,
13 | "theta": 1.5,
14 | "target_flops": 6.0,
15 | "resume": "model/spvit_deit_bs_l008_t60_dist_optimized.pth",
16 | "searching_model": "ffn_indicators/spvit_deit_bs_l008_t60_search_10epoch.pth",
17 | "assigned_indicators": [[1.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 1.0, 0.0], [1.0, 1.0, 0.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
18 | "arc_lr": 1e-3,
19 | "lr": 5e-5,
20 | "epochs": 130,
21 | "warmup_epochs": 0
22 | }
--------------------------------------------------------------------------------
/SPViT_DeiT/config/spvit_deit_bs_l008_t60_search.json:
--------------------------------------------------------------------------------
1 | {
2 | "model": "spvit_deit_base_patch16_224",
3 | "batch_size": 128,
4 | "data_path": "data/imagenet",
5 | "data_set": "IMNET",
6 | "exp_name": "spvit_deit_bs_l006_t100_search",
7 | "input_size": 224,
8 | "patch_size": 16,
9 | "num_workers": 10,
10 | "att_layer": "UnifiedAttention",
11 | "ffn_layer": "UnifiedMlp",
12 | "loss_lambda": 0.08,
13 | "theta": 1.5,
14 | "target_flops": 6.0,
15 | "resume": "model/deit_base_patch16_224-b5f2ef4d.pth",
16 | "searching_model": "",
17 | "assigned_indicators": [],
18 | "arc_lr": 1e-3,
19 | "lr": 5e-5,
20 | "min_lr": 1e-4,
21 | "warmup_epochs": 0
22 | }
--------------------------------------------------------------------------------
/SPViT_DeiT/config/spvit_deit_sm_l30_t32_ft.json:
--------------------------------------------------------------------------------
1 | {
2 | "model": "spvit_deit_small_patch16_224",
3 | "batch_size": 128,
4 | "data_path": "data/imagenet",
5 | "data_set": "IMNET",
6 | "exp_name": "spvit_deit_sm_l30_t32_ft",
7 | "input_size": 224,
8 | "patch_size": 16,
9 | "num_workers": 10,
10 | "att_layer": "UnifiedAttention",
11 | "ffn_layer": "UnifiedMlp",
12 | "loss_lambda": 3.0,
13 | "theta": 1.5,
14 | "target_flops": 3.2,
15 | "resume": "model/deit_small_patch16_224-cd65a155.pth",
16 | "searching_model": "ffn_indicators/spvit_deit_sm_l30_t32_search_10epoch.pth",
17 | "assigned_indicators": [[1.0, 0.0, 0.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
18 | "arc_lr": 5e-4,
19 | "lr": 5e-5,
20 | "epochs": 130,
21 | "warmup_epochs": 0
22 | }
--------------------------------------------------------------------------------
/SPViT_DeiT/config/spvit_deit_sm_l30_t32_ft_dist.json:
--------------------------------------------------------------------------------
1 | {
2 | "model": "spvit_deit_small_patch16_224",
3 | "batch_size": 128,
4 | "data_path": "data/imagenet",
5 | "data_set": "IMNET",
6 | "exp_name": "spvit_deit_sm_l30_t32_ft_dist",
7 | "input_size": 224,
8 | "patch_size": 16,
9 | "num_workers": 10,
10 | "att_layer": "UnifiedAttention",
11 | "ffn_layer": "UnifiedMlp",
12 | "loss_lambda": 3.0,
13 | "theta": 1.5,
14 | "target_flops": 3.2,
15 | "resume": "model/deit_small_patch16_224-cd65a155.pth",
16 | "searching_model": "ffn_indicators/spvit_deit_sm_l30_t32_search_10epoch.pth",
17 | "assigned_indicators": [[1.0, 0.0, 0.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
18 | "arc_lr": 5e-4,
19 | "lr": 5e-5,
20 | "epochs": 200,
21 | "warmup_epochs": 0,
22 | "teacher_model": "regnety_160",
23 | "teacher_path": "https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth",
24 | "distillation_type": "hard"
25 | }
--------------------------------------------------------------------------------
/SPViT_DeiT/config/spvit_deit_sm_l30_t32_search.json:
--------------------------------------------------------------------------------
1 | {
2 | "model": "spvit_deit_small_patch16_224",
3 | "batch_size": 128,
4 | "data_path": "data/imagenet",
5 | "data_set": "IMNET",
6 | "exp_name": "spvit_deit_sm_l30_t32_search",
7 | "input_size": 224,
8 | "patch_size": 16,
9 | "num_workers": 10,
10 | "att_layer": "UnifiedAttention",
11 | "ffn_layer": "UnifiedMlp",
12 | "loss_lambda": 3.0,
13 | "theta": 1.5,
14 | "target_flops": 3.2,
15 | "resume": "model/deit_small_patch16_224-cd65a155.pth",
16 | "searching_model": "",
17 | "assigned_indicators": [],
18 | "arc_lr": 5e-4,
19 | "lr": 5e-5,
20 | "min_lr": 1e-4,
21 | "warmup_epochs": 0
22 | }
--------------------------------------------------------------------------------
/SPViT_DeiT/config/spvit_deit_ti_l200_t10_ft.json:
--------------------------------------------------------------------------------
1 | {
2 | "model": "spvit_deit_tiny_patch16_224",
3 | "batch_size": 256,
4 | "data_path": "data/imagenet",
5 | "data_set": "IMNET",
6 | "exp_name": "spvit_deit_ti_l200_t10_ft",
7 | "input_size": 224,
8 | "patch_size": 16,
9 | "num_workers": 10,
10 | "att_layer": "UnifiedAttention",
11 | "ffn_layer": "UnifiedMlp",
12 | "loss_lambda": 20.0,
13 | "theta": 1.5,
14 | "target_flops": 1.0,
15 | "resume": "model/deit_tiny_patch16_224-a1311bcf.pth",
16 | "searching_model": "ffn_indicators/spvit_deit_ti_l200_t10_search_10epoch.pth",
17 | "assigned_indicators": [[1.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
18 | "arc_lr": 1e-3,
19 | "lr": 5e-5,
20 | "epochs": 130,
21 | "warmup_epochs": 0
22 | }
--------------------------------------------------------------------------------
/SPViT_DeiT/config/spvit_deit_ti_l200_t10_ft_dist.json:
--------------------------------------------------------------------------------
1 | {
2 | "model": "spvit_deit_tiny_patch16_224",
3 | "batch_size": 256,
4 | "data_path": "data/imagenet",
5 | "data_set": "IMNET",
6 | "exp_name": "spvit_deit_ti_l200_t10_ft_dist",
7 | "input_size": 224,
8 | "patch_size": 16,
9 | "num_workers": 10,
10 | "att_layer": "UnifiedAttention",
11 | "ffn_layer": "UnifiedMlp",
12 | "loss_lambda": 20.0,
13 | "theta": 1.5,
14 | "target_flops": 1.0,
15 | "resume": "model/deit_tiny_patch16_224-a1311bcf.pth",
16 | "searching_model": "ffn_indicators/spvit_deit_ti_l200_t10_search_10epoch.pth",
17 | "assigned_indicators": [[1.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
18 | "arc_lr": 1e-3,
19 | "lr": 5e-5,
20 | "epochs": 200,
21 | "warmup_epochs": 0,
22 | "teacher_model": "regnety_160",
23 | "teacher_path": "https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth",
24 | "distillation_type": "hard"
25 | }
--------------------------------------------------------------------------------
/SPViT_DeiT/config/spvit_deit_ti_l200_t10_search.json:
--------------------------------------------------------------------------------
1 | {
2 | "model": "spvit_deit_tiny_patch16_224",
3 | "batch_size": 256,
4 | "data_path": "data/imagenet",
5 | "data_set": "IMNET",
6 | "exp_name": "spvit_deit_ti_l200_t10_search",
7 | "input_size": 224,
8 | "patch_size": 16,
9 | "num_workers": 10,
10 | "att_layer": "UnifiedAttention",
11 | "ffn_layer": "UnifiedMlp",
12 | "loss_lambda": 20.0,
13 | "theta": 1.5,
14 | "target_flops": 1.0,
15 | "resume": "model/deit_tiny_patch16_224-a1311bcf.pth",
16 | "searching_model": "",
17 | "assigned_indicators": [],
18 | "arc_lr": 1e-3,
19 | "lr": 5e-5,
20 | "min_lr": 1e-4,
21 | "warmup_epochs": 0
22 | }
--------------------------------------------------------------------------------
/SPViT_DeiT/datasets.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2015-present, Facebook, Inc.
2 | # All rights reserved.
3 | # Modifications copyright (c) 2021 Zhuang AI Group, Haoyu He
4 |
5 | import os
6 | import json
7 |
8 | from torchvision import datasets, transforms
9 | from torchvision.datasets.folder import ImageFolder, default_loader
10 |
11 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
12 | from timm.data import create_transform
13 |
14 |
15 | class INatDataset(ImageFolder):
16 | def __init__(self, root, train=True, year=2018, transform=None, target_transform=None,
17 | category='name', loader=default_loader):
18 | self.transform = transform
19 | self.loader = loader
20 | self.target_transform = target_transform
21 | self.year = year
22 | # assert category in ['kingdom','phylum','class','order','supercategory','family','genus','name']
23 | path_json = os.path.join(root, f'{"train" if train else "val"}{year}.json')
24 | with open(path_json) as json_file:
25 | data = json.load(json_file)
26 |
27 | with open(os.path.join(root, 'categories.json')) as json_file:
28 | data_catg = json.load(json_file)
29 |
30 | path_json_for_targeter = os.path.join(root, f"train{year}.json")
31 |
32 | with open(path_json_for_targeter) as json_file:
33 | data_for_targeter = json.load(json_file)
34 |
35 | targeter = {}
36 | indexer = 0
37 | for elem in data_for_targeter['annotations']:
38 | king = []
39 | king.append(data_catg[int(elem['category_id'])][category])
40 | if king[0] not in targeter.keys():
41 | targeter[king[0]] = indexer
42 | indexer += 1
43 | self.nb_classes = len(targeter)
44 |
45 | self.samples = []
46 | for elem in data['images']:
47 | cut = elem['file_name'].split('/')
48 | target_current = int(cut[2])
49 | path_current = os.path.join(root, cut[0], cut[2], cut[3])
50 |
51 | categors = data_catg[target_current]
52 | target_current_true = targeter[categors[category]]
53 | self.samples.append((path_current, target_current_true))
54 |
55 | # __getitem__ and __len__ inherited from ImageFolder
56 |
57 |
58 | def build_dataset(is_train, args):
59 | transform = build_transform(is_train, args)
60 |
61 | if args.data_set == 'CIFAR':
62 | dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform)
63 | nb_classes = 100
64 | elif args.data_set == 'IMNET':
65 | root = os.path.join(args.data_path, 'train' if is_train else 'val')
66 | dataset = datasets.ImageFolder(root, transform=transform)
67 | nb_classes = 1000
68 | elif args.data_set == 'INAT':
69 | dataset = INatDataset(args.data_path, train=is_train, year=2018,
70 | category=args.inat_category, transform=transform)
71 | nb_classes = dataset.nb_classes
72 | elif args.data_set == 'INAT19':
73 | dataset = INatDataset(args.data_path, train=is_train, year=2019,
74 | category=args.inat_category, transform=transform)
75 | nb_classes = dataset.nb_classes
76 | elif args.data_set == 'IMNET100':
77 | root = os.path.join(args.data_path, 'train100' if is_train else 'val100')
78 | dataset = datasets.ImageFolder(root, transform=transform)
79 | nb_classes = 100
80 |
81 | return dataset, nb_classes
82 |
83 |
84 | def build_transform(is_train, args):
85 | resize_im = args.input_size > 32
86 | if is_train:
87 | # this should always dispatch to transforms_imagenet_train
88 | transform = create_transform(
89 | input_size=args.input_size,
90 | is_training=True,
91 | color_jitter=args.color_jitter,
92 | auto_augment=args.aa,
93 | interpolation=args.train_interpolation,
94 | re_prob=args.reprob,
95 | re_mode=args.remode,
96 | re_count=args.recount,
97 | )
98 | if not resize_im:
99 | # replace RandomResizedCropAndInterpolation with
100 | # RandomCrop
101 | transform.transforms[0] = transforms.RandomCrop(
102 | args.input_size, padding=4)
103 | return transform
104 |
105 | t = []
106 | if resize_im:
107 | size = int((256 / 224) * args.input_size)
108 | t.append(
109 | transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images
110 | )
111 | t.append(transforms.CenterCrop(args.input_size))
112 |
113 | t.append(transforms.ToTensor())
114 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
115 | return transforms.Compose(t)
116 |
--------------------------------------------------------------------------------
/SPViT_DeiT/engine.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2015-present, Facebook, Inc.
2 | # All rights reserved.
3 | # Modifications copyright (c) 2021 Zhuang AI Group, Haoyu He
4 |
5 | """
6 | Train and eval functions used in main.py
7 | """
8 | import math
9 | import sys
10 | from typing import Iterable, Optional
11 |
12 | import torch
13 |
14 | from timm.data import Mixup
15 | from timm.utils import accuracy, ModelEma
16 |
17 | from losses import DistillationLoss
18 | import utils
19 |
20 |
21 | def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss,
22 | data_loader: Iterable, optimizer: torch.optim.Optimizer,
23 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
24 | mixup_fn: Optional[Mixup] = None,
25 | set_training_mode=True):
26 | model.train(set_training_mode)
27 | metric_logger = utils.MetricLogger(delimiter=" ")
28 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
29 | header = 'Epoch: [{}]'.format(epoch)
30 | print_freq = 10
31 |
32 | for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
33 | samples = samples.to(device, non_blocking=True)
34 | targets = targets.to(device, non_blocking=True)
35 |
36 | if mixup_fn is not None:
37 | samples, targets = mixup_fn(samples, targets)
38 |
39 | with torch.cuda.amp.autocast():
40 |
41 | outputs = model(samples)
42 | loss = criterion(samples, outputs, targets)
43 |
44 | loss_value = loss.item()
45 |
46 | if not math.isfinite(loss_value):
47 | print("Loss is {}, stopping training".format(loss_value))
48 | sys.exit(1)
49 |
50 | optimizer.zero_grad()
51 |
52 | # this attribute is added by timm on one optimizer (adahessian)
53 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
54 |
55 | # loss.backward()
56 | # for name, param in model.named_parameters():
57 | # if 'thresholds' in name:
58 | # print(name, param.grad)
59 |
60 | loss_scaler(loss, optimizer, clip_grad=max_norm,
61 | parameters=model.parameters(), create_graph=is_second_order)
62 |
63 | torch.cuda.synchronize()
64 |
65 | metric_logger.update(loss=loss_value)
66 | metric_logger.update(lr=optimizer.param_groups[0]["lr"])
67 |
68 | # gather the stats from all processes
69 | metric_logger.synchronize_between_processes()
70 | print("Averaged stats:", metric_logger)
71 |
72 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
73 |
74 |
75 | def train_one_epoch_pruning(model: torch.nn.Module, criterion: DistillationLoss,
76 | data_loader: Iterable, optimizer1: torch.optim.Optimizer, optimizer2: torch.optim.Optimizer,
77 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
78 | mixup_fn: Optional[Mixup] = None,
79 | set_training_mode=True, logger=None):
80 | model.train(set_training_mode)
81 | metric_logger = utils.MetricLogger(delimiter=" ")
82 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
83 | header = 'Epoch: [{}]'.format(epoch)
84 | print_freq = 10
85 |
86 | for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
87 | samples = samples.to(device, non_blocking=True)
88 | targets = targets.to(device, non_blocking=True)
89 |
90 | if mixup_fn is not None:
91 | samples, targets = mixup_fn(samples, targets)
92 |
93 | with torch.cuda.amp.autocast():
94 |
95 | outputs, msa_indicators_list, msa_thresholds_list, ffn_indicators_list = model(samples)
96 | loss_cls = criterion(samples, outputs, targets)
97 |
98 | if not model.module.assigned_indicators:
99 | loss_bop = model.module.calculate_bops_loss()
100 | loss = loss_cls + loss_bop
101 | else:
102 | loss_bop = torch.zeros(1).to(loss_cls.device)
103 | loss = loss_cls
104 |
105 | loss_value = loss.item()
106 |
107 | if not math.isfinite(loss_value):
108 | print("Loss is {}, stopping training".format(loss_value))
109 | sys.exit(1)
110 |
111 | optimizer1.zero_grad()
112 | optimizer2.zero_grad()
113 |
114 | # this attribute is added by timm on one optimizer (adahessian)
115 | is_second_order1 = hasattr(optimizer1, 'is_second_order') and optimizer1.is_second_order
116 |
117 | if not model.module.assigned_indicators:
118 | loss_scaler(loss, optimizer1, optimizer2, clip_grad=max_norm, create_graph=is_second_order1, model=model)
119 | else:
120 |
121 | # Not using architecture optimizer during fine-tuning
122 | loss_scaler(loss, optimizer1, None, clip_grad=max_norm, create_graph=is_second_order1, model=model)
123 |
124 | torch.cuda.synchronize()
125 |
126 | metric_logger.update(loss=loss_value)
127 | metric_logger.update(loss_cls=loss_cls.item())
128 | metric_logger.update(loss_bop=loss_bop.item())
129 | metric_logger.update(lr=optimizer1.param_groups[0]["lr"])
130 |
131 | str_msa_thresholds = ''
132 | if not model.module.assigned_indicators and utils.get_rank() == 0:
133 | str_msa_thresholds = str(
134 | [["{:.3f}".format(i.item()) for i in blocks] for blocks in msa_thresholds_list])
135 |
136 | logger.info(str_msa_thresholds)
137 |
138 | str_ffn_indicators = str(
139 | [i.item() for i in ffn_indicators_list])
140 |
141 | logger.info(str_ffn_indicators)
142 |
143 | # break
144 |
145 | # gather the stats from all processes
146 | metric_logger.synchronize_between_processes()
147 | print("Averaged stats:", metric_logger)
148 |
149 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}, str_msa_thresholds
150 |
151 |
152 | @torch.no_grad()
153 | def evaluate(data_loader, model, device):
154 | criterion = torch.nn.CrossEntropyLoss()
155 |
156 | metric_logger = utils.MetricLogger(delimiter=" ")
157 | header = 'Test:'
158 |
159 | # switch to evaluation mode
160 | model.eval()
161 |
162 | for images, target in metric_logger.log_every(data_loader, 10, header):
163 | images = images.to(device, non_blocking=True)
164 | target = target.to(device, non_blocking=True)
165 |
166 | # compute output
167 | with torch.cuda.amp.autocast():
168 | output = model(images)
169 | loss = criterion(output, target)
170 |
171 | # metric_logger.log_indicator(indicators)
172 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
173 |
174 | batch_size = images.shape[0]
175 | metric_logger.update(loss=loss.item())
176 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
177 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
178 | # gather the stats from all processes
179 | metric_logger.synchronize_between_processes()
180 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
181 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
182 |
183 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
184 |
185 |
186 | @torch.no_grad()
187 | def evaluate_pruning(data_loader, model, device):
188 | criterion = torch.nn.CrossEntropyLoss()
189 |
190 | metric_logger = utils.MetricLogger(delimiter=" ")
191 | header = 'Test:'
192 |
193 | # switch to evaluation mode
194 | model.eval()
195 |
196 | for images, target in metric_logger.log_every(data_loader, 10, header):
197 | images = images.to(device, non_blocking=True)
198 | target = target.to(device, non_blocking=True)
199 |
200 | # compute output
201 | with torch.cuda.amp.autocast():
202 | output, msa_indicators_list, msa_thresholds_list, ffn_indicators_list = model(images)
203 | loss = criterion(output, target)
204 |
205 | # metric_logger.log_indicator(indicators)
206 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
207 |
208 | batch_size = images.shape[0]
209 | metric_logger.update(loss=loss.item())
210 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
211 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
212 |
213 | # break
214 |
215 | # gather the stats from all processes
216 | metric_logger.synchronize_between_processes()
217 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
218 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
219 |
220 | str_msa_indicators = ''
221 | str_msa_thresholds = ''
222 | str_ffn_indicators = ''
223 | str_flops = ''
224 | msa_thresholds = []
225 |
226 | # If searching, print some stuff
227 | if not model.module.assigned_indicators and utils.get_rank() == 0:
228 | str_msa_indicators = str(
229 | [[i.item() for i in blocks] for blocks in msa_indicators_list])
230 |
231 | print('str_msa_indicators: ', str_msa_indicators)
232 |
233 | str_msa_thresholds = str(
234 | [["{:.3f}".format(i.item()) for i in blocks] for blocks in msa_thresholds_list])
235 |
236 | print('str_msa_thresholds: ', str_msa_thresholds)
237 |
238 | str_ffn_indicators = str(
239 | [i.item() for i in ffn_indicators_list])
240 |
241 | print('str_ffn_indicators: ', str_ffn_indicators)
242 |
243 | str_flops = str("{:.3f}".format(model.module.flops()[0].item() / 1e9))
244 | print('flops: ', str_flops)
245 |
246 | msa_thresholds = [[i.item() for i in blocks] for blocks in msa_thresholds_list]
247 |
248 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}, str_msa_indicators, str_msa_thresholds,\
249 | str_ffn_indicators, str_flops, msa_thresholds
250 |
--------------------------------------------------------------------------------
/SPViT_DeiT/ffn_indicators/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziplab/SPViT/ae19b0dafae3baf15c8eb4e817f4156386936866/SPViT_DeiT/ffn_indicators/.DS_Store
--------------------------------------------------------------------------------
/SPViT_DeiT/ffn_indicators/spvit_deit_bs_l006_t100_search_15epoch.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziplab/SPViT/ae19b0dafae3baf15c8eb4e817f4156386936866/SPViT_DeiT/ffn_indicators/spvit_deit_bs_l006_t100_search_15epoch.pth
--------------------------------------------------------------------------------
/SPViT_DeiT/ffn_indicators/spvit_deit_bs_l008_t60_search_10epoch.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziplab/SPViT/ae19b0dafae3baf15c8eb4e817f4156386936866/SPViT_DeiT/ffn_indicators/spvit_deit_bs_l008_t60_search_10epoch.pth
--------------------------------------------------------------------------------
/SPViT_DeiT/ffn_indicators/spvit_deit_sm_l30_t32_search_10epoch.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziplab/SPViT/ae19b0dafae3baf15c8eb4e817f4156386936866/SPViT_DeiT/ffn_indicators/spvit_deit_sm_l30_t32_search_10epoch.pth
--------------------------------------------------------------------------------
/SPViT_DeiT/ffn_indicators/spvit_deit_ti_l200_t10_search_10epoch.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziplab/SPViT/ae19b0dafae3baf15c8eb4e817f4156386936866/SPViT_DeiT/ffn_indicators/spvit_deit_ti_l200_t10_search_10epoch.pth
--------------------------------------------------------------------------------
/SPViT_DeiT/hubconf.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2015-present, Facebook, Inc.
2 | # All rights reserved.
3 | # Modifications copyright (c) 2021 Zhuang AI Group, Haoyu He
4 |
5 | from models import *
6 |
7 | dependencies = ["torch", "torchvision", "timm"]
8 |
--------------------------------------------------------------------------------
/SPViT_DeiT/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | from pathlib import Path
4 | import json
5 | from datetime import datetime
6 | from params import args
7 |
8 | dt = datetime.now()
9 | dt.replace(tzinfo=datetime.now().astimezone().tzinfo)
10 | _LOG_FMT = '%(asctime)s - %(levelname)s - %(name)s - %(message)s'
11 | _DATE_FMT = '%m/%d/%Y %H:%M:%S'
12 | logging.basicConfig(format=_LOG_FMT, datefmt=_DATE_FMT, level=logging.INFO)
13 | logger = logging.getLogger('__main__') # this is the global logger
14 | current_time = datetime.now().strftime('%b%d_%H-%M-%S')
15 | output_dir = os.path.join('outputs', args.exp_name)
16 | Path(output_dir).mkdir(parents=True, exist_ok=True)
17 | checkpoint_path = os.path.join(output_dir, 'last_checkpoint.pth')
18 |
19 |
20 | # Here we auto load checkpoint even if there is a resume file
21 | setattr(args, 'auto_resume', False)
22 | if os.path.exists(checkpoint_path):
23 | setattr(args, 'resume', checkpoint_path)
24 | setattr(args, 'auto_resume', True)
25 |
26 | setattr(args, 'output_dir', output_dir)
27 |
28 | if not args.eval:
29 | log_path = os.path.join(output_dir, 'all_logs.txt')
30 | with open(os.path.join(args.output_dir, 'args.json'), 'w+') as f:
31 | json.dump(vars(args), f, indent=4)
32 | else:
33 | log_path = os.path.join(output_dir, 'eval_logs.txt')
34 |
35 | fh = logging.FileHandler(log_path, 'a+')
36 | formatter = logging.Formatter(_LOG_FMT, datefmt=_DATE_FMT)
37 | fh.setFormatter(formatter)
38 | logger.addHandler(fh)
--------------------------------------------------------------------------------
/SPViT_DeiT/losses.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2015-present, Facebook, Inc.
2 | # All rights reserved.
3 | # Modifications copyright (c) 2021 Zhuang AI Group, Haoyu He
4 |
5 | """
6 | Implements the knowledge distillation loss
7 | """
8 | import torch
9 | from torch.nn import functional as F
10 |
11 |
12 | class DistillationLoss(torch.nn.Module):
13 | """
14 | This module wraps a standard criterion and adds an extra knowledge distillation loss by
15 | taking a teacher model prediction and using it as additional supervision.
16 | """
17 | def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module,
18 | distillation_type: str, alpha: float, tau: float):
19 | super().__init__()
20 | self.base_criterion = base_criterion
21 | self.teacher_model = teacher_model
22 | assert distillation_type in ['none', 'soft', 'hard']
23 | self.distillation_type = distillation_type
24 | self.alpha = alpha
25 | self.tau = tau
26 |
27 | def forward(self, inputs, outputs, labels):
28 | """
29 | Args:
30 | inputs: The original inputs that are feed to the teacher model
31 | outputs: the outputs of the model to be trained. It is expected to be
32 | either a Tensor, or a Tuple[Tensor, Tensor], with the original output
33 | in the first position and the distillation predictions as the second output
34 | labels: the labels for the base criterion
35 | """
36 |
37 | outputs_dist = outputs
38 |
39 | base_loss = self.base_criterion(outputs, labels)
40 | if self.distillation_type == 'none':
41 | return base_loss
42 |
43 | # don't backprop throught the teacher
44 | with torch.no_grad():
45 | teacher_outputs = self.teacher_model(inputs)
46 |
47 | if self.distillation_type == 'soft':
48 | T = self.tau
49 | # taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
50 | # with slight modifications
51 | distillation_loss = F.kl_div(
52 | F.log_softmax(outputs_dist / T, dim=1),
53 | F.log_softmax(teacher_outputs / T, dim=1),
54 | reduction='sum',
55 | log_target=True
56 | ) * (T * T) / outputs_dist.numel()
57 | elif self.distillation_type == 'hard':
58 | distillation_loss = F.cross_entropy(outputs_dist, teacher_outputs.argmax(dim=1))
59 |
60 | loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha
61 | return loss
62 |
--------------------------------------------------------------------------------
/SPViT_DeiT/main.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2015-present, Facebook, Inc.
2 | # All rights reserved.
3 | import argparse
4 | import random
5 | import datetime
6 | import numpy as np
7 | import time
8 | import torch
9 | import torch.backends.cudnn as cudnn
10 | import json
11 | import os
12 | from pathlib import Path
13 |
14 | from timm.data import Mixup
15 | from timm.models import create_model
16 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
17 | from timm.scheduler import create_scheduler
18 | from timm.optim import create_optimizer
19 | from timm.utils import NativeScaler, get_state_dict, ModelEma
20 |
21 | from datasets import build_dataset
22 | from engine import train_one_epoch, evaluate
23 | from losses import DistillationLoss
24 | from samplers import RASampler
25 | from models import Attention, get_attention_flops
26 | import utils
27 | from params import args
28 | from logger import logger
29 |
30 | from timm.models import model_entrypoint
31 |
32 |
33 | class Custom_scaler:
34 | state_dict_key = "amp_scaler"
35 |
36 | def __init__(self):
37 | self._scaler = torch.cuda.amp.GradScaler()
38 |
39 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False):
40 | self._scaler.scale(loss).backward(create_graph=create_graph)
41 |
42 | if clip_grad is not None:
43 | assert parameters is not None
44 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
45 | torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
46 | self._scaler.step(optimizer)
47 | self._scaler.update()
48 |
49 | def state_dict(self):
50 | return self._scaler.state_dict()
51 |
52 | def load_state_dict(self, state_dict):
53 | self._scaler.load_state_dict(state_dict)
54 |
55 |
56 | def main():
57 | utils.init_distributed_mode(args)
58 | if utils.get_rank() != 0:
59 | logger.disabled = True
60 | print(args)
61 |
62 | if args.distillation_type != 'none' and args.finetune and not args.eval:
63 | raise NotImplementedError("Finetuning with distillation not yet supported")
64 |
65 | device = torch.device(args.device)
66 |
67 | # fix the seed for reproducibility
68 | torch.backends.cudnn.deterministic = True
69 | seed = args.seed + utils.get_rank()
70 | torch.manual_seed(seed)
71 | torch.cuda.manual_seed(seed)
72 | torch.cuda.manual_seed_all(seed)
73 | np.random.seed(seed)
74 | random.seed(seed)
75 |
76 | cudnn.benchmark = True
77 |
78 | dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
79 | dataset_val, _ = build_dataset(is_train=False, args=args)
80 |
81 | if True: # args.distributed:
82 | num_tasks = utils.get_world_size()
83 | global_rank = utils.get_rank()
84 | if args.repeated_aug:
85 | sampler_train = RASampler(
86 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
87 | )
88 | else:
89 | sampler_train = torch.utils.data.DistributedSampler(
90 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
91 | )
92 | if args.dist_eval:
93 | if len(dataset_val) % num_tasks != 0:
94 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
95 | 'This will slightly alter validation results as extra duplicate entries are added to achieve '
96 | 'equal num of samples per-process.')
97 | sampler_val = torch.utils.data.DistributedSampler(
98 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False)
99 | else:
100 | sampler_val = torch.utils.data.SequentialSampler(dataset_val)
101 | else:
102 | sampler_train = torch.utils.data.RandomSampler(dataset_train)
103 | sampler_val = torch.utils.data.SequentialSampler(dataset_val)
104 |
105 | data_loader_train = torch.utils.data.DataLoader(
106 | dataset_train, sampler=sampler_train,
107 | batch_size=args.batch_size,
108 | num_workers=args.num_workers,
109 | pin_memory=args.pin_mem,
110 | drop_last=True,
111 | )
112 |
113 | data_loader_val = torch.utils.data.DataLoader(
114 | dataset_val, sampler=sampler_val,
115 | batch_size=int(1.5 * args.batch_size),
116 | num_workers=args.num_workers,
117 | pin_memory=args.pin_mem,
118 | drop_last=False
119 | )
120 |
121 | mixup_fn = None
122 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
123 | if mixup_active:
124 | mixup_fn = Mixup(
125 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
126 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
127 | label_smoothing=args.smoothing, num_classes=args.nb_classes)
128 |
129 | logger.info(f"Creating model: {args.model}")
130 | model = create_model(
131 | args.model,
132 | pretrained=False,
133 | num_classes=args.nb_classes,
134 | drop_rate=args.drop,
135 | drop_path_rate=args.drop_path,
136 | drop_block_rate=None,
137 | # att_mode=args.att_mode
138 | )
139 |
140 | # if utils.get_rank() == 0:
141 | # # print_size_of_model(model)
142 | # try:
143 | # from ptflops import get_model_complexity_info
144 | # macs, params = get_model_complexity_info(model, (3, args.input_size, args.input_size), as_strings=True,
145 | # print_per_layer_stat=False, verbose=False, custom_modules_hooks={Attention:get_attention_flops})
146 | # # flops = macs
147 | # logger.info('{:<30} {:<8}'.format('MACs: ', macs))
148 | # logger.info('{:<30} {:<8}'.format('Number of parameters: ', params))
149 | # except:
150 | # pass
151 |
152 | if args.finetune:
153 | if args.finetune.startswith('https'):
154 | checkpoint = torch.hub.load_state_dict_from_url(
155 | args.finetune, map_location='cpu', check_hash=True)
156 | else:
157 | checkpoint = torch.load(args.finetune, map_location='cpu')
158 |
159 | checkpoint_model = checkpoint['model']
160 | state_dict = model.state_dict()
161 | for k in ['head.weight', 'head.bias', 'head_dist.weight', 'head_dist.bias']:
162 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
163 | print(f"Removing key {k} from pretrained checkpoint")
164 | del checkpoint_model[k]
165 |
166 | # interpolate position embedding
167 | pos_embed_checkpoint = checkpoint_model['pos_embed']
168 | embedding_size = pos_embed_checkpoint.shape[-1]
169 | num_patches = model.patch_embed.num_patches
170 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches
171 | # height (== width) for the checkpoint position embedding
172 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
173 | # height (== width) for the new position embedding
174 | new_size = int(num_patches ** 0.5)
175 | # class_token and dist_token are kept unchanged
176 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
177 | # only the position tokens are interpolated
178 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
179 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
180 | pos_tokens = torch.nn.functional.interpolate(
181 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
182 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
183 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
184 | checkpoint_model['pos_embed'] = new_pos_embed
185 |
186 | model.load_state_dict(checkpoint_model, strict=False)
187 |
188 | model.to(device)
189 |
190 | model_without_ddp = model
191 | if args.distributed:
192 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False)
193 | model_without_ddp = model.module
194 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
195 | logger.info('number of params: ' + str(n_parameters))
196 |
197 | linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0
198 | args.lr = linear_scaled_lr
199 | optimizer = create_optimizer(args, model_without_ddp)
200 | loss_scaler = NativeScaler()
201 |
202 | lr_scheduler, _ = create_scheduler(args, optimizer)
203 |
204 | criterion = LabelSmoothingCrossEntropy()
205 |
206 | if args.mixup > 0.:
207 | # smoothing is handled with mixup label transform
208 | criterion = SoftTargetCrossEntropy()
209 | elif args.smoothing:
210 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
211 | else:
212 | criterion = torch.nn.CrossEntropyLoss()
213 |
214 | teacher_model = None
215 | if args.distillation_type != 'none':
216 | assert args.teacher_path, 'need to specify teacher-path when using distillation'
217 | print(f"Creating teacher model: {args.teacher_model}")
218 | teacher_model = create_model(
219 | args.teacher_model,
220 | pretrained=False,
221 | num_classes=args.nb_classes,
222 | global_pool='avg',
223 | )
224 | if args.teacher_path.startswith('https'):
225 | checkpoint = torch.hub.load_state_dict_from_url(
226 | args.teacher_path, map_location='cpu', check_hash=True)
227 | else:
228 | checkpoint = torch.load(args.teacher_path, map_location='cpu')
229 | teacher_model.load_state_dict(checkpoint['model'])
230 | teacher_model.to(device)
231 | teacher_model.eval()
232 |
233 | # wrap the criterion in our custom DistillationLoss, which
234 | # just dispatches to the original criterion if args.distillation_type is 'none'
235 | criterion = DistillationLoss(
236 | criterion, teacher_model, args.distillation_type, args.distillation_alpha, args.distillation_tau
237 | )
238 |
239 | output_dir = Path(args.output_dir)
240 | if args.resume:
241 | if args.resume.startswith('https'):
242 | checkpoint = torch.hub.load_state_dict_from_url(
243 | args.resume, map_location='cpu', check_hash=True)
244 | else:
245 | checkpoint = torch.load(args.resume, map_location='cpu')
246 | model_without_ddp.load_state_dict(checkpoint['model'], strict=False)
247 | if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
248 | optimizer.load_state_dict(checkpoint['optimizer'])
249 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
250 | args.start_epoch = checkpoint['epoch'] + 1
251 | # if args.model_ema:
252 | # utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema'])
253 | if 'scaler' in checkpoint:
254 | loss_scaler.load_state_dict(checkpoint['scaler'])
255 | if args.eval:
256 | test_stats = evaluate(data_loader_val, model, device)
257 | logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
258 | return
259 | if args.throughput:
260 | throughput(data_loader_val, model, logger)
261 | return
262 |
263 | logger.info(f"Start training for {args.epochs} epochs")
264 | start_time = time.time()
265 | max_accuracy = 0.0
266 | for epoch in range(args.start_epoch, args.epochs):
267 | if args.distributed:
268 | data_loader_train.sampler.set_epoch(epoch)
269 |
270 | train_stats = train_one_epoch(
271 | model, criterion, data_loader_train,
272 | optimizer, device, epoch, loss_scaler,
273 | args.clip_grad, mixup_fn,
274 | set_training_mode=args.finetune == '' # keep in eval mode during finetuning
275 | )
276 |
277 | lr_scheduler.step(epoch)
278 | if args.output_dir:
279 | checkpoint_paths = [output_dir / 'last_checkpoint.pth']
280 | for checkpoint_path in checkpoint_paths:
281 | utils.save_on_master({
282 | 'model': model_without_ddp.state_dict(),
283 | 'optimizer': optimizer.state_dict(),
284 | 'lr_scheduler': lr_scheduler.state_dict(),
285 | 'epoch': epoch,
286 | 'scaler': loss_scaler.state_dict(),
287 | 'args': args,
288 | }, checkpoint_path)
289 |
290 | test_stats = evaluate(data_loader_val, model, device)
291 | logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
292 | if max_accuracy < test_stats["acc1"]:
293 | utils.save_on_master({
294 | 'model': model_without_ddp.state_dict(),
295 | 'optimizer': optimizer.state_dict(),
296 | 'lr_scheduler': lr_scheduler.state_dict(),
297 | 'epoch': epoch,
298 | 'scaler': loss_scaler.state_dict(),
299 | 'args': args,
300 | }, os.path.join(args.output_dir, 'best_checkpoint.pth'))
301 |
302 | max_accuracy = max(max_accuracy, test_stats["acc1"])
303 | logger.info(f'Max accuracy: {max_accuracy:.2f}%')
304 |
305 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
306 | **{f'test_{k}': v for k, v in test_stats.items()},
307 | 'epoch': epoch,
308 | 'n_parameters': n_parameters}
309 |
310 | if args.output_dir and utils.is_main_process():
311 | with (output_dir / "log.txt").open("a") as f:
312 | f.write(json.dumps(log_stats) + "\n")
313 |
314 | total_time = time.time() - start_time
315 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
316 | logger.info('Training time {}'.format(total_time_str))
317 |
318 |
319 | @torch.no_grad()
320 | def throughput(data_loader, model, logger):
321 | model.eval()
322 |
323 | for idx, (images, _) in enumerate(data_loader):
324 | images = images.cuda(non_blocking=True)
325 | batch_size = images.shape[0]
326 | for i in range(50):
327 | model(images)
328 | torch.cuda.synchronize()
329 | logger.info(f"throughput averaged with 100 times")
330 | tic1 = time.time()
331 | for i in range(100):
332 | model(images)
333 | torch.cuda.synchronize()
334 | tic2 = time.time()
335 | logger.info(f"batch_size {batch_size} throughput {100 * batch_size / (tic2 - tic1)}")
336 | return
337 |
338 |
339 | if __name__ == '__main__':
340 | main()
341 |
--------------------------------------------------------------------------------
/SPViT_DeiT/params.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import sys
3 | import json
4 |
5 |
6 | parser = argparse.ArgumentParser('DeiT training and evaluation script', add_help=False)
7 | parser.add_argument('--batch-size', default=64, type=int)
8 | parser.add_argument('--epochs', default=300, type=int)
9 |
10 | # Model parameters
11 | parser.add_argument('--model', default='deit_base_patch16_224', type=str, metavar='MODEL',
12 | help='Name of model to train')
13 | parser.add_argument('--input-size', default=224, type=int, help='images input size')
14 |
15 | parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
16 | help='Dropout rate (default: 0.)')
17 | parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT',
18 | help='Drop path rate (default: 0.1)')
19 |
20 | parser.add_argument('--model-ema', action='store_true')
21 | parser.add_argument('--no-model-ema', action='store_false', dest='model_ema')
22 | parser.set_defaults(model_ema=True)
23 | parser.add_argument('--model-ema-decay', type=float, default=0.99996, help='')
24 | parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='')
25 |
26 | # Optimizer parameters
27 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
28 | help='Optimizer (default: "adamw"')
29 | parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
30 | help='Optimizer Epsilon (default: 1e-8)')
31 | parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
32 | help='Optimizer Betas (default: None, use opt default)')
33 | parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
34 | help='Clip gradient norm (default: None, no clipping)')
35 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
36 | help='SGD momentum (default: 0.9)')
37 | parser.add_argument('--weight-decay', type=float, default=0.05,
38 | help='weight decay (default: 0.05)')
39 | # Learning rate schedule parameters
40 | parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
41 | help='LR scheduler (default: "cosine"')
42 | parser.add_argument('--lr', type=float, default=5e-4, metavar='LR',
43 | help='learning rate (default: 5e-4)')
44 | parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
45 | help='learning rate noise on/off epoch percentages')
46 | parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
47 | help='learning rate noise limit percent (default: 0.67)')
48 | parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
49 | help='learning rate noise std-dev (default: 1.0)')
50 | parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',
51 | help='warmup learning rate (default: 1e-6)')
52 | parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
53 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
54 |
55 | parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
56 | help='epoch interval to decay LR')
57 | parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
58 | help='epochs to warmup LR, if scheduler supports')
59 | parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
60 | help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
61 | parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
62 | help='patience epochs for Plateau LR scheduler (default: 10')
63 | parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
64 | help='LR decay rate (default: 0.1)')
65 |
66 | # Augmentation parameters
67 | parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
68 | help='Color jitter factor (default: 0.4)')
69 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
70 | help='Use AutoAugment policy. "v0" or "original". " + \
71 | "(default: rand-m9-mstd0.5-inc1)'),
72 | parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)')
73 | parser.add_argument('--train-interpolation', type=str, default='bicubic',
74 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
75 |
76 | parser.add_argument('--repeated-aug', action='store_true')
77 | parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug')
78 | parser.set_defaults(repeated_aug=True)
79 |
80 | # * Random Erase params
81 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
82 | help='Random erase prob (default: 0.25)')
83 | parser.add_argument('--remode', type=str, default='pixel',
84 | help='Random erase mode (default: "pixel")')
85 | parser.add_argument('--recount', type=int, default=1,
86 | help='Random erase count (default: 1)')
87 | parser.add_argument('--resplit', action='store_true', default=False,
88 | help='Do not random erase first (clean) augmentation split')
89 |
90 | # * Mixup params
91 | parser.add_argument('--mixup', type=float, default=0.8,
92 | help='mixup alpha, mixup enabled if > 0. (default: 0.8)')
93 | parser.add_argument('--cutmix', type=float, default=1.0,
94 | help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)')
95 | parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
96 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
97 | parser.add_argument('--mixup-prob', type=float, default=1.0,
98 | help='Probability of performing mixup or cutmix when either/both is enabled')
99 | parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
100 | help='Probability of switching to cutmix when both mixup and cutmix enabled')
101 | parser.add_argument('--mixup-mode', type=str, default='batch',
102 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
103 |
104 | # Distillation parameters
105 | parser.add_argument('--teacher-model', default='regnety_160', type=str, metavar='MODEL',
106 | help='Name of teacher model to train (default: "regnety_160"')
107 | parser.add_argument('--teacher-path', type=str, default='')
108 | parser.add_argument('--distillation-type', default='none', choices=['none', 'soft', 'hard'], type=str, help="")
109 | parser.add_argument('--distillation-alpha', default=0.5, type=float, help="")
110 | parser.add_argument('--distillation-tau', default=1.0, type=float, help="")
111 |
112 | # * Finetuning params
113 | parser.add_argument('--finetune', default='', help='finetune from checkpoint')
114 |
115 | # Dataset parameters
116 | parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', type=str,
117 | help='dataset path')
118 | parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19'],
119 | type=str, help='Image Net dataset path')
120 | parser.add_argument('--inat-category', default='name',
121 | choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'],
122 | type=str, help='semantic granularity')
123 |
124 | parser.add_argument('--output_dir', default='',
125 | help='path where to save, empty for no saving')
126 | parser.add_argument('--device', default='cuda',
127 | help='device to use for training / testing')
128 | parser.add_argument('--seed', default=0, type=int)
129 | parser.add_argument('--resume', default='', help='resume from checkpoint')
130 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
131 | help='start epoch')
132 | parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
133 | parser.add_argument('--throughput', action='store_true', help='Perform throughput only')
134 | parser.add_argument('--dist-eval', action='store_true', default=False, help='Enabling distributed evaluation')
135 | parser.add_argument('--num_workers', default=10, type=int)
136 | parser.add_argument('--pin-mem', action='store_true',
137 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
138 | parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem',
139 | help='')
140 | parser.set_defaults(pin_mem=True)
141 |
142 | # distributed training parameters
143 | parser.add_argument('--world_size', default=1, type=int,
144 | help='number of distributed processes')
145 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
146 | parser.add_argument('--exp_name', default='deit',
147 | type=str, help='model configuration')
148 | parser.add_argument('--config', default=None,
149 | type=str, help='model configuration')
150 | parser.add_argument('--patch_size', default=16, type=int)
151 | parser.add_argument('--num_heads', default=3, type=int)
152 | parser.add_argument('--head_dim', default=64, type=int)
153 | parser.add_argument('--num_blocks', default=12, type=int)
154 | parser.add_argument('--input_size', default=224, type=int, help='images input size')
155 | parser.add_argument('--sparse_block_mode', default=0, type=int, help='sparse policy')
156 | parser.add_argument('--custom_blocks', default='', type=str, help='custom sparse blocks')
157 | parser.add_argument('--transformer_type', default='normal', type=str, help='')
158 | parser.add_argument('--local_rank', default=0, type=int, help='')
159 |
160 | # distributed training parameters
161 |
162 | args = parser.parse_args()
163 | if args.config is not None:
164 | config_args = json.load(open(args.config))
165 | override_keys = {arg[2:].split('=')[0] for arg in sys.argv[1:]
166 | if arg.startswith('--')}
167 | for k, v in config_args.items():
168 | if k not in override_keys:
169 | setattr(args, k, v)
170 | del args.config
--------------------------------------------------------------------------------
/SPViT_DeiT/post_training_optimize_checkpoint.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import torch
3 | from collections import OrderedDict
4 | import ast
5 |
6 |
7 | def main():
8 |
9 | # Optimize the number of parameters of a checkpoint
10 |
11 | if len(sys.argv) != 3:
12 | print('Error: Two input arguments, checkpoint_path and searched MSA architecture in a list.')
13 | return
14 |
15 | checkpoint_path = sys.argv[1]
16 | state_dict = torch.load(checkpoint_path, map_location='cpu')['model']
17 |
18 | try:
19 | # Use ast.literal_eval to safely parse the string as a list
20 | MSA_indicators = ast.literal_eval(sys.argv[2])
21 |
22 | if not isinstance(MSA_indicators, list):
23 | raise ValueError("The provided parameter is not a valid list.")
24 |
25 | # Now you have the list parameter
26 | print(f"List Parameter: {MSA_indicators}")
27 |
28 | except (ValueError, SyntaxError) as e:
29 | print(f"Invalid MSA indicators: {e}")
30 |
31 | new_dict = OrderedDict()
32 |
33 | if any('bconv' in key for key in list(state_dict.keys())):
34 | print('Error: The checkpoint is already optimized!')
35 | return
36 |
37 | for k, v in state_dict.items():
38 | if 'head_probs' in k:
39 | block_name = k.replace('head_probs', '')
40 | block_num = int(k.split('.')[1])
41 | head_probs = (state_dict[k] / 1e-2).softmax(0)
42 | num_heads = head_probs.shape[0]
43 | feature_dim = state_dict[block_name + 'v.weight'].shape[0]
44 | head_dim = feature_dim // num_heads
45 |
46 | if MSA_indicators[block_num][-1] == 1:
47 | print('Error: checkpoint and MSA indicators do not match!')
48 | return
49 |
50 | new_v_weight = state_dict[block_name + 'v.weight'].view(num_heads, head_dim, feature_dim).permute(1, 2, 0) @ head_probs
51 | new_v_bias = state_dict[block_name + 'v.bias'].view(num_heads, head_dim).permute(1, 0) @ head_probs
52 | new_proj_weight = state_dict[block_name + 'proj.weight'].view(feature_dim, num_heads, head_dim).permute(0, 2, 1) @ head_probs
53 |
54 | if MSA_indicators[block_num][1] == 1:
55 | bn_name = 'bn_3x3.'
56 | new_dict[block_name + 'bconv.0.weight'] = new_v_weight.permute(2, 0, 1).view(3, 3, head_dim, -1).permute(2, 3, 1, 0)
57 | new_dict[block_name + 'bconv.0.bias'] = new_v_bias.sum(-1)
58 | new_dict[block_name + 'bconv.3.weight'] = new_proj_weight.sum(-1)[..., None, None]
59 | else:
60 | bn_name = 'bn_1x1.'
61 | new_dict[block_name + 'bconv.0.weight'] = new_v_weight[..., 4][..., None, None]
62 | new_dict[block_name + 'bconv.0.bias'] = new_v_bias[..., 4]
63 | new_dict[block_name + 'bconv.3.weight'] = new_proj_weight[..., 4][..., None, None]
64 |
65 | new_dict[block_name + 'bconv.3.bias'] = state_dict[block_name + 'proj.bias']
66 |
67 | new_dict[block_name + 'bconv.1.weight'] = state_dict[block_name + bn_name + 'weight']
68 | new_dict[block_name + 'bconv.1.bias'] = state_dict[block_name + bn_name + 'bias']
69 | new_dict[block_name + 'bconv.1.running_mean'] = state_dict[block_name + bn_name + 'running_mean']
70 | new_dict[block_name + 'bconv.1.running_var'] = state_dict[block_name + bn_name + 'running_var']
71 | new_dict[block_name + 'bconv.1.num_batches_tracked'] = state_dict[block_name + bn_name + 'num_batches_tracked']
72 |
73 | else:
74 | if len(k.split('.')) <= 4 or '.'.join(k.split('.')[:-2]) + '.head_probs' not in state_dict.keys():
75 | new_dict[k] = state_dict[k]
76 |
77 | torch.save({'model': new_dict}, '.'.join(checkpoint_path.split('.')[:-1]) + '_optimized.pth')
78 |
79 |
80 | if __name__ == '__main__':
81 | main()
82 |
--------------------------------------------------------------------------------
/SPViT_DeiT/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.7.0
2 | torchvision==0.8.1
3 | timm==0.3.2
4 |
--------------------------------------------------------------------------------
/SPViT_DeiT/samplers.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2015-present, Facebook, Inc.
2 | # All rights reserved.
3 | # Modifications copyright (c) 2021 Zhuang AI Group, Haoyu He
4 |
5 | import torch
6 | import torch.distributed as dist
7 | import math
8 |
9 |
10 | class RASampler(torch.utils.data.Sampler):
11 | """Sampler that restricts data loading to a subset of the dataset for distributed,
12 | with repeated augmentation.
13 | It ensures that different each augmented version of a sample will be visible to a
14 | different process (GPU)
15 | Heavily based on torch.utils.data.DistributedSampler
16 | """
17 |
18 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
19 | if num_replicas is None:
20 | if not dist.is_available():
21 | raise RuntimeError("Requires distributed package to be available")
22 | num_replicas = dist.get_world_size()
23 | if rank is None:
24 | if not dist.is_available():
25 | raise RuntimeError("Requires distributed package to be available")
26 | rank = dist.get_rank()
27 | self.dataset = dataset
28 | self.num_replicas = num_replicas
29 | self.rank = rank
30 | self.epoch = 0
31 | self.num_samples = int(math.ceil(len(self.dataset) * 3.0 / self.num_replicas))
32 | self.total_size = self.num_samples * self.num_replicas
33 | # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas))
34 | self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
35 | self.shuffle = shuffle
36 |
37 | def __iter__(self):
38 | # deterministically shuffle based on epoch
39 | g = torch.Generator()
40 | g.manual_seed(self.epoch)
41 | if self.shuffle:
42 | indices = torch.randperm(len(self.dataset), generator=g).tolist()
43 | else:
44 | indices = list(range(len(self.dataset)))
45 |
46 | # add extra samples to make it evenly divisible
47 | indices = [ele for ele in indices for i in range(3)]
48 | indices += indices[:(self.total_size - len(indices))]
49 | assert len(indices) == self.total_size
50 |
51 | # subsample
52 | indices = indices[self.rank:self.total_size:self.num_replicas]
53 | assert len(indices) == self.num_samples
54 |
55 | return iter(indices[:self.num_selected_samples])
56 |
57 | def __len__(self):
58 | return self.num_selected_samples
59 |
60 | def set_epoch(self, epoch):
61 | self.epoch = epoch
62 |
--------------------------------------------------------------------------------
/SPViT_DeiT/tox.ini:
--------------------------------------------------------------------------------
1 | [flake8]
2 | max-line-length = 120
3 | ignore = F401,E402,F403,W503,W504
4 |
--------------------------------------------------------------------------------
/SPViT_DeiT/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2015-present, Facebook, Inc.
2 | # All rights reserved.
3 | # Modifications copyright (c) 2021 Zhuang AI Group, Haoyu He
4 |
5 | """
6 | Misc functions, including distributed helpers.
7 |
8 | Mostly copy-paste from torchvision references.
9 | """
10 | import io
11 | import os
12 | import time
13 | from collections import defaultdict, deque
14 | import datetime
15 |
16 | import torch
17 | import torch.distributed as dist
18 | from logger import logger
19 | from timm.scheduler.cosine_lr import CosineLRScheduler
20 | from timm.scheduler.tanh_lr import TanhLRScheduler
21 | from timm.scheduler.step_lr import StepLRScheduler
22 | from timm.scheduler.plateau_lr import PlateauLRScheduler
23 | from torch import optim as optim
24 | from timm.optim.lookahead import Lookahead
25 | from collections import OrderedDict
26 |
27 | try:
28 | from apex.optimizers import FusedAdam
29 | has_apex = True
30 | except ImportError:
31 | has_apex = False
32 |
33 |
34 | class SmoothedValue(object):
35 | """Track a series of values and provide access to smoothed values over a
36 | window or the global series average.
37 | """
38 |
39 | def __init__(self, window_size=20, fmt=None):
40 | if fmt is None:
41 | fmt = "{median:.4f} ({global_avg:.4f})"
42 | self.deque = deque(maxlen=window_size)
43 | self.total = 0.0
44 | self.count = 0
45 | self.fmt = fmt
46 |
47 | def update(self, value, n=1):
48 | self.deque.append(value)
49 | self.count += n
50 | self.total += value * n
51 |
52 | def synchronize_between_processes(self):
53 | """
54 | Warning: does not synchronize the deque!
55 | """
56 | if not is_dist_avail_and_initialized():
57 | return
58 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
59 | dist.barrier()
60 | dist.all_reduce(t)
61 | t = t.tolist()
62 | self.count = int(t[0])
63 | self.total = t[1]
64 |
65 | @property
66 | def median(self):
67 | d = torch.tensor(list(self.deque))
68 | return d.median().item()
69 |
70 | @property
71 | def avg(self):
72 | d = torch.tensor(list(self.deque), dtype=torch.float32)
73 | return d.mean().item()
74 |
75 | @property
76 | def global_avg(self):
77 | return self.total / self.count
78 |
79 | @property
80 | def max(self):
81 | return max(self.deque)
82 |
83 | @property
84 | def value(self):
85 | return self.deque[-1]
86 |
87 | def __str__(self):
88 | return self.fmt.format(
89 | median=self.median,
90 | avg=self.avg,
91 | global_avg=self.global_avg,
92 | max=self.max,
93 | value=self.value)
94 |
95 |
96 | class MetricLogger(object):
97 | def __init__(self, delimiter="\t"):
98 | self.meters = defaultdict(SmoothedValue)
99 | self.delimiter = delimiter
100 |
101 | def update(self, **kwargs):
102 | for k, v in kwargs.items():
103 | if isinstance(v, torch.Tensor):
104 | v = v.item()
105 | assert isinstance(v, (float, int))
106 | self.meters[k].update(v)
107 |
108 | def __getattr__(self, attr):
109 | if attr in self.meters:
110 | return self.meters[attr]
111 | if attr in self.__dict__:
112 | return self.__dict__[attr]
113 | raise AttributeError("'{}' object has no attribute '{}'".format(
114 | type(self).__name__, attr))
115 |
116 | def __str__(self):
117 | loss_str = []
118 | for name, meter in self.meters.items():
119 | loss_str.append(
120 | "{}: {}".format(name, str(meter))
121 | )
122 | return self.delimiter.join(loss_str)
123 |
124 | def synchronize_between_processes(self):
125 | for meter in self.meters.values():
126 | meter.synchronize_between_processes()
127 |
128 | def add_meter(self, name, meter):
129 | self.meters[name] = meter
130 |
131 | def log_every(self, iterable, print_freq, header=None):
132 | i = 0
133 | if not header:
134 | header = ''
135 | start_time = time.time()
136 | end = time.time()
137 | iter_time = SmoothedValue(fmt='{avg:.4f}')
138 | data_time = SmoothedValue(fmt='{avg:.4f}')
139 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
140 | log_msg = [
141 | header,
142 | '[{0' + space_fmt + '}/{1}]',
143 | 'eta: {eta}',
144 | '{meters}',
145 | 'time: {time}',
146 | 'data: {data}'
147 | ]
148 | if torch.cuda.is_available():
149 | log_msg.append('max mem: {memory:.0f}')
150 | log_msg = self.delimiter.join(log_msg)
151 | MB = 1024.0 * 1024.0
152 | for obj in iterable:
153 | data_time.update(time.time() - end)
154 | yield obj
155 | iter_time.update(time.time() - end)
156 | if i % print_freq == 0 or i == len(iterable) - 1:
157 | eta_seconds = iter_time.global_avg * (len(iterable) - i)
158 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
159 | if torch.cuda.is_available():
160 | logger.info(log_msg.format(
161 | i, len(iterable), eta=eta_string,
162 | meters=str(self),
163 | time=str(iter_time), data=str(data_time),
164 | memory=torch.cuda.max_memory_allocated() / MB))
165 | else:
166 | logger.info(log_msg.format(
167 | i, len(iterable), eta=eta_string,
168 | meters=str(self),
169 | time=str(iter_time), data=str(data_time)))
170 | i += 1
171 | end = time.time()
172 | total_time = time.time() - start_time
173 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
174 | logger.info('{} Total time: {} ({:.4f} s / it)'.format(
175 | header, total_time_str, total_time / len(iterable)))
176 |
177 | def log_indicator(self, info):
178 | logger.info('indicators: ' + str(info))
179 |
180 |
181 | def _load_checkpoint_for_ema(model_ema, checkpoint):
182 | """
183 | Workaround for ModelEma._load_checkpoint to accept an already-loaded object
184 | """
185 | mem_file = io.BytesIO()
186 | torch.save(checkpoint, mem_file)
187 | mem_file.seek(0)
188 | model_ema._load_checkpoint(mem_file)
189 |
190 |
191 | def setup_for_distributed(is_master):
192 | """
193 | This function disables printing when not in master process
194 | """
195 | import builtins as __builtin__
196 | builtin_print = __builtin__.print
197 |
198 | def print(*args, **kwargs):
199 | force = kwargs.pop('force', False)
200 | if is_master or force:
201 | builtin_print(*args, **kwargs)
202 |
203 | __builtin__.print = print
204 |
205 |
206 | def is_dist_avail_and_initialized():
207 | if not dist.is_available():
208 | return False
209 | if not dist.is_initialized():
210 | return False
211 | return True
212 |
213 |
214 | def get_world_size():
215 | if not is_dist_avail_and_initialized():
216 | return 1
217 | return dist.get_world_size()
218 |
219 |
220 | def get_rank():
221 | if not is_dist_avail_and_initialized():
222 | return 0
223 | return dist.get_rank()
224 |
225 |
226 | def is_main_process():
227 | return get_rank() == 0
228 |
229 |
230 | def save_on_master(*args, **kwargs):
231 | if is_main_process():
232 | torch.save(*args, **kwargs)
233 |
234 |
235 | def init_distributed_mode(args):
236 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
237 | args.rank = int(os.environ["RANK"])
238 | args.world_size = int(os.environ['WORLD_SIZE'])
239 | args.gpu = int(os.environ['LOCAL_RANK'])
240 | elif 'SLURM_PROCID' in os.environ:
241 | args.rank = int(os.environ['SLURM_PROCID'])
242 | args.gpu = args.rank % torch.cuda.device_count()
243 | else:
244 | logger.info('Not using distributed mode')
245 | args.distributed = False
246 | return
247 |
248 | args.distributed = True
249 |
250 | torch.cuda.set_device(args.gpu)
251 | args.dist_backend = 'nccl'
252 | logger.info('| distributed init (rank {}): {}'.format(
253 | args.rank, args.dist_url))
254 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
255 | world_size=args.world_size, rank=args.rank)
256 | torch.distributed.barrier()
257 | setup_for_distributed(args.rank == 0)
258 |
259 |
260 | def create_scheduler(args, optimizer, epochs, warmup_epochs, min_lr):
261 | num_epochs = epochs
262 |
263 | if getattr(args, 'lr_noise', None) is not None:
264 | lr_noise = getattr(args, 'lr_noise')
265 | if isinstance(lr_noise, (list, tuple)):
266 | noise_range = [n * num_epochs for n in lr_noise]
267 | if len(noise_range) == 1:
268 | noise_range = noise_range[0]
269 | else:
270 | noise_range = lr_noise * num_epochs
271 | else:
272 | noise_range = None
273 |
274 | lr_scheduler = None
275 | if args.sched == 'cosine':
276 | lr_scheduler = CosineLRScheduler(
277 | optimizer,
278 | t_initial=num_epochs,
279 | t_mul=getattr(args, 'lr_cycle_mul', 1.),
280 | lr_min=args.min_lr,
281 | decay_rate=args.decay_rate,
282 | warmup_lr_init=args.warmup_lr,
283 | warmup_t=warmup_epochs,
284 | cycle_limit=getattr(args, 'lr_cycle_limit', 1),
285 | t_in_epochs=True,
286 | noise_range_t=noise_range,
287 | noise_pct=getattr(args, 'lr_noise_pct', 0.67),
288 | noise_std=getattr(args, 'lr_noise_std', 1.),
289 | noise_seed=getattr(args, 'seed', 42),
290 | )
291 | num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs
292 | elif args.sched == 'tanh':
293 | lr_scheduler = TanhLRScheduler(
294 | optimizer,
295 | t_initial=num_epochs,
296 | t_mul=getattr(args, 'lr_cycle_mul', 1.),
297 | lr_min=args.min_lr,
298 | warmup_lr_init=args.warmup_lr,
299 | warmup_t=args.warmup_epochs,
300 | cycle_limit=getattr(args, 'lr_cycle_limit', 1),
301 | t_in_epochs=True,
302 | noise_range_t=noise_range,
303 | noise_pct=getattr(args, 'lr_noise_pct', 0.67),
304 | noise_std=getattr(args, 'lr_noise_std', 1.),
305 | noise_seed=getattr(args, 'seed', 42),
306 | )
307 | num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs
308 | elif args.sched == 'step':
309 | lr_scheduler = StepLRScheduler(
310 | optimizer,
311 | decay_t=args.decay_epochs,
312 | decay_rate=args.decay_rate,
313 | warmup_lr_init=args.warmup_lr,
314 | warmup_t=args.warmup_epochs,
315 | noise_range_t=noise_range,
316 | noise_pct=getattr(args, 'lr_noise_pct', 0.67),
317 | noise_std=getattr(args, 'lr_noise_std', 1.),
318 | noise_seed=getattr(args, 'seed', 42),
319 | )
320 | elif args.sched == 'plateau':
321 | mode = 'min' if 'loss' in getattr(args, 'eval_metric', '') else 'max'
322 | lr_scheduler = PlateauLRScheduler(
323 | optimizer,
324 | decay_rate=args.decay_rate,
325 | patience_t=args.patience_epochs,
326 | lr_min=args.min_lr,
327 | mode=mode,
328 | warmup_lr_init=args.warmup_lr,
329 | warmup_t=args.warmup_epochs,
330 | cooldown_t=0,
331 | noise_range_t=noise_range,
332 | noise_pct=getattr(args, 'lr_noise_pct', 0.67),
333 | noise_std=getattr(args, 'lr_noise_std', 1.),
334 | noise_seed=getattr(args, 'seed', 42),
335 | )
336 |
337 | return lr_scheduler, num_epochs
338 |
339 |
340 | def add_weight_decay_2ops(model, weight_decay=1e-5, skip_list=()):
341 | decay = []
342 | no_decay = []
343 | diff_lr = []
344 | for name, param in model.named_parameters():
345 | if not param.requires_grad:
346 | continue # frozen weights
347 | if 'thresholds' in name:
348 | diff_lr.append(param)
349 | elif len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
350 | no_decay.append(param)
351 | else:
352 | decay.append(param)
353 |
354 | return [
355 | {'params': no_decay, 'weight_decay': 0.},
356 | {'params': decay, 'weight_decay': weight_decay}], [
357 | {'params': diff_lr, 'weight_decay': 0.}]
358 |
359 |
360 | def create_2optimizers(args, model, filter_bias_and_bn=True):
361 | opt_lower = args.opt.lower()
362 | weight_decay = args.weight_decay
363 | if weight_decay and filter_bias_and_bn:
364 | skip = {}
365 | if hasattr(model, 'no_weight_decay'):
366 | skip = model.no_weight_decay()
367 | parameters1, parameters2 = add_weight_decay_2ops(model, weight_decay, skip)
368 | weight_decay = 0.
369 | else:
370 | parameters = model.parameters()
371 |
372 | if 'fused' in opt_lower:
373 | assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
374 |
375 | opt_args1 = dict(lr=args.lr, weight_decay=weight_decay)
376 | if hasattr(args, 'opt_eps') and args.opt_eps is not None:
377 | opt_args1['eps'] = args.opt_eps
378 | if hasattr(args, 'opt_betas') and args.opt_betas is not None:
379 | opt_args1['betas'] = args.opt_betas
380 |
381 | opt_args2 = dict(lr=args.arc_lr, weight_decay=0.)
382 | if hasattr(args, 'opt_eps') and args.opt_eps is not None:
383 | opt_args2['eps'] = args.opt_eps
384 | if hasattr(args, 'opt_betas') and args.opt_betas is not None:
385 | opt_args2['betas'] = args.opt_betas
386 |
387 | opt_split = opt_lower.split('_')
388 | opt_lower = opt_split[-1]
389 | if opt_lower == 'adamw':
390 | optimizer1 = optim.AdamW(parameters1, **opt_args1)
391 | optimizer2 = optim.AdamW(parameters2, **opt_args2)
392 |
393 | elif opt_lower == 'fusedadamw':
394 | optimizer1 = FusedAdam(parameters1, adam_w_mode=True, **opt_args1)
395 | optimizer2 = FusedAdam(parameters2, adam_w_mode=True, **opt_args2)
396 | else:
397 | assert False and "Invalid optimizer"
398 | raise ValueError
399 |
400 | return optimizer1, optimizer2
401 |
402 |
403 | class NativeScaler:
404 | state_dict_key = "amp_scaler"
405 |
406 | def __init__(self):
407 | self._scaler = torch.cuda.amp.GradScaler()
408 |
409 | def __call__(self, loss, optimizer1, optimizer2, clip_grad=None, parameters=None, create_graph=False, model=None):
410 | self._scaler.scale(loss).backward(create_graph=create_graph)
411 |
412 | self._scaler.step(optimizer1)
413 |
414 | if optimizer2:
415 | self._scaler.step(optimizer2)
416 | self._scaler.update()
417 |
418 | def state_dict(self):
419 | return self._scaler.state_dict()
420 |
421 | def load_state_dict(self, state_dict):
422 | self._scaler.load_state_dict(state_dict)
423 |
424 |
425 | def prune_ffn(checkpoint_dict, ffn_indicators):
426 |
427 | depth = 12
428 |
429 | for i in range(depth):
430 | assigned_indicator_index = ffn_indicators[i].nonzero().squeeze(-1)
431 |
432 | in_dim = checkpoint_dict[f'blocks.{i}.mlp.fc1.weight'].shape[1]
433 | checkpoint_dict[f'blocks.{i}.mlp.fc1.weight'] = torch.gather(checkpoint_dict[f'blocks.{i}.mlp.fc1.weight'], 0,
434 | assigned_indicator_index.unsqueeze(-1).expand(-1, in_dim))
435 | checkpoint_dict[f'blocks.{i}.mlp.fc1.bias'] = torch.gather(checkpoint_dict[f'blocks.{i}.mlp.fc1.bias'], 0,
436 | assigned_indicator_index)
437 | checkpoint_dict[f'blocks.{i}.mlp.fc2.weight'] = torch.gather(checkpoint_dict[f'blocks.{i}.mlp.fc2.weight'], 1,
438 | assigned_indicator_index.unsqueeze(0).expand(in_dim, -1))
439 |
440 | i += 1
441 | return checkpoint_dict
442 |
443 |
444 | def save_ffn_indicators(model, epoch, logger, output_dir):
445 | new_dict = OrderedDict()
446 |
447 | old_dict = model.state_dict()
448 | for name in old_dict.keys():
449 | if 'assigned_indicator_index' in name:
450 | new_dict[name] = old_dict[name]
451 |
452 | save_path = os.path.join(output_dir, f'search_{epoch}epoch.pth')
453 | logger.info(f"{save_path} saving......")
454 | save_on_master({
455 | 'model': model.state_dict()
456 | }, save_path)
457 | logger.info(f"{save_path} saved !!!")
458 |
--------------------------------------------------------------------------------
/SPViT_Swin/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 | /output/
131 | /detectron2/
132 | detectron2
133 | detectron2/
134 | /detectron2.OLD/
135 | /detectron2.WRONG/
136 | /detectron2.egg-info/
137 |
--------------------------------------------------------------------------------
/SPViT_Swin/README.md:
--------------------------------------------------------------------------------
1 | ### Getting started on SPViT-Swin:
2 |
3 | #### Installation and data preparation
4 |
5 | - First, you can install the required environments as illustrated in the [Swin](https://github.com/microsoft/Swin-Transformer/blob/main/get_started.md) repository or follow the instructions below:
6 |
7 | ```bash
8 | # Create virtual env
9 | conda create -n spvit-swin python=3.7 -y
10 | conda activate spvit-swin
11 |
12 | # Install PyTorch
13 | conda install pytorch==1.7.1 torchvision==0.8.2 cudatoolkit=10.1 -c pytorch
14 | pip install timm==0.3.2
15 |
16 | # Install Apex
17 | git clone https://github.com/NVIDIA/apex
18 | cd apex
19 | pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
20 |
21 | # Install other requirements:
22 | pip install opencv-python==4.4.0.46 termcolor==1.1.0 yacs==0.1.8
23 | ```
24 |
25 | - Next, install some other dependencies that are required by SPViT:
26 |
27 | ```bash
28 | pip install tensorboardX tensorboard
29 | ```
30 |
31 | - Please refer to the [Swin](https://github.com/microsoft/Swin-Transformer/blob/main/get_started.md) repository to prepare the standard ImageNet dataset, then link the ImageNet dataset under the `dataset`folder:
32 |
33 | ```bash
34 | $ tree dataset
35 | imagenet
36 | ├── train
37 | │ ├── class1
38 | │ │ ├── img1.jpeg
39 | │ │ ├── img2.jpeg
40 | │ │ └── ...
41 | │ ├── class2
42 | │ │ ├── img3.jpeg
43 | │ │ └── ...
44 | │ └── ...
45 | └── val
46 | ├── class1
47 | │ ├── img4.jpeg
48 | │ ├── img5.jpeg
49 | │ └── ...
50 | ├── class2
51 | │ ├── img6.jpeg
52 | │ └── ...
53 | └── ...
54 | ```
55 |
56 | #### Download pretrained models
57 |
58 | - We start searching and fine-tuneing both from the pre-trained models.
59 |
60 | - Since we provide training scripts for three Swin models: Swin-T, Swin-S and Swin-B, please download the corresponding three pre-trained models from the [Swin](https://github.com/microsoft/Swin-Transformer/blob/main/get_started.md) repository as well.
61 |
62 | - Next, move the downloaded pre-trained models into the following file structure:
63 |
64 | ```bash
65 | $ tree model
66 | ├── swin_base_patch4_window7_224.pth
67 | ├── swin_small_patch4_window7_224.pth
68 | ├── swin_tiny_patch4_window7_224.pth
69 | ```
70 |
71 | - Note that do not change the filenames for the pre-trained models as we hard-coded these filenames when tailoring and loading the pre-trained models. Feel free to modify the hard-coded parts when pruning from other pre-trained models.
72 |
73 | #### Searching
74 |
75 | To search architectures with SPViT-Swin-T, run:
76 |
77 | ```bash
78 | python -m torch.distributed.launch --nproc_per_node 8 --master_port 3132 main_pruning.py --cfg configs/spvit_swin_tn_l28_t32_search.yaml --resume model/swin_tiny_patch4_window7_224.pth
79 | ```
80 |
81 | To search architectures with SPViT-Swin-S, run:
82 |
83 | ```bash
84 | python -m torch.distributed.launch --nproc_per_node 8 --master_port 3132 main_pruning.py --cfg configs/spvit_swin_sm_l04_t55_search.yaml --resume model/swin_small_patch4_window7_224.pth
85 | ```
86 |
87 | To search architectures with SPViT-Swin-B, run:
88 |
89 | ```bash
90 | python -m torch.distributed.launch --nproc_per_node 8 --master_port 3132 main_pruning.py --cfg configs/spvit_swin_bs_l01_t100_search.yaml --resume model/swin_base_patch4_window7_224.pth
91 | ```
92 |
93 | #### Fine-tuning
94 |
95 | You can start fine-tuning from either your own searched architectures or from our provided architectures by modifying and assigning the MSA indicators in `assigned_indicators` and the FFN indicators in `searching_model`.
96 |
97 | To fine-tune architectures searched by SPViT-Swin-T, run:
98 |
99 | ```bash
100 | python -m torch.distributed.launch --nproc_per_node 8 --master_port 3132 main_pruning.py --cfg configs/spvit_swin_tn_l28_t32_ft.yaml --resume model/swin_tiny_patch4_window7_224.pth
101 | ```
102 |
103 | To fine-tune the architectures with SPViT-Swin-S, run:
104 |
105 | ```bash
106 | python -m torch.distributed.launch --nproc_per_node 8 --master_port 3132 main_pruning.py --cfg configs/spvit_swin_sm_l04_t55_ft.yaml --resume model/swin_small_patch4_window7_224.pth
107 | ```
108 |
109 | To fine-tune the architectures with SPViT-Swin-B, run:
110 |
111 | ```bash
112 | python -m torch.distributed.launch --nproc_per_node 8 --master_port 3132 main_pruning.py --cfg configs/spvit_swin_bs_l01_t100_ft.yaml --resume model/swin_base_patch4_window7_224.pth
113 | ```
114 |
115 | #### Evaluation
116 |
117 | We provide several examples for evaluating pre-trained SPViT models.
118 |
119 | To evaluate SPViT-Swin-T pre-trained models, run:
120 |
121 | ```bash
122 | python -m torch.distributed.launch --nproc_per_node 1 --master_port 3132 main_pruning.py --cfg configs/spvit_swin_tn_l28_t32_ft.yaml --resume [PRE-TRAINED MODEL PATH] --opts EVAL_MODE True
123 | ```
124 |
125 | To evaluate SPViT-Swin-S pre-trained models, run:
126 |
127 | ```bash
128 | python -m torch.distributed.launch --nproc_per_node 1 --master_port 3132 main_pruning.py --cfg configs/spvit_swin_sm_l04_t55_ft.yaml --resume [PRE-TRAINED MODEL PATH] --opts EVAL_MODE True
129 | ```
130 |
131 | To evaluate SPViT-Swin-B pre-trained models, run:
132 |
133 | ```bash
134 | python -m torch.distributed.launch --nproc_per_node 1 --master_port 3132 main_pruning.py --cfg configs/spvit_swin_bs_l01_t100_ft.yaml --resume [PRE-TRAINED MODEL PATH] --opts EVAL_MODE True
135 | ```
136 |
137 | After fine-tuning, you can optimize your checkpoint to a smaller size with the following code:
138 | ```bash
139 | python post_training_optimize_checkpoint.py YOUR_CHECKPOINT_PATH
140 | ```
141 | The optimized checkpoint can be evaluated by replacing `UnifiedWindowAttention` with `UnifiedWindowAttentionParamOpt` and we have provided an example below:
142 | ```bash
143 | main_pruning.py
144 | --cfg
145 | configs/spvit_swin_tn_l28_t32_ft_dist.yaml
146 | --resume
147 | model/spvit_swin_t_l28_t32_dist_optimized.pth
148 | --opts
149 | EVAL_MODE
150 | True
151 | EXTRA.attention_type
152 | UnifiedWindowAttentionParamOpt
153 | --local_rank
154 | 0
155 | ```
156 | ####
157 |
158 | #### TODO:
159 |
160 | ```
161 | - [x] Release code.
162 | - [x] Release pre-trained models.
163 | ```
164 |
--------------------------------------------------------------------------------
/SPViT_Swin/config.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Swin Transformer
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ze Liu
6 | # --------------------------------------------------------'
7 | # Modifications copyright (c) 2021 Zhuang AI Group, Haoyu He
8 |
9 | import os
10 | import yaml
11 | from yacs.config import CfgNode as CN
12 |
13 | _C = CN()
14 |
15 | # Base config files
16 | _C.BASE = ['']
17 |
18 | # -----------------------------------------------------------------------------
19 | # Data settings
20 | # -----------------------------------------------------------------------------
21 | _C.DATA = CN()
22 | # Batch size for a single GPU, could be overwritten by command line argument
23 | _C.DATA.BATCH_SIZE = 128
24 | # Path to dataset, could be overwritten by command line argument
25 | _C.DATA.DATA_PATH = ''
26 | # Dataset name
27 | _C.DATA.DATASET = 'imagenet'
28 | # Input image size
29 | _C.DATA.IMG_SIZE = 224
30 | # Interpolation to resize image (random, bilinear, bicubic)
31 | _C.DATA.INTERPOLATION = 'bicubic'
32 | # Use zipped dataset instead of folder dataset
33 | # could be overwritten by command line argument
34 | _C.DATA.ZIP_MODE = False
35 | # Cache Data in Memory, could be overwritten by command line argument
36 | _C.DATA.CACHE_MODE = 'part'
37 | # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.
38 | _C.DATA.PIN_MEMORY = True
39 | # Number of data loading threads
40 | _C.DATA.NUM_WORKERS = 8
41 |
42 | # -----------------------------------------------------------------------------
43 | # Model settings
44 | # -----------------------------------------------------------------------------
45 | _C.MODEL = CN()
46 | # Model type
47 | _C.MODEL.TYPE = 'swin'
48 | # Model name
49 | _C.MODEL.NAME = 'swin_tiny_patch4_window7_224'
50 | # Checkpoint to resume, could be overwritten by command line argument
51 | _C.MODEL.RESUME = ''
52 | # Number of classes, overwritten in data preparation
53 | _C.MODEL.NUM_CLASSES = 1000
54 | # Dropout rate
55 | _C.MODEL.DROP_RATE = 0.0
56 | # Drop path rate
57 | _C.MODEL.DROP_PATH_RATE = 0.1
58 | # Label Smoothing
59 | _C.MODEL.LABEL_SMOOTHING = 0.1
60 |
61 | # Swin Transformer parameters
62 | _C.MODEL.SWIN = CN()
63 | _C.MODEL.SWIN.PATCH_SIZE = 4
64 | _C.MODEL.SWIN.IN_CHANS = 3
65 | _C.MODEL.SWIN.EMBED_DIM = 96
66 | _C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2]
67 | _C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24]
68 | _C.MODEL.SWIN.WINDOW_SIZE = 7
69 | _C.MODEL.SWIN.MLP_RATIO = 4.
70 | _C.MODEL.SWIN.QKV_BIAS = True
71 | _C.MODEL.SWIN.QK_SCALE = None
72 | _C.MODEL.SWIN.APE = False
73 | _C.MODEL.SWIN.PATCH_NORM = True
74 |
75 | # -----------------------------------------------------------------------------
76 | # Training settings
77 | # -----------------------------------------------------------------------------
78 | _C.TRAIN = CN()
79 | _C.TRAIN.START_EPOCH = 0
80 | _C.TRAIN.EPOCHS = 300
81 | _C.TRAIN.WARMUP_EPOCHS = 20
82 | _C.TRAIN.WEIGHT_DECAY = 0.05
83 | _C.TRAIN.BASE_LR = 5e-4
84 | _C.TRAIN.WARMUP_LR = 5e-7
85 | _C.TRAIN.MIN_LR = 5e-6
86 | # Clip gradient norm
87 | _C.TRAIN.CLIP_GRAD = 5.0
88 | # Auto resume from latest checkpoint
89 | _C.TRAIN.AUTO_RESUME = True
90 | # Gradient accumulation steps
91 | # could be overwritten by command line argument
92 | _C.TRAIN.ACCUMULATION_STEPS = 0
93 | # Whether to use gradient checkpointing to save memory
94 | # could be overwritten by command line argument
95 | _C.TRAIN.USE_CHECKPOINT = False
96 |
97 | # LR scheduler
98 | _C.TRAIN.LR_SCHEDULER = CN()
99 | _C.TRAIN.LR_SCHEDULER.NAME = 'cosine'
100 | # Epoch interval to decay LR, used in StepLRScheduler
101 | _C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30
102 | # LR decay rate, used in StepLRScheduler
103 | _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1
104 |
105 | # Optimizer
106 | _C.TRAIN.OPTIMIZER = CN()
107 | _C.TRAIN.OPTIMIZER.NAME = 'adamw'
108 | # Optimizer Epsilon
109 | _C.TRAIN.OPTIMIZER.EPS = 1e-8
110 | # Optimizer Betas
111 | _C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999)
112 | # SGD momentum
113 | _C.TRAIN.OPTIMIZER.MOMENTUM = 0.9
114 |
115 | # -----------------------------------------------------------------------------
116 | # Augmentation settings
117 | # -----------------------------------------------------------------------------
118 | _C.AUG = CN()
119 | # Color jitter factor
120 | _C.AUG.COLOR_JITTER = 0.4
121 | # Use AutoAugment policy. "v0" or "original"
122 | _C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1'
123 | # Random erase prob
124 | _C.AUG.REPROB = 0.25
125 | # Random erase mode
126 | _C.AUG.REMODE = 'pixel'
127 | # Random erase count
128 | _C.AUG.RECOUNT = 1
129 | # Mixup alpha, mixup enabled if > 0
130 | _C.AUG.MIXUP = 0.8
131 | # Cutmix alpha, cutmix enabled if > 0
132 | _C.AUG.CUTMIX = 1.0
133 | # Cutmix min/max ratio, overrides alpha and enables cutmix if set
134 | _C.AUG.CUTMIX_MINMAX = None
135 | # Probability of performing mixup or cutmix when either/both is enabled
136 | _C.AUG.MIXUP_PROB = 1.0
137 | # Probability of switching to cutmix when both mixup and cutmix enabled
138 | _C.AUG.MIXUP_SWITCH_PROB = 0.5
139 | # How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
140 | _C.AUG.MIXUP_MODE = 'batch'
141 |
142 | # -----------------------------------------------------------------------------
143 | # Augmentation settings
144 | # -----------------------------------------------------------------------------
145 | _C.EXTRA = CN()
146 |
147 | # Fine-tuning settings
148 | _C.EXTRA.searching_model = None
149 | _C.EXTRA.assigned_indicators = None
150 |
151 | # Architecture hyper-parameters
152 | _C.EXTRA.architecture_lr = 5e-4
153 | _C.EXTRA.arc_decay = 100
154 | _C.EXTRA.arc_warmup = 20
155 | _C.EXTRA.arc_min_lr = 5e-6
156 |
157 | # Hyper-parameters
158 | _C.EXTRA.theta = 0.5 # Bernoulli gates' initial parameter
159 | _C.EXTRA.alpha = 1e2 # Softmax temperature for ensembling heads
160 | _C.EXTRA.loss_lambda = 0.14
161 | _C.EXTRA.target_flops = 3.6
162 |
163 | _C.EXTRA.teacher_model = 'regnety_160'
164 | _C.EXTRA.teacher_path = 'https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth'
165 | _C.EXTRA.distillation_type = 'none'
166 | _C.EXTRA.distillation_alpha = 0.5
167 | _C.EXTRA.distillation_tau = 1.0
168 | _C.EXTRA.attention_type = 'UnifiedWindowAttention'
169 |
170 | # -----------------------------------------------------------------------------
171 | # Testing settings
172 | # -----------------------------------------------------------------------------
173 | _C.TEST = CN()
174 | # Whether to use center crop when testing
175 | _C.TEST.CROP = True
176 |
177 | # -----------------------------------------------------------------------------
178 | # Misc
179 | # -----------------------------------------------------------------------------
180 | # Mixed precision opt level, if O0, no amp is used ('O0', 'O1', 'O2')
181 | # overwritten by command line argument
182 | _C.AMP_OPT_LEVEL = ''
183 | # Path to output folder, overwritten by command line argument
184 | _C.OUTPUT = ''
185 | # Tag of experiment, overwritten by command line argument
186 | _C.TAG = 'default'
187 | # Frequency to save checkpoint
188 | _C.SAVE_FREQ = 50
189 | # Frequency to logging info
190 | _C.PRINT_FREQ = 10
191 | # Fixed random seed
192 | _C.SEED = 0
193 | # Perform evaluation only, overwritten by command line argument
194 | _C.EVAL_MODE = False
195 | # Test throughput only, overwritten by command line argument
196 | _C.THROUGHPUT_MODE = False
197 | # local rank for DistributedDataParallel, given by command line argument
198 | _C.LOCAL_RANK = 0
199 |
200 |
201 | def _update_config_from_file(config, cfg_file):
202 | config.defrost()
203 | with open(cfg_file, 'r') as f:
204 | yaml_cfg = yaml.load(f, Loader=yaml.FullLoader)
205 |
206 | for cfg in yaml_cfg.setdefault('BASE', ['']):
207 | if cfg:
208 | _update_config_from_file(
209 | config, os.path.join(os.path.dirname(cfg_file), cfg)
210 | )
211 |
212 | print('=> merge config from {}'.format(cfg_file))
213 |
214 | config.merge_from_file(cfg_file)
215 | config.freeze()
216 |
217 |
218 | def update_config(config, args):
219 | _update_config_from_file(config, args.cfg)
220 |
221 | config.defrost()
222 | if args.opts:
223 | config.merge_from_list(args.opts)
224 |
225 | # merge from specific arguments
226 | if args.batch_size:
227 | config.DATA.BATCH_SIZE = args.batch_size
228 | if args.data_path:
229 | config.DATA.DATA_PATH = args.data_path
230 | if args.zip:
231 | config.DATA.ZIP_MODE = True
232 | if args.cache_mode:
233 | config.DATA.CACHE_MODE = args.cache_mode
234 | if args.resume:
235 | config.MODEL.RESUME = args.resume
236 | if args.accumulation_steps:
237 | config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps
238 | if args.use_checkpoint:
239 | config.TRAIN.USE_CHECKPOINT = True
240 | if args.amp_opt_level:
241 | config.AMP_OPT_LEVEL = args.amp_opt_level
242 | if args.output:
243 | config.OUTPUT = args.output
244 | if args.tag:
245 | config.TAG = args.tag
246 | if args.eval:
247 | config.EVAL_MODE = True
248 | if args.throughput:
249 | config.THROUGHPUT_MODE = True
250 |
251 | # set local rank for distributed training
252 | config.LOCAL_RANK = args.local_rank
253 |
254 | # output folder
255 | config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME, config.TAG)
256 | config.freeze()
257 |
258 |
259 | def get_config(args):
260 | """Get a yacs CfgNode object with default values."""
261 | # Return a clone so that the defaults will not be altered
262 | # This is for the "local variable" use pattern
263 | config = _C.clone()
264 | update_config(config, args)
265 |
266 | return config
267 |
--------------------------------------------------------------------------------
/SPViT_Swin/configs/spvit_swin_bs_l01_t100_ft.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: spvit_swin
3 | NAME: spvit_swin_bs_l01_t100_ft
4 | DROP_PATH_RATE: 0.5
5 | SWIN:
6 | EMBED_DIM: 128
7 | DEPTHS: [ 2, 2, 18, 2 ]
8 | NUM_HEADS: [ 4, 8, 16, 32 ]
9 | WINDOW_SIZE: 7
10 | DATA:
11 | NUM_WORKERS: 10
12 | BATCH_SIZE: 128
13 | DATA_PATH: dataset/imagenet
14 | DATASET: imagenet
15 | EXTRA:
16 | loss_lambda: 0.1
17 | arc_decay: 150
18 | target_flops: 10.0
19 | arc_warmup: 0
20 | arc_min_lr: 5e-4
21 | architecture_lr: 5e-4
22 | alpha: 1e2
23 | theta: 1.5
24 | assigned_indicators: [[[1.0, 1.0, 1.0], [0.0, 0.0, 0.0]], [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [0.0, 0.0, 0.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 1.0, 1.0], [0.0, 0.0, 0.0]], [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]]
25 | searching_model: 'ffn_indicators/spvit_swin_bs_l01_t100_search_20epoch.pth'
26 | TRAIN:
27 | EPOCHS: 130
28 | WARMUP_EPOCHS: 0
29 | BASE_LR: 5e-5
30 | #EVAL_MODE: True
--------------------------------------------------------------------------------
/SPViT_Swin/configs/spvit_swin_bs_l01_t100_search.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: spvit_swin
3 | NAME: spvit_swin_bs_l01_t100_search
4 | DROP_PATH_RATE: 0.5
5 | SWIN:
6 | EMBED_DIM: 128
7 | DEPTHS: [ 2, 2, 18, 2 ]
8 | NUM_HEADS: [ 4, 8, 16, 32 ]
9 | WINDOW_SIZE: 7
10 | DATA:
11 | NUM_WORKERS: 10
12 | BATCH_SIZE: 92
13 | DATA_PATH: dataset/imagenet
14 | DATASET: imagenet
15 | EXTRA:
16 | loss_lambda: 0.1
17 | arc_decay: 150
18 | target_flops: 10.0
19 | arc_warmup: 0
20 | arc_min_lr: 5e-4
21 | architecture_lr: 5e-4
22 | alpha: 1e2
23 | theta: 1.5
24 | TRAIN:
25 | EPOCHS: 300
26 | WARMUP_EPOCHS: 0
27 | BASE_LR: 5e-5
28 | MIN_LR: 5e-5
29 | #EVAL_MODE: True
--------------------------------------------------------------------------------
/SPViT_Swin/configs/spvit_swin_sm_l04_t55_ft.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: spvit_swin
3 | NAME: spvit_swin_sm_l04_t55_ft
4 | DROP_PATH_RATE: 0.3
5 | SWIN:
6 | EMBED_DIM: 96
7 | DEPTHS: [ 2, 2, 18, 2 ]
8 | NUM_HEADS: [ 3, 6, 12, 24 ]
9 | WINDOW_SIZE: 7
10 | DATA:
11 | NUM_WORKERS: 10
12 | BATCH_SIZE: 128
13 | DATA_PATH: dataset/imagenet
14 | DATASET: imagenet
15 | EXTRA:
16 | loss_lambda: 0.4
17 | arc_decay: 150
18 | target_flops: 5.5
19 | arc_warmup: 0
20 | arc_min_lr: 5e-4
21 | architecture_lr: 5e-4
22 | alpha: 1e2
23 | theta: 1.5
24 | assigned_indicators: [[[1.0, 0.0, 0.0], [0.0, 0.0, 0.0]], [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]]
25 | searching_model: 'ffn_indicators/spvit_swin_sm_l04_t55_search_14epoch.pth'
26 | TRAIN:
27 | EPOCHS: 130
28 | WARMUP_EPOCHS: 0
29 | BASE_LR: 5e-5
30 | #EVAL_MODE: True
--------------------------------------------------------------------------------
/SPViT_Swin/configs/spvit_swin_sm_l04_t55_ft_dist.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: spvit_swin
3 | NAME: spvit_swin_sm_l04_t55_ft_dist
4 | DROP_PATH_RATE: 0.3
5 | SWIN:
6 | EMBED_DIM: 96
7 | DEPTHS: [ 2, 2, 18, 2 ]
8 | NUM_HEADS: [ 3, 6, 12, 24 ]
9 | WINDOW_SIZE: 7
10 | DATA:
11 | NUM_WORKERS: 10
12 | BATCH_SIZE: 128
13 | DATA_PATH: dataset/imagenet
14 | DATASET: imagenet
15 | EXTRA:
16 | loss_lambda: 0.4
17 | arc_decay: 150
18 | target_flops: 5.5
19 | arc_warmup: 0
20 | arc_min_lr: 5e-4
21 | architecture_lr: 5e-4
22 | alpha: 1e2
23 | theta: 1.5
24 | assigned_indicators: [[[1.0, 0.0, 0.0], [0.0, 0.0, 0.0]], [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]]
25 | searching_model: 'ffn_indicators/spvit_swin_sm_l04_t55_search_14epoch.pth'
26 | distillation_type: 'hard'
27 | TRAIN:
28 | EPOCHS: 200
29 | WARMUP_EPOCHS: 0
30 | BASE_LR: 5e-5
31 | #EVAL_MODE: True
--------------------------------------------------------------------------------
/SPViT_Swin/configs/spvit_swin_sm_l04_t55_search.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: spvit_swin
3 | NAME: spvit_swin_sm_l04_t55_search
4 | DROP_PATH_RATE: 0.3
5 | SWIN:
6 | EMBED_DIM: 96
7 | DEPTHS: [ 2, 2, 18, 2 ]
8 | NUM_HEADS: [ 3, 6, 12, 24 ]
9 | WINDOW_SIZE: 7
10 | DATA:
11 | NUM_WORKERS: 10
12 | BATCH_SIZE: 128
13 | DATA_PATH: dataset/imagenet
14 | DATASET: imagenet
15 | EXTRA:
16 | loss_lambda: 0.4
17 | arc_decay: 150
18 | target_flops: 5.5
19 | arc_warmup: 0
20 | arc_min_lr: 5e-4
21 | architecture_lr: 5e-4
22 | alpha: 1e2
23 | theta: 1.5
24 | TRAIN:
25 | EPOCHS: 300
26 | WARMUP_EPOCHS: 0
27 | BASE_LR: 5e-5
28 | MIN_LR: 5e-5
29 | #EVAL_MODE: True
--------------------------------------------------------------------------------
/SPViT_Swin/configs/spvit_swin_tn_l28_t32_ft.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: spvit_swin
3 | NAME: spvit_swin_tn_l28_t32_ft
4 | DROP_PATH_RATE: 0.2
5 | SWIN:
6 | EMBED_DIM: 96
7 | DEPTHS: [ 2, 2, 6, 2 ]
8 | NUM_HEADS: [ 3, 6, 12, 24 ]
9 | WINDOW_SIZE: 7
10 | DATA:
11 | NUM_WORKERS: 10
12 | BATCH_SIZE: 128
13 | DATA_PATH: dataset/imagenet
14 | DATASET: imagenet
15 | EXTRA:
16 | loss_lambda: 2.8
17 | arc_decay: 150
18 | target_flops: 3.2
19 | arc_warmup: 0
20 | arc_min_lr: 5e-4
21 | architecture_lr: 5e-4
22 | alpha: 1e2
23 | theta: 1.5
24 | assigned_indicators: [[[1.0, 0.0, 0.0], [0.0, 0.0, 0.0]], [[1.0, 0.0, 0.0], [0.0, 0.0, 0.0]], [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]]
25 | searching_model: 'ffn_indicators/spvit_swin_t_l28_t32_search_12epoch.pth'
26 | TRAIN:
27 | EPOCHS: 130
28 | WARMUP_EPOCHS: 0
29 | BASE_LR: 5e-5
30 | #EVAL_MODE: True
--------------------------------------------------------------------------------
/SPViT_Swin/configs/spvit_swin_tn_l28_t32_ft_dist.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: spvit_swin
3 | NAME: spvit_swin_tn_l28_t32_ft_dist
4 | DROP_PATH_RATE: 0.2
5 | SWIN:
6 | EMBED_DIM: 96
7 | DEPTHS: [ 2, 2, 6, 2 ]
8 | NUM_HEADS: [ 3, 6, 12, 24 ]
9 | WINDOW_SIZE: 7
10 | DATA:
11 | NUM_WORKERS: 10
12 | BATCH_SIZE: 128
13 | DATA_PATH: dataset/imagenet
14 | DATASET: imagenet
15 | EXTRA:
16 | loss_lambda: 2.8
17 | arc_decay: 150
18 | target_flops: 3.2
19 | arc_warmup: 0
20 | arc_min_lr: 5e-4
21 | architecture_lr: 5e-4
22 | alpha: 1e2
23 | theta: 1.5
24 | assigned_indicators: [[[1.0, 0.0, 0.0], [0.0, 0.0, 0.0]], [[1.0, 0.0, 0.0], [0.0, 0.0, 0.0]], [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]]
25 | searching_model: 'ffn_indicators/spvit_swin_t_l28_t32_search_12epoch.pth'
26 | distillation_type: 'hard'
27 | TRAIN:
28 | EPOCHS: 200
29 | WARMUP_EPOCHS: 0
30 | BASE_LR: 5e-5
31 | #EVAL_MODE: True
--------------------------------------------------------------------------------
/SPViT_Swin/configs/spvit_swin_tn_l28_t32_search.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: spvit_swin
3 | NAME: spvit_swin_tn_l28_t32_search
4 | DROP_PATH_RATE: 0.2
5 | SWIN:
6 | EMBED_DIM: 96
7 | DEPTHS: [ 2, 2, 6, 2 ]
8 | NUM_HEADS: [ 3, 6, 12, 24 ]
9 | WINDOW_SIZE: 7
10 | DATA:
11 | NUM_WORKERS: 10
12 | BATCH_SIZE: 128
13 | DATA_PATH: dataset/imagenet
14 | DATASET: imagenet
15 | EXTRA:
16 | loss_lambda: 2.8
17 | arc_decay: 150
18 | target_flops: 3.2
19 | arc_warmup: 0
20 | arc_min_lr: 5e-4
21 | architecture_lr: 5e-4
22 | alpha: 1e2
23 | theta: 1.5
24 | TRAIN:
25 | EPOCHS: 300
26 | WARMUP_EPOCHS: 0
27 | BASE_LR: 5e-5
28 | MIN_LR: 5e-5
29 | #EVAL_MODE: True
--------------------------------------------------------------------------------
/SPViT_Swin/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .build import build_loader, build_loader_darts, build_loader_darts_v2
--------------------------------------------------------------------------------
/SPViT_Swin/data/build.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Swin Transformer
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ze Liu
6 | # --------------------------------------------------------
7 | # Modifications copyright (c) 2021 Zhuang AI Group, Haoyu He
8 |
9 | import os
10 | import torch
11 | import numpy as np
12 | import torch.distributed as dist
13 | from torchvision import datasets, transforms
14 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
15 | from timm.data import Mixup
16 | from timm.data import create_transform
17 | from timm.data.transforms import _pil_interp
18 |
19 | from .cached_image_folder import CachedImageFolder
20 | from .samplers import SubsetRandomSampler
21 | from utils import DistributedSamplerWrapper
22 |
23 | def build_loader(config):
24 | config.defrost()
25 | dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config)
26 | config.freeze()
27 | print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build train dataset")
28 | dataset_val, _ = build_dataset(is_train=False, config=config)
29 | print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build val dataset")
30 |
31 | num_tasks = dist.get_world_size()
32 | global_rank = dist.get_rank()
33 | if config.DATA.ZIP_MODE and config.DATA.CACHE_MODE == 'part':
34 | indices = np.arange(dist.get_rank(), len(dataset_train), dist.get_world_size())
35 | sampler_train = SubsetRandomSampler(indices)
36 | else:
37 | sampler_train = torch.utils.data.DistributedSampler(
38 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
39 | )
40 |
41 | indices = np.arange(dist.get_rank(), len(dataset_val), dist.get_world_size())
42 | sampler_val = SubsetRandomSampler(indices)
43 |
44 | data_loader_train = torch.utils.data.DataLoader(
45 | dataset_train, sampler=sampler_train,
46 | batch_size=config.DATA.BATCH_SIZE,
47 | num_workers=config.DATA.NUM_WORKERS,
48 | pin_memory=config.DATA.PIN_MEMORY,
49 | drop_last=True,
50 | )
51 |
52 | data_loader_val = torch.utils.data.DataLoader(
53 | dataset_val, sampler=sampler_val,
54 | batch_size=config.DATA.BATCH_SIZE,
55 | shuffle=False,
56 | num_workers=config.DATA.NUM_WORKERS,
57 | pin_memory=config.DATA.PIN_MEMORY,
58 | drop_last=False
59 | )
60 |
61 | # setup mixup / cutmix
62 | mixup_fn = None
63 | mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None
64 | if mixup_active:
65 | mixup_fn = Mixup(
66 | mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX,
67 | prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE,
68 | label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES)
69 |
70 | return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn
71 |
72 |
73 | def build_loader_darts(config):
74 | config.defrost()
75 | dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config)
76 | config.freeze()
77 | print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build train dataset")
78 | dataset_val, _ = build_dataset(is_train=False, config=config)
79 | print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build val dataset")
80 |
81 | num_tasks = dist.get_world_size()
82 | global_rank = dist.get_rank()
83 | if config.DATA.ZIP_MODE and config.DATA.CACHE_MODE == 'part':
84 | indices = np.arange(dist.get_rank(), len(dataset_train), dist.get_world_size())
85 | sampler_train = SubsetRandomSampler(indices)
86 | else:
87 |
88 | num_train = len(dataset_train)
89 | indices = list(range(num_train))
90 | split = int(np.floor(0.5 * num_train))
91 | traintrain_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:split])
92 | trainval_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[split:num_train])
93 | dis_traintrain_sampler = DistributedSamplerWrapper(traintrain_sampler, num_replicas=num_tasks, rank=global_rank, shuffle=True)
94 | dis_trainval_sampler = DistributedSamplerWrapper(trainval_sampler, num_replicas=num_tasks, rank=global_rank, shuffle=True)
95 |
96 | indices = np.arange(dist.get_rank(), len(dataset_val), dist.get_world_size())
97 | sampler_val = SubsetRandomSampler(indices)
98 |
99 | train_queue = torch.utils.data.DataLoader(
100 | dataset_train, sampler=dis_traintrain_sampler,
101 | batch_size=config.DATA.BATCH_SIZE,
102 | num_workers=0,
103 | pin_memory=config.DATA.PIN_MEMORY,
104 | drop_last=True,
105 | )
106 |
107 | val_queue = torch.utils.data.DataLoader(
108 | dataset_train, sampler=dis_trainval_sampler,
109 | batch_size=config.DATA.BATCH_SIZE,
110 | num_workers=0,
111 | pin_memory=config.DATA.PIN_MEMORY,
112 | drop_last=True,
113 | )
114 |
115 | data_loader_val = torch.utils.data.DataLoader(
116 | dataset_val, sampler=sampler_val,
117 | batch_size=config.DATA.BATCH_SIZE,
118 | shuffle=False,
119 | num_workers=config.DATA.NUM_WORKERS,
120 | pin_memory=config.DATA.PIN_MEMORY,
121 | drop_last=False
122 | )
123 |
124 | # setup mixup / cutmix
125 | mixup_fn = None
126 | mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None
127 | if mixup_active:
128 | mixup_fn = Mixup(
129 | mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX,
130 | prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE,
131 | label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES)
132 |
133 | return dataset_train, dataset_val, train_queue, val_queue, data_loader_val, mixup_fn
134 |
135 |
136 | def build_loader_darts_v2(config):
137 | config.defrost()
138 | dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config)
139 | config.freeze()
140 | print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build train dataset")
141 | dataset_val, _ = build_dataset(is_train=False, config=config)
142 | print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build val dataset")
143 |
144 | num_tasks = dist.get_world_size()
145 | global_rank = dist.get_rank()
146 |
147 | if config.DATA.ZIP_MODE and config.DATA.CACHE_MODE == 'part':
148 | indices = np.arange(dist.get_rank(), len(dataset_train), dist.get_world_size())
149 | sampler_train = SubsetRandomSampler(indices)
150 | else:
151 | num_train = len(dataset_train)
152 | split = int(np.floor(0.5 * num_train))
153 |
154 | dataset_traintrain, dataset_trainval = torch.utils.data.random_split(
155 | dataset_train,
156 | (len(dataset_train) - split, split)
157 | )
158 |
159 | sampler_traintrain = torch.utils.data.DistributedSampler(
160 | dataset_traintrain, num_replicas=num_tasks, rank=global_rank, shuffle=True
161 | )
162 |
163 | sampler_trainval = torch.utils.data.DistributedSampler(
164 | dataset_trainval, num_replicas=num_tasks, rank=global_rank, shuffle=True
165 | )
166 |
167 | # sampler_traintrain = torch.utils.data.RandomSampler(
168 | # dataset_traintrain
169 | # )
170 | #
171 | # sampler_trainval = torch.utils.data.RandomSampler(
172 | # dataset_trainval
173 | # )
174 |
175 | indices = np.arange(dist.get_rank(), len(dataset_val), dist.get_world_size())
176 | sampler_val = SubsetRandomSampler(indices)
177 |
178 | train_queue = torch.utils.data.DataLoader(
179 | dataset_train, sampler=sampler_traintrain,
180 | batch_size=config.DATA.BATCH_SIZE,
181 | num_workers=2,
182 | pin_memory=config.DATA.PIN_MEMORY,
183 | drop_last=True,
184 | persistent_workers=False
185 | )
186 |
187 | val_queue = torch.utils.data.DataLoader(
188 | dataset_train, sampler=sampler_trainval,
189 | batch_size=config.DATA.BATCH_SIZE,
190 | num_workers=2,
191 | pin_memory=config.DATA.PIN_MEMORY,
192 | drop_last=True,
193 | persistent_workers=False
194 | )
195 |
196 | data_loader_val = torch.utils.data.DataLoader(
197 | dataset_val, sampler=sampler_val,
198 | batch_size=config.DATA.BATCH_SIZE,
199 | shuffle=False,
200 | num_workers=config.DATA.NUM_WORKERS,
201 | pin_memory=config.DATA.PIN_MEMORY,
202 | drop_last=False
203 | )
204 |
205 | # setup mixup / cutmix
206 | mixup_fn = None
207 | mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None
208 | if mixup_active:
209 | mixup_fn = Mixup(
210 | mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX,
211 | prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE,
212 | label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES)
213 |
214 | return dataset_train, dataset_val, train_queue, val_queue, data_loader_val, mixup_fn
215 |
216 |
217 | def build_dataset(is_train, config):
218 | transform = build_transform(is_train, config)
219 | if config.DATA.DATASET == 'imagenet':
220 | prefix = 'train' if is_train else 'val'
221 | if config.DATA.ZIP_MODE:
222 | ann_file = prefix + "_map.txt"
223 | prefix = prefix + ".zip@/"
224 | dataset = CachedImageFolder(config.DATA.DATA_PATH, ann_file, prefix, transform,
225 | cache_mode=config.DATA.CACHE_MODE if is_train else 'part')
226 | else:
227 | root = os.path.join(config.DATA.DATA_PATH, prefix)
228 | dataset = datasets.ImageFolder(root, transform=transform)
229 | nb_classes = 1000
230 | elif config.DATA.DATASET == 'cifar':
231 | dataset = datasets.CIFAR100(config.DATA.DATA_PATH, train=is_train, transform=transform)
232 | nb_classes = 100
233 | elif config.DATA.DATASET == 'imagenet100':
234 | root = os.path.join(config.DATA.DATA_PATH, 'train100' if is_train else 'val100')
235 | dataset = datasets.ImageFolder(root, transform=transform)
236 | nb_classes = 100
237 | else:
238 | raise NotImplementedError("We only support ImageNet Now.")
239 |
240 | return dataset, nb_classes
241 |
242 |
243 | def build_transform(is_train, config):
244 | resize_im = config.DATA.IMG_SIZE > 32
245 | if is_train:
246 | # this should always dispatch to transforms_imagenet_train
247 | transform = create_transform(
248 | input_size=config.DATA.IMG_SIZE,
249 | is_training=True,
250 | color_jitter=config.AUG.COLOR_JITTER if config.AUG.COLOR_JITTER > 0 else None,
251 | auto_augment=config.AUG.AUTO_AUGMENT if config.AUG.AUTO_AUGMENT != 'none' else None,
252 | re_prob=config.AUG.REPROB,
253 | re_mode=config.AUG.REMODE,
254 | re_count=config.AUG.RECOUNT,
255 | interpolation=config.DATA.INTERPOLATION,
256 | )
257 | if not resize_im:
258 | # replace RandomResizedCropAndInterpolation with
259 | # RandomCrop
260 | transform.transforms[0] = transforms.RandomCrop(config.DATA.IMG_SIZE, padding=4)
261 | return transform
262 |
263 | t = []
264 | if resize_im:
265 | if config.TEST.CROP:
266 | size = int((256 / 224) * config.DATA.IMG_SIZE)
267 | t.append(
268 | transforms.Resize(size, interpolation=_pil_interp(config.DATA.INTERPOLATION)),
269 | # to maintain same ratio w.r.t. 224 images
270 | )
271 | t.append(transforms.CenterCrop(config.DATA.IMG_SIZE))
272 | else:
273 | t.append(
274 | transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE),
275 | interpolation=_pil_interp(config.DATA.INTERPOLATION))
276 | )
277 |
278 | t.append(transforms.ToTensor())
279 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
280 | return transforms.Compose(t)
281 |
--------------------------------------------------------------------------------
/SPViT_Swin/data/cached_image_folder.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Swin Transformer
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ze Liu
6 | # --------------------------------------------------------
7 | # Modifications copyright (c) 2021 Zhuang AI Group, Haoyu He
8 |
9 | import io
10 | import os
11 | import time
12 | import torch.distributed as dist
13 | import torch.utils.data as data
14 | from PIL import Image
15 |
16 | from .zipreader import is_zip_path, ZipReader
17 |
18 |
19 | def has_file_allowed_extension(filename, extensions):
20 | """Checks if a file is an allowed extension.
21 | Args:
22 | filename (string): path to a file
23 | Returns:
24 | bool: True if the filename ends with a known image extension
25 | """
26 | filename_lower = filename.lower()
27 | return any(filename_lower.endswith(ext) for ext in extensions)
28 |
29 |
30 | def find_classes(dir):
31 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
32 | classes.sort()
33 | class_to_idx = {classes[i]: i for i in range(len(classes))}
34 | return classes, class_to_idx
35 |
36 |
37 | def make_dataset(dir, class_to_idx, extensions):
38 | images = []
39 | dir = os.path.expanduser(dir)
40 | for target in sorted(os.listdir(dir)):
41 | d = os.path.join(dir, target)
42 | if not os.path.isdir(d):
43 | continue
44 |
45 | for root, _, fnames in sorted(os.walk(d)):
46 | for fname in sorted(fnames):
47 | if has_file_allowed_extension(fname, extensions):
48 | path = os.path.join(root, fname)
49 | item = (path, class_to_idx[target])
50 | images.append(item)
51 |
52 | return images
53 |
54 |
55 | def make_dataset_with_ann(ann_file, img_prefix, extensions):
56 | images = []
57 | with open(ann_file, "r") as f:
58 | contents = f.readlines()
59 | for line_str in contents:
60 | path_contents = [c for c in line_str.split('\t')]
61 | im_file_name = path_contents[0]
62 | class_index = int(path_contents[1])
63 |
64 | assert str.lower(os.path.splitext(im_file_name)[-1]) in extensions
65 | item = (os.path.join(img_prefix, im_file_name), class_index)
66 |
67 | images.append(item)
68 |
69 | return images
70 |
71 |
72 | class DatasetFolder(data.Dataset):
73 | """A generic data loader where the samples are arranged in this way: ::
74 | root/class_x/xxx.ext
75 | root/class_x/xxy.ext
76 | root/class_x/xxz.ext
77 | root/class_y/123.ext
78 | root/class_y/nsdf3.ext
79 | root/class_y/asd932_.ext
80 | Args:
81 | root (string): Root directory path.
82 | loader (callable): A function to load a sample given its path.
83 | extensions (list[string]): A list of allowed extensions.
84 | transform (callable, optional): A function/transform that takes in
85 | a sample and returns a transformed version.
86 | E.g, ``transforms.RandomCrop`` for images.
87 | target_transform (callable, optional): A function/transform that takes
88 | in the target and transforms it.
89 | Attributes:
90 | samples (list): List of (sample path, class_index) tuples
91 | """
92 |
93 | def __init__(self, root, loader, extensions, ann_file='', img_prefix='', transform=None, target_transform=None,
94 | cache_mode="no"):
95 | # image folder mode
96 | if ann_file == '':
97 | _, class_to_idx = find_classes(root)
98 | samples = make_dataset(root, class_to_idx, extensions)
99 | # zip mode
100 | else:
101 | samples = make_dataset_with_ann(os.path.join(root, ann_file),
102 | os.path.join(root, img_prefix),
103 | extensions)
104 |
105 | if len(samples) == 0:
106 | raise (RuntimeError("Found 0 files in subfolders of: " + root + "\n" +
107 | "Supported extensions are: " + ",".join(extensions)))
108 |
109 | self.root = root
110 | self.loader = loader
111 | self.extensions = extensions
112 |
113 | self.samples = samples
114 | self.labels = [y_1k for _, y_1k in samples]
115 | self.classes = list(set(self.labels))
116 |
117 | self.transform = transform
118 | self.target_transform = target_transform
119 |
120 | self.cache_mode = cache_mode
121 | if self.cache_mode != "no":
122 | self.init_cache()
123 |
124 | def init_cache(self):
125 | assert self.cache_mode in ["part", "full"]
126 | n_sample = len(self.samples)
127 | global_rank = dist.get_rank()
128 | world_size = dist.get_world_size()
129 |
130 | samples_bytes = [None for _ in range(n_sample)]
131 | start_time = time.time()
132 | for index in range(n_sample):
133 | if index % (n_sample // 10) == 0:
134 | t = time.time() - start_time
135 | print(f'global_rank {dist.get_rank()} cached {index}/{n_sample} takes {t:.2f}s per block')
136 | start_time = time.time()
137 | path, target = self.samples[index]
138 | if self.cache_mode == "full":
139 | samples_bytes[index] = (ZipReader.read(path), target)
140 | elif self.cache_mode == "part" and index % world_size == global_rank:
141 | samples_bytes[index] = (ZipReader.read(path), target)
142 | else:
143 | samples_bytes[index] = (path, target)
144 | self.samples = samples_bytes
145 |
146 | def __getitem__(self, index):
147 | """
148 | Args:
149 | index (int): Index
150 | Returns:
151 | tuple: (sample, target) where target is class_index of the target class.
152 | """
153 | path, target = self.samples[index]
154 | sample = self.loader(path)
155 | if self.transform is not None:
156 | sample = self.transform(sample)
157 | if self.target_transform is not None:
158 | target = self.target_transform(target)
159 |
160 | return sample, target
161 |
162 | def __len__(self):
163 | return len(self.samples)
164 |
165 | def __repr__(self):
166 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
167 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
168 | fmt_str += ' Root Location: {}\n'.format(self.root)
169 | tmp = ' Transforms (if any): '
170 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
171 | tmp = ' Target Transforms (if any): '
172 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
173 | return fmt_str
174 |
175 |
176 | IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif']
177 |
178 |
179 | def pil_loader(path):
180 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
181 | if isinstance(path, bytes):
182 | img = Image.open(io.BytesIO(path))
183 | elif is_zip_path(path):
184 | data = ZipReader.read(path)
185 | img = Image.open(io.BytesIO(data))
186 | else:
187 | with open(path, 'rb') as f:
188 | img = Image.open(f)
189 | return img.convert('RGB')
190 |
191 |
192 | def accimage_loader(path):
193 | import accimage
194 | try:
195 | return accimage.Image(path)
196 | except IOError:
197 | # Potentially a decoding problem, fall back to PIL.Image
198 | return pil_loader(path)
199 |
200 |
201 | def default_img_loader(path):
202 | from torchvision import get_image_backend
203 | if get_image_backend() == 'accimage':
204 | return accimage_loader(path)
205 | else:
206 | return pil_loader(path)
207 |
208 |
209 | class CachedImageFolder(DatasetFolder):
210 | """A generic data loader where the images are arranged in this way: ::
211 | root/dog/xxx.png
212 | root/dog/xxy.png
213 | root/dog/xxz.png
214 | root/cat/123.png
215 | root/cat/nsdf3.png
216 | root/cat/asd932_.png
217 | Args:
218 | root (string): Root directory path.
219 | transform (callable, optional): A function/transform that takes in an PIL image
220 | and returns a transformed version. E.g, ``transforms.RandomCrop``
221 | target_transform (callable, optional): A function/transform that takes in the
222 | target and transforms it.
223 | loader (callable, optional): A function to load an image given its path.
224 | Attributes:
225 | imgs (list): List of (image path, class_index) tuples
226 | """
227 |
228 | def __init__(self, root, ann_file='', img_prefix='', transform=None, target_transform=None,
229 | loader=default_img_loader, cache_mode="no"):
230 | super(CachedImageFolder, self).__init__(root, loader, IMG_EXTENSIONS,
231 | ann_file=ann_file, img_prefix=img_prefix,
232 | transform=transform, target_transform=target_transform,
233 | cache_mode=cache_mode)
234 | self.imgs = self.samples
235 |
236 | def __getitem__(self, index):
237 | """
238 | Args:
239 | index (int): Index
240 | Returns:
241 | tuple: (image, target) where target is class_index of the target class.
242 | """
243 | path, target = self.samples[index]
244 | image = self.loader(path)
245 | if self.transform is not None:
246 | img = self.transform(image)
247 | else:
248 | img = image
249 | if self.target_transform is not None:
250 | target = self.target_transform(target)
251 |
252 | return img, target
253 |
--------------------------------------------------------------------------------
/SPViT_Swin/data/samplers.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Swin Transformer
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ze Liu
6 | # --------------------------------------------------------
7 | # Modifications copyright (c) 2021 Zhuang AI Group, Haoyu He
8 |
9 | import torch
10 |
11 |
12 | class SubsetRandomSampler(torch.utils.data.Sampler):
13 | r"""Samples elements randomly from a given list of indices, without replacement.
14 |
15 | Arguments:
16 | indices (sequence): a sequence of indices
17 | """
18 |
19 | def __init__(self, indices):
20 | self.epoch = 0
21 | self.indices = indices
22 |
23 | def __iter__(self):
24 | return (self.indices[i] for i in torch.randperm(len(self.indices)))
25 |
26 | def __len__(self):
27 | return len(self.indices)
28 |
29 | def set_epoch(self, epoch):
30 | self.epoch = epoch
31 |
--------------------------------------------------------------------------------
/SPViT_Swin/data/zipreader.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Swin Transformer
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ze Liu
6 | # --------------------------------------------------------
7 | # Modifications copyright (c) 2021 Zhuang AI Group, Haoyu He
8 |
9 | import os
10 | import zipfile
11 | import io
12 | import numpy as np
13 | from PIL import Image
14 | from PIL import ImageFile
15 |
16 | ImageFile.LOAD_TRUNCATED_IMAGES = True
17 |
18 |
19 | def is_zip_path(img_or_path):
20 | """judge if this is a zip path"""
21 | return '.zip@' in img_or_path
22 |
23 |
24 | class ZipReader(object):
25 | """A class to read zipped files"""
26 | zip_bank = dict()
27 |
28 | def __init__(self):
29 | super(ZipReader, self).__init__()
30 |
31 | @staticmethod
32 | def get_zipfile(path):
33 | zip_bank = ZipReader.zip_bank
34 | if path not in zip_bank:
35 | zfile = zipfile.ZipFile(path, 'r')
36 | zip_bank[path] = zfile
37 | return zip_bank[path]
38 |
39 | @staticmethod
40 | def split_zip_style_path(path):
41 | pos_at = path.index('@')
42 | assert pos_at != -1, "character '@' is not found from the given path '%s'" % path
43 |
44 | zip_path = path[0: pos_at]
45 | folder_path = path[pos_at + 1:]
46 | folder_path = str.strip(folder_path, '/')
47 | return zip_path, folder_path
48 |
49 | @staticmethod
50 | def list_folder(path):
51 | zip_path, folder_path = ZipReader.split_zip_style_path(path)
52 |
53 | zfile = ZipReader.get_zipfile(zip_path)
54 | folder_list = []
55 | for file_foler_name in zfile.namelist():
56 | file_foler_name = str.strip(file_foler_name, '/')
57 | if file_foler_name.startswith(folder_path) and \
58 | len(os.path.splitext(file_foler_name)[-1]) == 0 and \
59 | file_foler_name != folder_path:
60 | if len(folder_path) == 0:
61 | folder_list.append(file_foler_name)
62 | else:
63 | folder_list.append(file_foler_name[len(folder_path) + 1:])
64 |
65 | return folder_list
66 |
67 | @staticmethod
68 | def list_files(path, extension=None):
69 | if extension is None:
70 | extension = ['.*']
71 | zip_path, folder_path = ZipReader.split_zip_style_path(path)
72 |
73 | zfile = ZipReader.get_zipfile(zip_path)
74 | file_lists = []
75 | for file_foler_name in zfile.namelist():
76 | file_foler_name = str.strip(file_foler_name, '/')
77 | if file_foler_name.startswith(folder_path) and \
78 | str.lower(os.path.splitext(file_foler_name)[-1]) in extension:
79 | if len(folder_path) == 0:
80 | file_lists.append(file_foler_name)
81 | else:
82 | file_lists.append(file_foler_name[len(folder_path) + 1:])
83 |
84 | return file_lists
85 |
86 | @staticmethod
87 | def read(path):
88 | zip_path, path_img = ZipReader.split_zip_style_path(path)
89 | zfile = ZipReader.get_zipfile(zip_path)
90 | data = zfile.read(path_img)
91 | return data
92 |
93 | @staticmethod
94 | def imread(path):
95 | zip_path, path_img = ZipReader.split_zip_style_path(path)
96 | zfile = ZipReader.get_zipfile(zip_path)
97 | data = zfile.read(path_img)
98 | try:
99 | im = Image.open(io.BytesIO(data))
100 | except:
101 | print("ERROR IMG LOADED: ", path_img)
102 | random_img = np.random.rand(224, 224, 3) * 255
103 | im = Image.fromarray(np.uint8(random_img))
104 | return im
105 |
--------------------------------------------------------------------------------
/SPViT_Swin/dev/README.md:
--------------------------------------------------------------------------------
1 |
2 | ## Some scripts for developers to use, include:
3 |
4 | - `linter.sh`: lint the codebase before commit
5 | - `run_{inference,instant}_tests.sh`: run inference/training for a few iterations.
6 | Note that these tests require 2 GPUs.
7 | - `parse_results.sh`: parse results from a log file.
8 |
--------------------------------------------------------------------------------
/SPViT_Swin/dev/linter.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash -e
2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
3 |
4 | # Run this script at project root by "./dev/linter.sh" before you commit
5 |
6 | vergte() {
7 | [ "$2" = "$(echo -e "$1\\n$2" | sort -V | head -n1)" ]
8 | }
9 |
10 | {
11 | black --version | grep -E "(19.3b0.*6733274)|(19.3b0\\+8)" > /dev/null
12 | } || {
13 | echo "Linter requires 'black @ git+https://github.com/psf/black@673327449f86fce558adde153bb6cbe54bfebad2' !"
14 | exit 1
15 | }
16 |
17 | ISORT_TARGET_VERSION="4.3.21"
18 | ISORT_VERSION=$(isort -v | grep VERSION | awk '{print $2}')
19 | vergte "$ISORT_VERSION" "$ISORT_TARGET_VERSION" || {
20 | echo "Linter requires isort>=${ISORT_TARGET_VERSION} !"
21 | exit 1
22 | }
23 |
24 | set -v
25 |
26 | echo "Running isort ..."
27 | isort -y -sp . --atomic
28 |
29 | echo "Running black ..."
30 | black -l 100 .
31 |
32 | echo "Running flake8 ..."
33 | if [ -x "$(command -v flake8-3)" ]; then
34 | flake8-3 .
35 | else
36 | python3 -m flake8 .
37 | fi
38 |
39 | # echo "Running mypy ..."
40 | # Pytorch does not have enough type annotations
41 | # mypy detectron2/solver detectron2/structures detectron2/config
42 |
43 | echo "Running clang-format ..."
44 | find . -regex ".*\.\(cpp\|c\|cc\|cu\|cxx\|h\|hh\|hpp\|hxx\|tcc\|mm\|m\)" -print0 | xargs -0 clang-format -i
45 |
46 | command -v arc > /dev/null && arc lint
47 |
--------------------------------------------------------------------------------
/SPViT_Swin/dev/packaging/README.md:
--------------------------------------------------------------------------------
1 |
2 | ## To build a cu101 wheel for release:
3 |
4 | ```
5 | $ nvidia-docker run -it --storage-opt "size=20GB" --name pt pytorch/manylinux-cuda101
6 | # inside the container:
7 | # git clone https://github.com/facebookresearch/detectron2/
8 | # cd detectron2
9 | # export CU_VERSION=cu101 D2_VERSION_SUFFIX= PYTHON_VERSION=3.7 PYTORCH_VERSION=1.4
10 | # ./dev/packaging/build_wheel.sh
11 | ```
12 |
13 | ## To build all wheels for `CUDA {9.2,10.0,10.1}` x `Python {3.6,3.7,3.8}`:
14 | ```
15 | ./dev/packaging/build_all_wheels.sh
16 | ./dev/packaging/gen_wheel_index.sh /path/to/wheels
17 | ```
18 |
--------------------------------------------------------------------------------
/SPViT_Swin/dev/packaging/build_all_wheels.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash -e
2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
3 |
4 | [[ -d "dev/packaging" ]] || {
5 | echo "Please run this script at detectron2 root!"
6 | exit 1
7 | }
8 |
9 | build_one() {
10 | cu=$1
11 | pytorch_ver=$2
12 |
13 | case "$cu" in
14 | cu*)
15 | container_name=manylinux-cuda${cu/cu/}
16 | ;;
17 | cpu)
18 | container_name=manylinux-cuda101
19 | ;;
20 | *)
21 | echo "Unrecognized cu=$cu"
22 | exit 1
23 | ;;
24 | esac
25 |
26 | echo "Launching container $container_name ..."
27 |
28 | for py in 3.6 3.7 3.8; do
29 | docker run -itd \
30 | --name $container_name \
31 | --mount type=bind,source="$(pwd)",target=/detectron2 \
32 | pytorch/$container_name
33 |
34 | cat </dev/null 2>&1 && pwd )"
8 | . "$script_dir/pkg_helpers.bash"
9 |
10 | echo "Build Settings:"
11 | echo "CU_VERSION: $CU_VERSION" # e.g. cu101
12 | echo "D2_VERSION_SUFFIX: $D2_VERSION_SUFFIX" # e.g. +cu101 or ""
13 | echo "PYTHON_VERSION: $PYTHON_VERSION" # e.g. 3.6
14 | echo "PYTORCH_VERSION: $PYTORCH_VERSION" # e.g. 1.4
15 |
16 | setup_cuda
17 | setup_wheel_python
18 | yum install ninja-build -y && ln -sv /usr/bin/ninja-build /usr/bin/ninja
19 |
20 | pip_install pip numpy -U
21 | pip_install "torch==$PYTORCH_VERSION" \
22 | -f https://download.pytorch.org/whl/"$CU_VERSION"/torch_stable.html
23 |
24 | # use separate directories to allow parallel build
25 | BASE_BUILD_DIR=build/cu$CU_VERSION-py$PYTHON_VERSION-pt$PYTORCH_VERSION
26 | python setup.py \
27 | build -b "$BASE_BUILD_DIR" \
28 | bdist_wheel -b "$BASE_BUILD_DIR/build_dist" -d "wheels/$CU_VERSION/torch$PYTORCH_VERSION"
29 | rm -rf "$BASE_BUILD_DIR"
30 |
--------------------------------------------------------------------------------
/SPViT_Swin/dev/packaging/gen_install_table.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | import argparse
5 |
6 | template = """ install
\
7 | python -m pip install detectron2{d2_version} -f \\
8 | https://dl.fbaipublicfiles.com/detectron2/wheels/{cuda}/torch{torch}/index.html
9 |
"""
10 | CUDA_SUFFIX = {"10.2": "cu102", "10.1": "cu101", "10.0": "cu100", "9.2": "cu92", "cpu": "cpu"}
11 |
12 |
13 | def gen_header(torch_versions):
14 | return ' CUDA | ' + "".join(
15 | [
16 | 'torch {} | '.format(t)
17 | for t in torch_versions
18 | ]
19 | )
20 |
21 |
22 | if __name__ == "__main__":
23 | parser = argparse.ArgumentParser()
24 | parser.add_argument("--d2-version", help="detectron2 version number, default to empty")
25 | args = parser.parse_args()
26 | d2_version = f"=={args.d2_version}" if args.d2_version else ""
27 |
28 | all_versions = [("1.4", k) for k in ["10.1", "10.0", "9.2", "cpu"]] + [
29 | ("1.5", k) for k in ["10.2", "10.1", "9.2", "cpu"]
30 | ]
31 |
32 | torch_versions = sorted({k[0] for k in all_versions}, key=float, reverse=True)
33 | cuda_versions = sorted(
34 | {k[1] for k in all_versions}, key=lambda x: float(x) if x != "cpu" else 0, reverse=True
35 | )
36 |
37 | table = gen_header(torch_versions)
38 | for cu in cuda_versions:
39 | table += f""" {cu} | """
40 | cu_suffix = CUDA_SUFFIX[cu]
41 | for torch in torch_versions:
42 | if (torch, cu) in all_versions:
43 | cell = template.format(d2_version=d2_version, cuda=cu_suffix, torch=torch)
44 | else:
45 | cell = ""
46 | table += f"""{cell} | """
47 | table += "
"
48 | table += "
"
49 | print(table)
50 |
--------------------------------------------------------------------------------
/SPViT_Swin/dev/packaging/gen_wheel_index.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash -e
2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
3 |
4 |
5 | root=$1
6 | if [[ -z "$root" ]]; then
7 | echo "Usage: ./gen_wheel_index.sh /path/to/wheels"
8 | exit
9 | fi
10 |
11 | export LC_ALL=C # reproducible sort
12 | # NOTE: all sort in this script might not work when xx.10 is released
13 |
14 | index=$root/index.html
15 |
16 | cd "$root"
17 | for cu in cpu cu92 cu100 cu101 cu102; do
18 | cd "$root/$cu"
19 | echo "Creating $PWD/index.html ..."
20 | # First sort by torch version, then stable sort by d2 version with unique.
21 | # As a result, the latest torch version for each d2 version is kept.
22 | for whl in $(find -type f -name '*.whl' -printf '%P\n' \
23 | | sort -k 1 -r | sort -t '/' -k 2 --stable -r --unique); do
24 | echo "$whl
"
25 | done > index.html
26 |
27 |
28 | for torch in torch*; do
29 | cd "$root/$cu/$torch"
30 |
31 | # list all whl for each cuda,torch version
32 | echo "Creating $PWD/index.html ..."
33 | for whl in $(find . -type f -name '*.whl' -printf '%P\n' | sort -r); do
34 | echo "$whl
"
35 | done > index.html
36 | done
37 | done
38 |
39 | cd "$root"
40 | # Just list everything:
41 | echo "Creating $index ..."
42 | for whl in $(find . -type f -name '*.whl' -printf '%P\n' | sort -r); do
43 | echo "$whl
"
44 | done > "$index"
45 |
46 |
--------------------------------------------------------------------------------
/SPViT_Swin/dev/packaging/pkg_helpers.bash:
--------------------------------------------------------------------------------
1 | #!/bin/bash -e
2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
3 |
4 | # Function to retry functions that sometimes timeout or have flaky failures
5 | retry () {
6 | $* || (sleep 1 && $*) || (sleep 2 && $*) || (sleep 4 && $*) || (sleep 8 && $*)
7 | }
8 | # Install with pip a bit more robustly than the default
9 | pip_install() {
10 | retry pip install --progress-bar off "$@"
11 | }
12 |
13 |
14 | setup_cuda() {
15 | # Now work out the CUDA settings
16 | # Like other torch domain libraries, we choose common GPU architectures only.
17 | export FORCE_CUDA=1
18 | case "$CU_VERSION" in
19 | cu102)
20 | export CUDA_HOME=/usr/local/cuda-10.2/
21 | export TORCH_CUDA_ARCH_LIST="3.5;3.7;5.0;5.2;6.0+PTX;6.1+PTX;7.0+PTX;7.5+PTX"
22 | ;;
23 | cu101)
24 | export CUDA_HOME=/usr/local/cuda-10.1/
25 | export TORCH_CUDA_ARCH_LIST="3.5;3.7;5.0;5.2;6.0+PTX;6.1+PTX;7.0+PTX;7.5+PTX"
26 | ;;
27 | cu100)
28 | export CUDA_HOME=/usr/local/cuda-10.0/
29 | export TORCH_CUDA_ARCH_LIST="3.5;3.7;5.0;5.2;6.0+PTX;6.1+PTX;7.0+PTX;7.5+PTX"
30 | ;;
31 | cu92)
32 | export CUDA_HOME=/usr/local/cuda-9.2/
33 | export TORCH_CUDA_ARCH_LIST="3.5;3.7;5.0;5.2;6.0+PTX;6.1+PTX;7.0+PTX"
34 | ;;
35 | cpu)
36 | unset FORCE_CUDA
37 | export CUDA_VISIBLE_DEVICES=
38 | ;;
39 | *)
40 | echo "Unrecognized CU_VERSION=$CU_VERSION"
41 | exit 1
42 | ;;
43 | esac
44 | }
45 |
46 | setup_wheel_python() {
47 | case "$PYTHON_VERSION" in
48 | 3.6) python_abi=cp36-cp36m ;;
49 | 3.7) python_abi=cp37-cp37m ;;
50 | 3.8) python_abi=cp38-cp38 ;;
51 | *)
52 | echo "Unrecognized PYTHON_VERSION=$PYTHON_VERSION"
53 | exit 1
54 | ;;
55 | esac
56 | export PATH="/opt/python/$python_abi/bin:$PATH"
57 | }
58 |
--------------------------------------------------------------------------------
/SPViT_Swin/dev/parse_results.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
3 |
4 | # A shell script that parses metrics from the log file.
5 | # Make it easier for developers to track performance of models.
6 |
7 | LOG="$1"
8 |
9 | if [[ -z "$LOG" ]]; then
10 | echo "Usage: $0 /path/to/log/file"
11 | exit 1
12 | fi
13 |
14 | # [12/15 11:47:32] trainer INFO: Total training time: 12:15:04.446477 (0.4900 s / it)
15 | # [12/15 11:49:03] inference INFO: Total inference time: 0:01:25.326167 (0.13652186737060548 s / img per device, on 8 devices)
16 | # [12/15 11:49:03] inference INFO: Total inference pure compute time: .....
17 |
18 | # training time
19 | trainspeed=$(grep -o 'Overall training.*' "$LOG" | grep -Eo '\(.*\)' | grep -o '[0-9\.]*')
20 | echo "Training speed: $trainspeed s/it"
21 |
22 | # inference time: there could be multiple inference during training
23 | inferencespeed=$(grep -o 'Total inference pure.*' "$LOG" | tail -n1 | grep -Eo '\(.*\)' | grep -o '[0-9\.]*' | head -n1)
24 | echo "Inference speed: $inferencespeed s/it"
25 |
26 | # [12/15 11:47:18] trainer INFO: eta: 0:00:00 iter: 90000 loss: 0.5407 (0.7256) loss_classifier: 0.1744 (0.2446) loss_box_reg: 0.0838 (0.1160) loss_mask: 0.2159 (0.2722) loss_objectness: 0.0244 (0.0429) loss_rpn_box_reg: 0.0279 (0.0500) time: 0.4487 (0.4899) data: 0.0076 (0.0975) lr: 0.000200 max mem: 4161
27 | memory=$(grep -o 'max[_ ]mem: [0-9]*' "$LOG" | tail -n1 | grep -o '[0-9]*')
28 | echo "Training memory: $memory MB"
29 |
30 | echo "Easy to copypaste:"
31 | echo "$trainspeed","$inferencespeed","$memory"
32 |
33 | echo "------------------------------"
34 |
35 | # [12/26 17:26:32] engine.coco_evaluation: copypaste: Task: bbox
36 | # [12/26 17:26:32] engine.coco_evaluation: copypaste: AP,AP50,AP75,APs,APm,APl
37 | # [12/26 17:26:32] engine.coco_evaluation: copypaste: 0.0017,0.0024,0.0017,0.0005,0.0019,0.0011
38 | # [12/26 17:26:32] engine.coco_evaluation: copypaste: Task: segm
39 | # [12/26 17:26:32] engine.coco_evaluation: copypaste: AP,AP50,AP75,APs,APm,APl
40 | # [12/26 17:26:32] engine.coco_evaluation: copypaste: 0.0014,0.0021,0.0016,0.0005,0.0016,0.0011
41 |
42 | echo "COCO Results:"
43 | num_tasks=$(grep -o 'copypaste:.*Task.*' "$LOG" | sort -u | wc -l)
44 | # each task has 3 lines
45 | grep -o 'copypaste:.*' "$LOG" | cut -d ' ' -f 2- | tail -n $((num_tasks * 3))
46 |
--------------------------------------------------------------------------------
/SPViT_Swin/dev/run_inference_tests.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash -e
2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
3 |
4 | BIN="python tools/train_net.py"
5 | OUTPUT="inference_test_output"
6 | NUM_GPUS=2
7 |
8 | CFG_LIST=( "${@:1}" )
9 |
10 | if [ ${#CFG_LIST[@]} -eq 0 ]; then
11 | CFG_LIST=( ./configs/quick_schedules/*inference_acc_test.yaml )
12 | fi
13 |
14 | echo "========================================================================"
15 | echo "Configs to run:"
16 | echo "${CFG_LIST[@]}"
17 | echo "========================================================================"
18 |
19 |
20 | for cfg in "${CFG_LIST[@]}"; do
21 | echo "========================================================================"
22 | echo "Running $cfg ..."
23 | echo "========================================================================"
24 | $BIN \
25 | --eval-only \
26 | --num-gpus $NUM_GPUS \
27 | --config-file "$cfg" \
28 | OUTPUT_DIR $OUTPUT
29 | rm -rf $OUTPUT
30 | done
31 |
32 |
33 | echo "========================================================================"
34 | echo "Running demo.py ..."
35 | echo "========================================================================"
36 | DEMO_BIN="python demo/demo.py"
37 | COCO_DIR=datasets/coco/val2014
38 | mkdir -pv $OUTPUT
39 |
40 | set -v
41 |
42 | $DEMO_BIN --config-file ./configs/quick_schedules/panoptic_fpn_R_50_inference_acc_test.yaml \
43 | --input $COCO_DIR/COCO_val2014_0000001933* --output $OUTPUT
44 | rm -rf $OUTPUT
45 |
--------------------------------------------------------------------------------
/SPViT_Swin/dev/run_instant_tests.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash -e
2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
3 |
4 | BIN="python tools/train_net.py"
5 | OUTPUT="instant_test_output"
6 | NUM_GPUS=2
7 |
8 | CFG_LIST=( "${@:1}" )
9 | if [ ${#CFG_LIST[@]} -eq 0 ]; then
10 | CFG_LIST=( ./configs/quick_schedules/*instant_test.yaml )
11 | fi
12 |
13 | echo "========================================================================"
14 | echo "Configs to run:"
15 | echo "${CFG_LIST[@]}"
16 | echo "========================================================================"
17 |
18 | for cfg in "${CFG_LIST[@]}"; do
19 | echo "========================================================================"
20 | echo "Running $cfg ..."
21 | echo "========================================================================"
22 | $BIN --num-gpus $NUM_GPUS --config-file "$cfg" \
23 | SOLVER.IMS_PER_BATCH $(($NUM_GPUS * 2)) \
24 | OUTPUT_DIR "$OUTPUT"
25 | rm -rf "$OUTPUT"
26 | done
27 |
28 |
--------------------------------------------------------------------------------
/SPViT_Swin/ffn_indicators/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziplab/SPViT/ae19b0dafae3baf15c8eb4e817f4156386936866/SPViT_Swin/ffn_indicators/.DS_Store
--------------------------------------------------------------------------------
/SPViT_Swin/ffn_indicators/spvit_swin_bs_l01_t100_search_20epoch.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziplab/SPViT/ae19b0dafae3baf15c8eb4e817f4156386936866/SPViT_Swin/ffn_indicators/spvit_swin_bs_l01_t100_search_20epoch.pth
--------------------------------------------------------------------------------
/SPViT_Swin/ffn_indicators/spvit_swin_sm_l04_t55_search_14epoch.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziplab/SPViT/ae19b0dafae3baf15c8eb4e817f4156386936866/SPViT_Swin/ffn_indicators/spvit_swin_sm_l04_t55_search_14epoch.pth
--------------------------------------------------------------------------------
/SPViT_Swin/ffn_indicators/spvit_swin_t_l28_t32_search_12epoch.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ziplab/SPViT/ae19b0dafae3baf15c8eb4e817f4156386936866/SPViT_Swin/ffn_indicators/spvit_swin_t_l28_t32_search_12epoch.pth
--------------------------------------------------------------------------------
/SPViT_Swin/logger.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Swin Transformer
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ze Liu
6 | # --------------------------------------------------------
7 | # Modifications copyright (c) 2021 Zhuang AI Group, Haoyu He
8 |
9 | import os
10 | import sys
11 | import logging
12 | import functools
13 | from termcolor import colored
14 |
15 |
16 | @functools.lru_cache()
17 | def create_logger(output_dir, dist_rank=0, name=''):
18 | # create logger
19 | logger = logging.getLogger(name)
20 | logger.setLevel(logging.DEBUG)
21 | logger.propagate = False
22 |
23 | # create formatter
24 | fmt = '[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s'
25 | color_fmt = colored('[%(asctime)s %(name)s]', 'green') + \
26 | colored('(%(filename)s %(lineno)d)', 'yellow') + ': %(levelname)s %(message)s'
27 |
28 | # create console handlers for master process
29 | if dist_rank == 0:
30 | console_handler = logging.StreamHandler(sys.stdout)
31 | console_handler.setLevel(logging.DEBUG)
32 | console_handler.setFormatter(
33 | logging.Formatter(fmt=color_fmt, datefmt='%Y-%m-%d %H:%M:%S'))
34 | logger.addHandler(console_handler)
35 |
36 | # create file handlers
37 | file_handler = logging.FileHandler(os.path.join(output_dir, f'log_rank{dist_rank}.txt'), mode='a')
38 | file_handler.setLevel(logging.DEBUG)
39 | file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S'))
40 | logger.addHandler(file_handler)
41 |
42 | return logger
43 |
--------------------------------------------------------------------------------
/SPViT_Swin/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Swin Transformer
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ze Liu
6 | # --------------------------------------------------------
7 | # Modifications copyright (c) 2021 Zhuang AI Group, Haoyu He
8 |
9 | import torch
10 | from timm.scheduler.cosine_lr import CosineLRScheduler
11 | from timm.scheduler.step_lr import StepLRScheduler
12 | from timm.scheduler.scheduler import Scheduler
13 |
14 |
15 | def build_scheduler(config, optimizer, n_iter_per_epoch, num_steps=None, warmup_steps=None, min_lr=None):
16 | if num_steps is None:
17 | num_steps = int(config.TRAIN.EPOCHS * n_iter_per_epoch)
18 | if warmup_steps is None:
19 | warmup_steps = int(config.TRAIN.WARMUP_EPOCHS * n_iter_per_epoch)
20 | if min_lr is None:
21 | min_lr = config.TRAIN.MIN_LR
22 | decay_steps = int(config.TRAIN.LR_SCHEDULER.DECAY_EPOCHS * n_iter_per_epoch)
23 |
24 | lr_scheduler = None
25 | if config.TRAIN.LR_SCHEDULER.NAME == 'cosine':
26 | lr_scheduler = CosineLRScheduler(
27 | optimizer,
28 | t_initial=num_steps,
29 | t_mul=1.,
30 | lr_min=min_lr,
31 | warmup_lr_init=config.TRAIN.WARMUP_LR,
32 | warmup_t=warmup_steps,
33 | cycle_limit=1,
34 | t_in_epochs=False,
35 | )
36 | elif config.TRAIN.LR_SCHEDULER.NAME == 'linear':
37 | lr_scheduler = LinearLRScheduler(
38 | optimizer,
39 | t_initial=num_steps,
40 | lr_min_rate=0.01,
41 | warmup_lr_init=config.TRAIN.WARMUP_LR,
42 | warmup_t=warmup_steps,
43 | t_in_epochs=False,
44 | )
45 | elif config.TRAIN.LR_SCHEDULER.NAME == 'step':
46 | lr_scheduler = StepLRScheduler(
47 | optimizer,
48 | decay_t=decay_steps,
49 | decay_rate=config.TRAIN.LR_SCHEDULER.DECAY_RATE,
50 | warmup_lr_init=config.TRAIN.WARMUP_LR,
51 | warmup_t=warmup_steps,
52 | t_in_epochs=False,
53 | )
54 |
55 | return lr_scheduler
56 |
57 |
58 | class LinearLRScheduler(Scheduler):
59 | def __init__(self,
60 | optimizer: torch.optim.Optimizer,
61 | t_initial: int,
62 | lr_min_rate: float,
63 | warmup_t=0,
64 | warmup_lr_init=0.,
65 | t_in_epochs=True,
66 | noise_range_t=None,
67 | noise_pct=0.67,
68 | noise_std=1.0,
69 | noise_seed=42,
70 | initialize=True,
71 | ) -> None:
72 | super().__init__(
73 | optimizer, param_group_field="lr",
74 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
75 | initialize=initialize)
76 |
77 | self.t_initial = t_initial
78 | self.lr_min_rate = lr_min_rate
79 | self.warmup_t = warmup_t
80 | self.warmup_lr_init = warmup_lr_init
81 | self.t_in_epochs = t_in_epochs
82 | if self.warmup_t:
83 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
84 | super().update_groups(self.warmup_lr_init)
85 | else:
86 | self.warmup_steps = [1 for _ in self.base_values]
87 |
88 | def _get_lr(self, t):
89 | if t < self.warmup_t:
90 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
91 | else:
92 | t = t - self.warmup_t
93 | total_t = self.t_initial - self.warmup_t
94 | lrs = [v - ((v - v * self.lr_min_rate) * (t / total_t)) for v in self.base_values]
95 | return lrs
96 |
97 | def get_epoch_values(self, epoch: int):
98 | if self.t_in_epochs:
99 | return self._get_lr(epoch)
100 | else:
101 | return None
102 |
103 | def get_update_values(self, num_updates: int):
104 | if not self.t_in_epochs:
105 | return self._get_lr(num_updates)
106 | else:
107 | return None
108 |
--------------------------------------------------------------------------------
/SPViT_Swin/main.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Swin Transformer
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ze Liu
6 | # --------------------------------------------------------
7 |
8 | import os
9 | import time
10 | import argparse
11 | import datetime
12 | import numpy as np
13 |
14 | import torch
15 | import torch.backends.cudnn as cudnn
16 | import torch.distributed as dist
17 |
18 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
19 | from timm.utils import accuracy, AverageMeter
20 |
21 | from config import get_config
22 | from models import build_model
23 | from data import build_loader
24 | from lr_scheduler import build_scheduler
25 | from optimizer import build_optimizer
26 | from logger import create_logger
27 | from utils import load_checkpoint, save_checkpoint, get_grad_norm, auto_resume_helper, reduce_tensor
28 |
29 | try:
30 | # noinspection PyUnresolvedReferences
31 | from apex import amp
32 | except ImportError:
33 | amp = None
34 |
35 |
36 | def parse_option():
37 | parser = argparse.ArgumentParser('Swin Transformer training and evaluation script', add_help=False)
38 | parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', )
39 | parser.add_argument(
40 | "--opts",
41 | help="Modify config options by adding 'KEY VALUE' pairs. ",
42 | default=None,
43 | nargs='+',
44 | )
45 |
46 | # easy config modification
47 | parser.add_argument('--batch-size', type=int, help="batch size for single GPU")
48 | parser.add_argument('--data-path', type=str, help='path to dataset')
49 | parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset')
50 | parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'],
51 | help='no: no cache, '
52 | 'full: cache all data, '
53 | 'part: sharding the dataset into nonoverlapping pieces and only cache one piece')
54 | parser.add_argument('--resume', help='resume from checkpoint')
55 | parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps")
56 | parser.add_argument('--use-checkpoint', action='store_true',
57 | help="whether to use gradient checkpointing to save memory")
58 | parser.add_argument('--amp-opt-level', type=str, default='O1', choices=['O0', 'O1', 'O2'],
59 | help='mixed precision opt level, if O0, no amp is used')
60 | parser.add_argument('--output', default='output', type=str, metavar='PATH',
61 | help='root of output folder, the full path is