├── .idea
├── .gitignore
├── PreNAS.iml
├── inspectionProfiles
│ └── profiles_settings.xml
├── modules.xml
└── vcs.xml
├── 01_zero_shot_search.sh
├── 02_one_shot_training.sh
├── 03_evaluation.sh
├── LICENSE.txt
├── README.md
├── candidates_to_choices.py
├── evolution_pre_train.py
├── experiments
└── supernet
│ ├── base.yaml
│ ├── small.yaml
│ ├── supernet-B.yaml
│ ├── supernet-S.yaml
│ ├── supernet-T.yaml
│ └── tiny.yaml
├── figure
└── overview.svg
├── interval_cands
├── base.json
├── small.json
└── tiny.json
├── lib
├── config.py
├── cuda.py
├── datasets.py
├── imagenet_withhold.py
├── samplers.py
├── score_maker.py
├── subImageNet.py
└── utils.py
├── model
├── module
│ ├── Linear_super.py
│ ├── __init__.py
│ ├── embedding_super.py
│ ├── layernorm_super.py
│ ├── multihead_super.py
│ ├── qkv_super.py
│ └── scaling_super.py
├── supernet_transformer.py
└── utils.py
├── requirements.txt
├── supernet_engine.py
├── supernet_train.py
└── two_step_search.py
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Datasource local storage ignored files
5 | /dataSources/
6 | /dataSources.local.xml
7 | # Editor-based HTTP Client requests
8 | /httpRequests/
9 |
--------------------------------------------------------------------------------
/.idea/PreNAS.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/01_zero_shot_search.sh:
--------------------------------------------------------------------------------
1 | ### for tiny search space
2 | python -m torch.distributed.launch \
3 | --nproc_per_node=8 \
4 | --use_env \
5 | two_step_search.py \
6 | --gp \
7 | --change_qk \
8 | --relative_position \
9 | --dist-eval \
10 | --batch-size 64 \
11 | --data-free \
12 | --score-method left_super_taylor6 \
13 | --block-score-method-for-head balance_taylor6_max_dim \
14 | --block-score-method-for-mlp balance_taylor6_max_dim \
15 | --cand-per-interval 1 \
16 | --param-interval 1.0 \
17 | --min_param_limits 5 \
18 | --param_limits 12 \
19 | --data-path ../datas/imagenet \
20 | --cfg ./experiments/supernet/supernet-T.yaml \
21 | --interval-cands-output ./interval_cands/tiny.json
22 |
23 | python candidates_to_choices.py ./interval_cands/tiny.json ./experiments/supernet/tiny.yaml
24 |
25 | ### for small search space
26 | #python -m torch.distributed.launch \
27 | #--nproc_per_node=8 \
28 | #--use_env \
29 | #two_step_search.py \
30 | #--gp \
31 | #--change_qk \
32 | #--relative_position \
33 | #--dist-eval \
34 | #--batch-size 64 \
35 | #--data-free \
36 | #--score-method left_super_taylor6 \
37 | #--block-score-method-for-head balance_taylor6_max_dim \
38 | #--block-score-method-for-mlp balance_taylor6_max_dim \
39 | #--cand-per-interval 1 \
40 | #--param-interval 5.0 \
41 | #--min_param_limits 13 \
42 | #--param_limits 33 \
43 | #--data-path ../datas/imagenet \
44 | #--cfg ./experiments/supernet/supernet-S.yaml \
45 | #--interval-cands-output ./interval_cands/small.json
46 | #
47 | #python candidates_to_choices.py ./interval_cands/small.json ./experiments/supernet/small.yaml
48 |
49 | ### for base search space
50 | #python -m torch.distributed.launch \
51 | #--nproc_per_node=8 \
52 | #--use_env two_step_search.py \
53 | #--gp \
54 | #--change_qk \
55 | #--relative_position \
56 | #--dist-eval \
57 | #--batch-size 64 \
58 | #--data-free \
59 | #--score-method left_super_taylor6 \
60 | #--block-score-method-for-head balance_taylor6_max_dim \
61 | #--block-score-method-for-mlp balance_taylor6_max_dim \
62 | #--cand-per-interval 1 \
63 | #--param-interval 12.0 \
64 | #--min_param_limits 30 \
65 | #--param_limits 70 \
66 | #--data-path ../datas/imagenet \
67 | #--cfg ./experiments/supernet/supernet-B.yaml \
68 | #--interval-cands-output ./interval_cands/base.json
69 | #
70 | #python candidates_to_choices.py ./interval_cands/base.json ./experiments/supernet/base.yaml
--------------------------------------------------------------------------------
/02_one_shot_training.sh:
--------------------------------------------------------------------------------
1 | ### train PreNAS_tiny
2 | python -m torch.distributed.launch \
3 | --nproc_per_node=8 \
4 | --use_env \
5 | supernet_train.py \
6 | --gp \
7 | --change_qk \
8 | --relative_position \
9 | --mode super \
10 | --dist-eval \
11 | --epochs 500 \
12 | --warmup-epochs 20 \
13 | --batch-size 128 \
14 | --min-lr 1e-7 \
15 | --group-by-dim \
16 | --group-by-depth \
17 | --mixup-mode elem \
18 | --aa rand-n3-m10-mstd0.5-inc1 \
19 | --recount 2 \
20 | --data-path ../datas/imagenet \
21 | --cfg ./experiments/supernet/base.yaml \
22 | --candfile ./interval_cands/base.json \
23 | --output_dir ./output/tiny
24 |
25 | ### train PreNAS_small
26 | #python -m torch.distributed.launch \
27 | #--nproc_per_node=8 \
28 | #--use_env \
29 | #supernet_train.py \
30 | #--gp \
31 | #--change_qk \
32 | #--relative_position \
33 | #--mode super \
34 | #--dist-eval \
35 | #--epochs 500 \
36 | #--warmup-epochs 20 \
37 | #--batch-size 128 \
38 | #--group-by-dim \
39 | #--group-by-depth \
40 | #--mixup-mode elem \
41 | #--aa v0r-mstd0.5 \
42 | #--data-path ../datas/imagenet \
43 | #--cfg ./experiments/supernet/small.yaml \
44 | #--candfile ./interval_cands/small.json \
45 | #--output_dir ./output/small
46 |
47 | ### train PreNAS_base
48 | #python -m torch.distributed.launch \
49 | #--nproc_per_node=8 \
50 | #--use_env \
51 | #supernet_train.py \
52 | #--gp \
53 | #--change_qk \
54 | #--relative_position \
55 | #--mode super \
56 | #--dist-eval \
57 | #--epochs 500 \
58 | #--warmup-epochs 20 \
59 | #--batch-size 128 \
60 | #--min-lr 1e-7 \
61 | #--group-by-dim \
62 | #--group-by-depth \
63 | #--mixup-mode elem \
64 | #--aa rand-n3-m10-mstd0.5-inc1 \
65 | #--recount 2 \
66 | #--data-path ../datas/imagenet \
67 | #--cfg ./experiments/supernet/base.yaml \
68 | #--candfile ./interval_cands/base.json \
69 | #--output_dir ./output/base
70 |
--------------------------------------------------------------------------------
/03_evaluation.sh:
--------------------------------------------------------------------------------
1 | ### eval PreNAS_tiny
2 | python -m torch.distributed.launch \
3 | --nproc_per_node=8 \
4 | --use_env \
5 | supernet_train.py \
6 | --gp \
7 | --change_qk \
8 | --relative_position \
9 | --mode retrain \
10 | --dist-eval \
11 | --batch-size 128 \
12 | --eval \
13 | --data-path ../datas/imagenet \
14 | --cfg ./experiments/supernet/tiny.yaml \
15 | --candfile ./interval_cands/tiny.json \
16 | --resume ./output/tiny/checkpoint.pth
17 |
18 | ### eval PreNAS_small
19 | #python -m torch.distributed.launch \
20 | #--nproc_per_node=8 \
21 | #--use_env \
22 | #supernet_train.py \
23 | #--gp \
24 | #--change_qk \
25 | #--relative_position \
26 | #--mode retrain \
27 | #--dist-eval \
28 | #--batch-size 128 \
29 | #--eval \
30 | #--data-path ../datas/imagenet \
31 | #--cfg ./experiments/supernet/small.yaml \
32 | #--candfile ./interval_cands/small.json \
33 | #--resume ./output/small/checkpoint.pth
34 |
35 | ### eval PreNAS_base
36 | #python -m torch.distributed.launch \
37 | #--nproc_per_node=8 \
38 | #--use_env \
39 | #supernet_train.py \
40 | #--gp \
41 | #--change_qk \
42 | #--relative_position \
43 | #--mode retrain \
44 | #--dist-eval \
45 | #--batch-size 128 \
46 | #--eval \
47 | #--data-path ../datas/imagenet \
48 | #--cfg ./experiments/supernet/base.yaml \
49 | #--candfile ./interval_cands/base.json \
50 | #--resume ./output/base/checkpoint.pth
51 |
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2022-2023 Alibaba Group Holding Limited.
190 |
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PreNAS: Preferred One-Shot Learning Towards Efficient Neural Architecture Search
2 |
3 | PreNAS is a novel learning paradigm that integrates one-shot and zero-shot NAS techniques to enhance search efficiency and training effectiveness.
4 | This search-free approach outperforms current state-of-the-art one-shot NAS methods for both Vision Transformer and convolutional architectures,
5 | as confirmed by its superior performance when the code is released.
6 |
7 | >Wang H, Ge C, Chen H and Sun X. PreNAS: Preferred One-Shot Learning Towards Efficient Neural Architecture Search. ICML 2023.
8 |
9 | Paper link: [arXiv](https://arxiv.org/abs/2304.14636)
10 |
11 | ## Overview
12 |
13 |
14 |
15 | Previous one-shot NAS samples all architectures in the search space when one-shot training of the supernet for better evaluation in evolution search.
16 | Instead, PreNAS first searches the target architectures via a zero-cost proxy and next applies preferred one-shot training to supernet.
17 | PreNAS improves the Pareto Frontier benefited from the preferred one-shot learning and is search-free after training by offering the models with the
18 | advance selected architectures from the zero-cost search.
19 |
20 | ## Environment Setup
21 |
22 | To set up the environment you can easily run the following command:
23 | ```buildoutcfg
24 | conda create -n PreNAS python=3.7
25 | conda activate PreNAS
26 | pip install -r requirements.txt
27 | ```
28 |
29 | ## Data Preparation
30 | You need to download the [ImageNet-2012](http://www.image-net.org/) to the folder `../data/imagenet`.
31 |
32 | ## Run example
33 | The code was run on 8 x 80G A100.
34 | - Zero-Shot Search
35 |
36 | `bash 01_zero_shot_search.sh`
37 |
38 | - One-Shot Training
39 |
40 | `bash 02_one_shot_training.sh`
41 |
42 | - Evaluation
43 |
44 | `bash 03_evaluation.sh`
45 |
46 | ## Model Zoo
47 |
48 | | Model | TOP-1 (%) | TOP-5 (%) | #Params (M) | FLOPs (G) | Download Link |
49 | | ------------ | ---------- | ------------- | ------------- | --------- | ------------- |
50 | | PreNAS-Ti | 77.1 | 93.4 | 5.9 | 1.4 | [AliCloud](https://idstcv.oss-cn-zhangjiakou.aliyuncs.com/PreNAS/supernet-tiny.pth) |
51 | | PreNAS-S | 81.8 | 95.9 | 22.9 | 5.1 | [AliCloud](https://idstcv.oss-cn-zhangjiakou.aliyuncs.com/PreNAS/supernet-small.pth) |
52 | | PreNAS-B | 82.6 | 96.0 | 54 | 11 | [AliCloud](https://idstcv.oss-cn-zhangjiakou.aliyuncs.com/PreNAS/supernet-base.pth) |
53 |
54 | ## Bibtex
55 |
56 | If PreNAS is useful for you, please consider to cite it. Thank you! :)
57 | ```bibtex
58 | @InProceedings{PreNAS,
59 | title = {PreNAS: Preferred One-Shot Learning Towards Efficient Neural Architecture Search},
60 | author = {Wang, Haibin and Ge, Ce and Chen, Hesen and Sun, Xiuyu},
61 | booktitle = {International Conference on Machine Learning (ICML)},
62 | month = {July},
63 | year = {2023}
64 | }
65 | ```
66 |
67 | ## Acknowledgements
68 |
69 | The codes are inspired by [AutoFormer](https://github.com/microsoft/Cream/tree/main/AutoFormer).
70 |
--------------------------------------------------------------------------------
/candidates_to_choices.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import json
3 | from collections import defaultdict
4 |
5 | def candidate_to_choices(candidate_path, topN=float('inf')):
6 | interval_cands = json.load(open(candidate_path))
7 |
8 | # init
9 | new_embed_dim = []
10 | new_mlp_ratio = defaultdict(lambda : defaultdict(list))
11 | new_num_heads = defaultdict(lambda : defaultdict(list))
12 | new_depth = defaultdict(list)
13 |
14 | for cand_list in interval_cands.values():
15 | for i in range(min(topN, len(cand_list))):
16 | cur_cand = cand_list[i]
17 | # embed dim
18 | embed_dim = cur_cand['embed_dim'][0]
19 | new_embed_dim.append(embed_dim) if embed_dim not in new_embed_dim else None
20 | # depth
21 | depth = cur_cand['layer_num']
22 | new_depth[embed_dim].append(depth) if depth not in new_depth[embed_dim] else None
23 | # mlp & heads
24 | for layer_id, (mlp_ratio, num_heads) in enumerate(zip(cur_cand['mlp_ratio'], cur_cand['num_heads'])):
25 | pt_mlp_ratio = new_mlp_ratio[embed_dim][layer_id]
26 | pt_mlp_ratio.append(mlp_ratio) if mlp_ratio not in pt_mlp_ratio else None
27 | pt_num_heads = new_num_heads[embed_dim][layer_id]
28 | pt_num_heads.append(num_heads) if num_heads not in pt_num_heads else None
29 |
30 | return {'embed_dim': sorted(new_embed_dim),
31 | 'mlp_ratio': {dim: [sorted(ratios[layer]) for layer in sorted(ratios)] for dim, ratios in new_mlp_ratio.items()},
32 | 'num_heads': {dim: [sorted(heads[layer]) for layer in sorted(heads)] for dim, heads in new_num_heads.items()},
33 | 'depth': {dim: sorted(deps) for dim, deps in new_depth.items()},
34 | }
35 |
36 |
37 | if __name__ == '__main__':
38 | import os, sys, yaml
39 |
40 | cand_file = os.path.normpath(sys.argv[1])
41 | conf_file = os.path.normpath(sys.argv[2])
42 | if os.path.exists(conf_file):
43 | print(f'Target file already exists: {conf_file}')
44 | exit()
45 |
46 | new_choices = candidate_to_choices(cand_file)
47 | #print(new_choices)
48 | cfg = dict()
49 | cfg['SEARCH_SPACE'] = {k.upper(): v for k, v in new_choices.items()}
50 | max_depth = max({dep for deps in new_choices['depth'].values() for dep in deps})
51 | max_ratio = max(max(ratio_list) for ratios in new_choices['mlp_ratio'].values() for ratio_list in ratios)
52 | max_heads = max(max(heads_list) for heads in new_choices['num_heads'].values() for heads_list in heads)
53 | max_dim = max_heads * 64
54 | assert max_dim >= max(new_choices['embed_dim'])
55 | cfg['SUPERNET'] = {'DEPTH': max_depth, 'MLP_RATIO': max_ratio, 'NUM_HEADS': max_heads, 'EMBED_DIM': max_dim}
56 |
57 | yaml.safe_dump(cfg, open(conf_file, 'w'))
58 | print(f'Saved to: {conf_file}')
59 |
--------------------------------------------------------------------------------
/evolution_pre_train.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import numpy as np
4 | import time
5 | import torch
6 | import torch.backends.cudnn as cudnn
7 | from pathlib import Path
8 |
9 | from lib.datasets import build_dataset
10 | from lib import utils
11 | from supernet_engine import evaluate
12 | from model.supernet_transformer import Vision_TransformerSuper
13 | import argparse
14 | import os
15 | import yaml
16 | from lib.config import cfg, update_config_from_file
17 | from lib.score_maker import ScoreMaker
18 | import math
19 | from itertools import combinations
20 | import json
21 |
22 |
23 | def decode_cand_tuple(cand_tuple):
24 | depth = cand_tuple[0]
25 | return depth, list(cand_tuple[1:depth+1]), list(cand_tuple[depth + 1: 2 * depth + 1]), cand_tuple[-1]
26 |
27 |
28 | def get_max_min_model(choices):
29 | max_depth = max(choices['depth'])
30 | max_emb = max(choices['embed_dim'])
31 | max_num_head = max(choices['num_heads'])
32 | max_mlp_ratio = max(choices['mlp_ratio'])
33 | min_depth = min(choices['depth'])
34 | min_emb = min(choices['embed_dim'])
35 | min_num_head = min(choices['num_heads'])
36 | min_mlp_ratio = min(choices['mlp_ratio'])
37 | max_model = tuple([max_depth] + [max_mlp_ratio] * max_depth + [max_num_head] * max_depth + [max_emb])
38 | min_model = tuple([min_depth] + [min_mlp_ratio] * min_depth + [min_num_head] * min_depth + [min_emb])
39 | return max_model, min_model
40 |
41 |
42 | class Searcher(object):
43 |
44 | def __init__(self, args, device, model, model_without_ddp, choices, output_dir, score_maker):
45 | self.device = device
46 | self.model = model
47 | self.model_without_ddp = model_without_ddp
48 | self.args = args
49 | self.max_epochs = args.max_epochs
50 | self.select_num = args.select_num
51 | self.population_num = args.population_num
52 | self.m_prob = args.m_prob
53 | self.crossover_num = args.crossover_num
54 | self.mutation_num = args.mutation_num
55 | self.parameters_limits = args.param_limits
56 | self.min_parameters_limits = args.min_param_limits
57 | self.output_dir = output_dir
58 | self.s_prob =args.s_prob
59 | self.memory = []
60 | self.vis_dict = {}
61 | self.keep_top_k = {}
62 | self.epoch = 0
63 | self.checkpoint_path = args.resume
64 | self.candidates = []
65 | self.top_accuracies = []
66 | self.cand_params = []
67 | self.choices = choices
68 | self.choices['num_heads'].sort()
69 | self.choices['mlp_ratio'].sort()
70 |
71 | self.score_maker = score_maker
72 | self.eval_cnt = 0
73 | self.update_num = 0
74 | self.un_update_cnt = 0
75 |
76 | self.all_cands = {}
77 | min_param = self.min_parameters_limits
78 | max_param = min_param + self.args.param_interval
79 | while max_param < self.parameters_limits + 1e-6:
80 | params = (max_param + min_param) / 2
81 | self.all_cands[self.param_to_index(params)] = []
82 | min_param = max_param
83 | max_param = min_param + self.args.param_interval
84 |
85 | self.cur_min_param = args.min_param_limits
86 | self.cur_max_param = args.param_limits
87 | self.interval_cands = {}
88 | self.max_model, self.min_model = get_max_min_model(choices)
89 | self.search_mode = args.search_mode
90 | self.head_mlp_scores = {}
91 |
92 | def get_params_range(self):
93 | depth, mlp_ratio, num_heads, embed_dim = decode_cand_tuple(self.max_model)
94 | sampled_config = {}
95 | sampled_config['layer_num'] = depth
96 | sampled_config['mlp_ratio'] = mlp_ratio
97 | sampled_config['num_heads'] = num_heads
98 | sampled_config['embed_dim'] = [embed_dim] * depth
99 |
100 | n_parameters = self.model_without_ddp.get_sampled_params_numel(sampled_config)
101 | max_params = n_parameters / 10. ** 6
102 |
103 | depth, mlp_ratio, num_heads, embed_dim = decode_cand_tuple(self.min_model)
104 | sampled_config = {}
105 | sampled_config['layer_num'] = depth
106 | sampled_config['mlp_ratio'] = mlp_ratio
107 | sampled_config['num_heads'] = num_heads
108 | sampled_config['embed_dim'] = [embed_dim] * depth
109 |
110 | n_parameters = self.model_without_ddp.get_sampled_params_numel(sampled_config)
111 | min_params = n_parameters / 10. ** 6
112 | return min_params, max_params
113 |
114 | def select_cands(self, *, key, reverse=True):
115 | for k in self.all_cands.keys():
116 | t = self.all_cands[k]
117 | t.sort(key=key, reverse=reverse)
118 | self.all_cands[k] = t[:self.args.cand_per_interval]
119 |
120 | def param_to_index(self, param):
121 | if param < self.min_parameters_limits:
122 | return -1
123 | if param >= self.parameters_limits:
124 | return -1
125 | return math.floor((param - self.min_parameters_limits) / self.args.param_interval)
126 |
127 | def index_to_param_interval(self, index):
128 | if index == -1:
129 | return (0, self.min_parameters_limits)
130 | if index == -2:
131 | return (self.parameters_limits, 2*self.parameters_limits)
132 | down = self.min_parameters_limits + index * self.args.param_interval
133 | up = down + self.args.param_interval
134 | return (down, up)
135 |
136 | def stack_random_cand(self, random_func, *, batchsize=10):
137 | while True:
138 | cands = [random_func() for _ in range(batchsize)]
139 | for cand in cands:
140 | if cand not in self.vis_dict:
141 | self.vis_dict[cand] = {}
142 | info = self.vis_dict[cand]
143 | for cand in cands:
144 | yield cand
145 |
146 |
147 | def get_random_cand_without_reallocate(self):
148 |
149 | cand_tuple = list()
150 | dimensions = ['mlp_ratio', 'num_heads']
151 | depth = random.choice(self.choices['depth'])
152 | cand_tuple.append(depth)
153 | for dimension in dimensions:
154 | idx = list(range(len(self.choices[dimension])))
155 | random.shuffle(idx)
156 | choice_cnt = {}
157 | left_layers = depth
158 | for i in idx[:-1]:
159 | choice = self.choices[dimension][i]
160 | cnt = random.choice(range(left_layers + 1))
161 | left_layers = left_layers - cnt
162 | choice_cnt[choice] = cnt
163 | choice = self.choices[dimension][idx[-1]]
164 | choice_cnt[choice] = left_layers
165 | conf = [0] * depth
166 |
167 | for choice in self.choices[dimension][1:][::-1]:
168 | scores = np.random.rand(depth)
169 | mask = np.where(np.array(conf) > 0, -1, 1)
170 | mask_scores = scores * mask
171 | for i in mask_scores.argsort()[::-1][:choice_cnt[choice]]:
172 | conf[i] = choice
173 | for i in range(len(conf)):
174 | if conf[i] == 0:
175 | conf[i] = self.choices[dimension][0]
176 |
177 | cand_tuple.extend(conf)
178 |
179 | cand_tuple.append(random.choice(self.choices['embed_dim']))
180 | return tuple(cand_tuple)
181 |
182 | def get_random_cand(self):
183 |
184 | cand_tuple = list()
185 | dimensions = ['mlp_ratio', 'num_heads']
186 | score_names = ['mlp_scores', 'head_scores']
187 | depth = random.choice(self.choices['depth'])
188 | cand_tuple.append(depth)
189 | emb_dim = random.choice(self.choices['embed_dim'])
190 | max_dim = max(self.choices['embed_dim'])
191 | for (dimension, score_name) in zip(dimensions, score_names):
192 | idx = list(range(len(self.choices[dimension])))
193 | random.shuffle(idx)
194 | choice_cnt = {}
195 | left_layers = depth
196 | for i in idx[:-1]:
197 | choice = self.choices[dimension][i]
198 | cnt = random.choice(range(left_layers + 1))
199 | left_layers = left_layers - cnt
200 | choice_cnt[choice] = cnt
201 | choice = self.choices[dimension][idx[-1]]
202 | choice_cnt[choice] = left_layers
203 | choice_cnt_list = [choice_cnt[choice] for choice in self.choices[dimension]]
204 | method = None
205 | if dimension == 'mlp_ratio':
206 | method = self.args.block_score_method_for_mlp
207 | else:
208 | method = self.args.block_score_method_for_head
209 | cand_tuple.extend(self.reallocate(depth,
210 | emb_dim,
211 | dimension,
212 | self.head_mlp_scores[score_name],
213 | choice_cnt_list,
214 | method))
215 |
216 | cand_tuple.append(emb_dim)
217 | return tuple(cand_tuple)
218 |
219 | def get_random(self, num):
220 | print('random select ........')
221 | if self.args.search_mode == 'iteration' or self.args.reallocate:
222 | cand_iter = self.stack_random_cand(self.get_random_cand)
223 | else:
224 | cand_iter = self.stack_random_cand(self.get_random_cand_without_reallocate)
225 | while len(self.candidates) < num:
226 | cand = next(cand_iter)
227 | if not self.is_legal(cand):
228 | continue
229 | self.candidates.append(cand)
230 | print('random {}/{}'.format(len(self.candidates), num))
231 | print('random_num = {}'.format(len(self.candidates)))
232 |
233 | def is_legal(self, cand):
234 | assert isinstance(cand, tuple)
235 |
236 | if cand not in self.vis_dict:
237 | self.vis_dict[cand] = {}
238 | info = self.vis_dict[cand]
239 | if 'visited' in info:
240 | return False
241 | depth, mlp_ratio, num_heads, embed_dim = decode_cand_tuple(cand)
242 | sampled_config = {}
243 | sampled_config['layer_num'] = depth
244 | sampled_config['mlp_ratio'] = mlp_ratio
245 | sampled_config['num_heads'] = num_heads
246 | sampled_config['embed_dim'] = [embed_dim]*depth
247 |
248 | n_parameters = self.model_without_ddp.get_sampled_params_numel(sampled_config)
249 | info['params'] = n_parameters / 10.**6
250 |
251 | if info['params'] > self.cur_max_param:
252 | print('parameters limit exceed {}'.format(self.cur_max_param))
253 | return False
254 |
255 | if info['params'] < self.cur_min_param:
256 | print('under minimum parameters limit {}'.format(self.cur_min_param))
257 | return False
258 |
259 | info['visited'] = True
260 |
261 | return True
262 |
263 | def conf_to_cnt_list(self, conf, part):
264 | cnt_list = [0]*len(self.choices[part])
265 | for choice in conf:
266 | cnt_list[self.choices[part].index(choice)] += 1
267 | return cnt_list
268 |
269 | def reallocate(self, depth, embed_dim, part, scores, choice_cnt, method):
270 |
271 | if method == 'deeper_is_better':
272 | conf = []
273 | for choice, cnt in zip(self.choices[part], choice_cnt):
274 | conf = conf + ([choice] * cnt)
275 | return conf
276 |
277 | if 'max_dim' in method:
278 | embed_dim = max(self.choices['embed_dim'])
279 |
280 | conf = [0] * depth
281 | for choice, cnt in zip(self.choices[part][1:][::-1], choice_cnt[1:][::-1]):
282 | cur_scores = np.array(scores[(f"{embed_dim},{choice}")][:depth])
283 | mask = np.where(np.array(conf) > 0, -1, 1)
284 | mask_scores = cur_scores * mask
285 | for i in mask_scores.argsort()[::-1][:cnt]:
286 | conf[i] = choice
287 | for i in range(len(conf)):
288 | if conf[i] == 0:
289 | conf[i] = self.choices[part][0]
290 | return conf
291 |
292 | def get_score(self):
293 | for cand in self.candidates:
294 | info = self.vis_dict[cand]
295 | if self.args.score_method == 'params':
296 | info['score'] = info['params']
297 | else:
298 | depth, mlp_ratio, num_heads, embed_dim = decode_cand_tuple(cand)
299 | sampled_config = {}
300 | sampled_config['layer_num'] = depth
301 | sampled_config['mlp_ratio'] = mlp_ratio
302 | sampled_config['num_heads'] = num_heads
303 | sampled_config['embed_dim'] = [embed_dim] * depth
304 | score = self.score_maker.get_score(self.model, self.args.score_method, config=sampled_config)
305 | info['score'] = score
306 |
307 | def update_top_k(self, candidates, *, k, key, reverse=True, get_update_num=False):
308 | assert k in self.keep_top_k
309 | print('select ......')
310 | t = self.keep_top_k[k]
311 | t += candidates
312 | t.sort(key=key, reverse=reverse)
313 | self.keep_top_k[k] = t[:k]
314 | if get_update_num:
315 | self.update_num = 0
316 | for cand in self.keep_top_k[k]:
317 | if cand in candidates:
318 | self.update_num += 1
319 | print('update {} models in top {}.'.format(self.update_num, k))
320 | if self.update_num == 0:
321 | self.un_update_cnt += 1
322 |
323 | def get_mutation(self, k, mutation_num, m_prob, s_prob):
324 | assert k in self.keep_top_k
325 | print('mutation ......')
326 | res = []
327 | iter = 0
328 | max_iters = mutation_num * 10
329 |
330 | def random_func():
331 | cand = list(random.choice(self.keep_top_k[k]))
332 | depth, mlp_ratio, num_heads, embed_dim = decode_cand_tuple(cand)
333 | random_s = random.random()
334 |
335 | # depth
336 | if random_s < s_prob:
337 | new_depth = random.choice(self.choices['depth'])
338 |
339 | if new_depth > depth:
340 | mlp_ratio = mlp_ratio + [random.choice(self.choices['mlp_ratio']) for _ in range(new_depth - depth)]
341 | num_heads = num_heads + [random.choice(self.choices['num_heads']) for _ in range(new_depth - depth)]
342 | else:
343 | mlp_ratio = mlp_ratio[:new_depth]
344 | num_heads = num_heads[:new_depth]
345 |
346 | depth = new_depth
347 | # mlp_ratio
348 | for i in range(depth):
349 | random_s = random.random()
350 | if random_s < m_prob:
351 | mlp_ratio[i] = random.choice(self.choices['mlp_ratio'])
352 |
353 | # num_heads
354 |
355 | for i in range(depth):
356 | random_s = random.random()
357 | if random_s < m_prob:
358 | num_heads[i] = random.choice(self.choices['num_heads'])
359 |
360 | # embed_dim
361 | random_s = random.random()
362 | if random_s < s_prob:
363 | embed_dim = random.choice(self.choices['embed_dim'])
364 |
365 | mlp_cnt = self.conf_to_cnt_list(mlp_ratio, 'mlp_ratio')
366 | head_cnt = self.conf_to_cnt_list(num_heads, 'num_heads')
367 | mlp_ratio = self.reallocate(depth,
368 | embed_dim,
369 | 'mlp_ratio',
370 | self.head_mlp_scores['mlp_scores'],
371 | mlp_cnt,
372 | self.args.block_score_method_for_mlp)
373 | num_heads = self.reallocate(depth,
374 | embed_dim,
375 | 'num_heads',
376 | self.head_mlp_scores['head_scores'],
377 | head_cnt,
378 | self.args.block_score_method_for_head)
379 |
380 | result_cand = [depth] + mlp_ratio + num_heads + [embed_dim]
381 |
382 | return tuple(result_cand)
383 |
384 | cand_iter = self.stack_random_cand(random_func)
385 | while len(res) < mutation_num and max_iters > 0:
386 | max_iters -= 1
387 | cand = next(cand_iter)
388 | if not self.is_legal(cand):
389 | continue
390 | res.append(cand)
391 | print('mutation {}/{}'.format(len(res), mutation_num))
392 |
393 | print('mutation_num = {}'.format(len(res)))
394 | return res
395 |
396 | def get_crossover(self, k, crossover_num):
397 | assert k in self.keep_top_k
398 | print('crossover ......')
399 | res = []
400 | iter = 0
401 | max_iters = 10 * crossover_num
402 |
403 | def random_func():
404 |
405 | p1 = random.choice(self.keep_top_k[k])
406 | p2 = random.choice(self.keep_top_k[k])
407 | max_iters_tmp = 50
408 | while len(p1) != len(p2) and max_iters_tmp > 0:
409 | max_iters_tmp -= 1
410 | p1 = random.choice(self.keep_top_k[k])
411 | p2 = random.choice(self.keep_top_k[k])
412 | cand = tuple(random.choice([i, j]) for i, j in zip(p1, p2))
413 | depth, mlp_ratio, num_heads, embed_dim = decode_cand_tuple(cand)
414 | mlp_cnt = self.conf_to_cnt_list(mlp_ratio, 'mlp_ratio')
415 | head_cnt = self.conf_to_cnt_list(num_heads, 'num_heads')
416 | mlp_ratio = self.reallocate(depth,
417 | embed_dim,
418 | 'mlp_ratio',
419 | self.head_mlp_scores['mlp_scores'],
420 | mlp_cnt,
421 | self.args.block_score_method_for_mlp)
422 | num_heads = self.reallocate(depth,
423 | embed_dim,
424 | 'num_heads',
425 | self.head_mlp_scores['head_scores'],
426 | head_cnt,
427 | self.args.block_score_method_for_head)
428 | result_cand = [depth] + mlp_ratio + num_heads + [embed_dim]
429 | return tuple(result_cand)
430 |
431 | cand_iter = self.stack_random_cand(random_func)
432 | while len(res) < crossover_num and max_iters > 0:
433 | max_iters -= 1
434 | cand = next(cand_iter)
435 | if not self.is_legal(cand):
436 | continue
437 | res.append(cand)
438 | print('crossover {}/{}'.format(len(res), crossover_num))
439 |
440 | print('crossover_num = {}'.format(len(res)))
441 | return res
442 |
443 | def search(self, out_file_name=None):
444 |
445 | print('searching...')
446 | if not self.args.block_score_method_for_mlp == 'deeper_is_better' or not self.args.block_score_method_for_head == 'deeper_is_better':
447 | self.head_mlp_scores = self.score_maker.get_block_scores(self.model, self.args, self.choices)
448 |
449 | # random search
450 | if self.args.search_mode == 'random':
451 | self.cur_min_param = self.min_parameters_limits
452 | self.cur_max_param = self.cur_min_param + self.args.param_interval
453 |
454 | while self.cur_max_param < self.parameters_limits + 1e-6:
455 | self.candidates = []
456 | self.keep_top_k = {100: []}
457 | self.get_random(self.population_num)
458 | self.get_score()
459 | self.update_top_k(
460 | self.candidates, k=100, key=lambda x: self.vis_dict[x]['score'])
461 | for i, cand in enumerate(self.keep_top_k[100]):
462 | print('No.{} {} score = {}, params = {}'.format(
463 | i + 1, cand, self.vis_dict[cand]['score'], self.vis_dict[cand]['params']))
464 | self.interval_cands[(self.cur_min_param, self.cur_max_param)] = self.keep_top_k[100][:self.args.cand_per_interval]
465 | self.cur_min_param = self.cur_max_param
466 | self.cur_max_param = self.cur_min_param + self.args.param_interval
467 | # evolution search
468 | elif self.args.search_mode == 'evolution':
469 | self.cur_min_param = self.min_parameters_limits
470 | self.cur_max_param = self.cur_min_param + self.args.param_interval
471 |
472 | while self.cur_max_param < self.parameters_limits + 1e-6:
473 | self.update_num = 0
474 | self.un_update_cnt = 0
475 | self.epoch = 0
476 | self.candidates = []
477 | self.keep_top_k = {self.select_num: [], 100: []}
478 | self.get_random(self.population_num)
479 | while self.epoch < self.max_epochs:
480 | print('epoch = {} for param {} to param {}'.format(self.epoch, self.cur_min_param, self.cur_max_param))
481 |
482 | if self.un_update_cnt == 2:
483 | self.epoch += 1
484 | continue
485 |
486 | self.get_score()
487 | self.update_top_k(
488 | self.candidates, k=self.select_num, key=lambda x: self.vis_dict[x]['score'], get_update_num=True)
489 | self.update_top_k(
490 | self.candidates, k=100, key=lambda x: self.vis_dict[x]['score'])
491 |
492 | print('epoch = {} for param {} to param {} : top {} result'.format(
493 | self.epoch, self.cur_min_param, self.cur_max_param, len(self.keep_top_k[100])))
494 | for i, cand in enumerate(self.keep_top_k[100]):
495 | print('No.{} {} score = {}, params = {}'.format(
496 | i + 1, cand, self.vis_dict[cand]['score'], self.vis_dict[cand]['params']))
497 |
498 | self.epoch += 1
499 | if self.epoch >= self.max_epochs:
500 | break
501 |
502 | # check
503 | mutation = self.get_mutation(
504 | self.select_num, self.mutation_num, self.m_prob, self.s_prob)
505 | crossover = self.get_crossover(self.select_num, self.crossover_num)
506 |
507 | self.candidates = mutation + crossover
508 |
509 | self.get_random(self.population_num)
510 |
511 | self.interval_cands[(self.cur_min_param, self.cur_max_param)] = self.keep_top_k[100][:self.args.cand_per_interval]
512 | self.cur_min_param = self.cur_max_param
513 | self.cur_max_param = self.cur_min_param + self.args.param_interval
514 | # force search
515 | else:
516 | max_dim = max(self.choices['embed_dim'])
517 | iter_cnt = 0
518 | for embed_dim in self.choices['embed_dim']:
519 | for depth in self.choices['depth']:
520 | depth_ids = list(range(depth+1))
521 | num_head_choice = len(self.choices['num_heads'])
522 | num_mlp_choice = len(self.choices['mlp_ratio'])
523 | mlp_confs = []
524 | head_confs = []
525 |
526 | for mlp_dist in combinations(depth_ids, num_mlp_choice - 1):
527 | mlp_dist = [0] + list(mlp_dist) + [depth]
528 | mlp_cnt = [mlp_dist[i+1] - mlp_dist[i] for i in range(len(mlp_dist)-1)]
529 | mlp_confs.append(self.reallocate(depth,
530 | embed_dim,
531 | 'mlp_ratio',
532 | self.head_mlp_scores['mlp_scores'],
533 | mlp_cnt,
534 | self.args.block_score_method_for_mlp))
535 |
536 | for head_dist in combinations(depth_ids, num_head_choice - 1):
537 | head_dist = [0] + list(head_dist) + [depth]
538 | head_cnt = [head_dist[i+1] - head_dist[i] for i in range(len(head_dist)-1)]
539 | head_confs.append(self.reallocate(depth,
540 | embed_dim,
541 | 'num_heads',
542 | self.head_mlp_scores['head_scores'],
543 | head_cnt,
544 | self.args.block_score_method_for_head))
545 |
546 | for mlp_conf in mlp_confs:
547 | iter_cnt += 1
548 | for head_conf in head_confs:
549 | cand = tuple([depth] + mlp_conf + head_conf + [embed_dim])
550 | depth, mlp_ratio, num_heads, embed_dim = decode_cand_tuple(cand)
551 | sampled_config = {}
552 | sampled_config['layer_num'] = depth
553 | sampled_config['mlp_ratio'] = mlp_ratio
554 | sampled_config['num_heads'] = num_heads
555 | sampled_config['embed_dim'] = [embed_dim] * depth
556 | n_parameters = self.model_without_ddp.get_sampled_params_numel(sampled_config)
557 | params = n_parameters / 10. ** 6
558 | index = self.param_to_index(params)
559 |
560 | if self.args.score_method == 'params':
561 | score = params
562 | else:
563 | score = self.score_maker.get_score(self.model, self.args.score_method, config=sampled_config)
564 |
565 | info = {'cand': cand, 'score': score, 'params': params}
566 | self.vis_dict[cand] = info
567 | if index in self.all_cands.keys():
568 | self.all_cands[index].append(info)
569 |
570 | self.select_cands(key=lambda x: x['score'])
571 |
572 | for index in self.all_cands.keys():
573 | k = self.index_to_param_interval(index)
574 | self.interval_cands[k] = [item['cand'] for item in self.all_cands[index]]
575 |
576 | if out_file_name is None:
577 | out_file_name = f'out/interval_cands_{self.args.super_model_size}_{self.args.score_method}_{self.args.block_score_method_for_mlp}_for_mlp_{self.args.block_score_method_for_head}_for_head'
578 | out_file_name += f'_i{self.args.param_interval}_top_{self.args.cand_per_interval}.pt'
579 | torch.save(self.interval_cands, out_file_name)
580 | else:
581 | json_dict = {}
582 | for interval in self.interval_cands.keys():
583 | cand_list = []
584 | for cand in self.interval_cands[interval]:
585 | depth, mlp_ratio, num_heads, embed_dim = decode_cand_tuple(cand)
586 | info = {
587 | 'layer_num': depth,
588 | 'mlp_ratio': mlp_ratio,
589 | 'num_heads': num_heads,
590 | 'embed_dim': [embed_dim]*depth,
591 | 'num_params': float(self.vis_dict[cand]['params']),
592 | 'score': float(self.vis_dict[cand]['score'])
593 | }
594 | cand_list.append(info)
595 | if len(cand_list) > 0:
596 | json_dict[str(interval[1])] = cand_list
597 | print("selected candidates:")
598 | print(json_dict)
599 | with open(out_file_name, "w") as fp:
600 | json.dump(json_dict, fp, indent=2)
601 | fp.close()
602 |
603 |
604 |
605 | return self.interval_cands
606 |
607 |
--------------------------------------------------------------------------------
/experiments/supernet/base.yaml:
--------------------------------------------------------------------------------
1 | SEARCH_SPACE:
2 | DEPTH:
3 | 528:
4 | - 14
5 | 576:
6 | - 14
7 | 624:
8 | - 14
9 | EMBED_DIM:
10 | - 528
11 | - 576
12 | - 624
13 | MLP_RATIO:
14 | 528:
15 | - - 3.0
16 | - - 3.0
17 | - - 3.0
18 | - - 3.0
19 | - - 3.0
20 | - - 3.0
21 | - - 3.0
22 | - - 3.0
23 | - - 3.0
24 | - - 3.0
25 | - - 3.0
26 | - - 3.0
27 | - - 3.5
28 | - - 3.0
29 | 576:
30 | - - 3.5
31 | - - 4.0
32 | - - 3.5
33 | - - 3.5
34 | - - 3.5
35 | - - 4.0
36 | - - 3.5
37 | - - 3.5
38 | - - 3.5
39 | - - 3.5
40 | - - 4.0
41 | - - 3.0
42 | - - 3.5
43 | - - 4.0
44 | 624:
45 | - - 4.0
46 | - - 4.0
47 | - - 3.5
48 | - - 4.0
49 | - - 4.0
50 | - - 4.0
51 | - - 4.0
52 | - - 4.0
53 | - - 4.0
54 | - - 4.0
55 | - - 4.0
56 | - - 4.0
57 | - - 4.0
58 | - - 4.0
59 | NUM_HEADS:
60 | 528:
61 | - - 9
62 | - - 9
63 | - - 9
64 | - - 9
65 | - - 9
66 | - - 9
67 | - - 9
68 | - - 9
69 | - - 9
70 | - - 9
71 | - - 9
72 | - - 9
73 | - - 9
74 | - - 9
75 | 576:
76 | - - 9
77 | - - 9
78 | - - 10
79 | - - 9
80 | - - 9
81 | - - 10
82 | - - 9
83 | - - 9
84 | - - 9
85 | - - 9
86 | - - 9
87 | - - 9
88 | - - 10
89 | - - 9
90 | 624:
91 | - - 9
92 | - - 10
93 | - - 10
94 | - - 9
95 | - - 9
96 | - - 10
97 | - - 10
98 | - - 9
99 | - - 10
100 | - - 9
101 | - - 9
102 | - - 9
103 | - - 10
104 | - - 10
105 | SUPERNET:
106 | DEPTH: 14
107 | EMBED_DIM: 640
108 | MLP_RATIO: 4.0
109 | NUM_HEADS: 10
110 |
--------------------------------------------------------------------------------
/experiments/supernet/small.yaml:
--------------------------------------------------------------------------------
1 | SEARCH_SPACE:
2 | DEPTH:
3 | 320:
4 | - 13
5 | 384:
6 | - 13
7 | 448:
8 | - 13
9 | - 14
10 | EMBED_DIM:
11 | - 320
12 | - 384
13 | - 448
14 | MLP_RATIO:
15 | 320:
16 | - - 4.0
17 | - - 4.0
18 | - - 3.5
19 | - - 4.0
20 | - - 3.5
21 | - - 4.0
22 | - - 4.0
23 | - - 3.5
24 | - - 4.0
25 | - - 4.0
26 | - - 3.5
27 | - - 4.0
28 | - - 4.0
29 | 384:
30 | - - 4.0
31 | - - 4.0
32 | - - 3.5
33 | - - 4.0
34 | - - 3.5
35 | - - 4.0
36 | - - 3.5
37 | - - 3.5
38 | - - 4.0
39 | - - 4.0
40 | - - 3.0
41 | - - 4.0
42 | - - 3.5
43 | 448:
44 | - - 4.0
45 | - - 4.0
46 | - - 3.0
47 | - 3.5
48 | - - 4.0
49 | - - 3.0
50 | - 3.5
51 | - - 3.5
52 | - 4.0
53 | - - 3.0
54 | - 4.0
55 | - - 3.0
56 | - 3.5
57 | - - 4.0
58 | - - 3.0
59 | - 4.0
60 | - - 3.0
61 | - 3.5
62 | - - 4.0
63 | - - 3.0
64 | - 4.0
65 | - - 3.5
66 | NUM_HEADS:
67 | 320:
68 | - - 7
69 | - - 7
70 | - - 7
71 | - - 7
72 | - - 7
73 | - - 7
74 | - - 7
75 | - - 7
76 | - - 6
77 | - - 7
78 | - - 5
79 | - - 6
80 | - - 5
81 | 384:
82 | - - 7
83 | - - 7
84 | - - 7
85 | - - 7
86 | - - 5
87 | - - 7
88 | - - 7
89 | - - 5
90 | - - 6
91 | - - 5
92 | - - 5
93 | - - 6
94 | - - 5
95 | 448:
96 | - - 7
97 | - - 7
98 | - - 7
99 | - - 7
100 | - - 5
101 | - 7
102 | - - 7
103 | - - 7
104 | - - 5
105 | - 7
106 | - - 5
107 | - 6
108 | - - 5
109 | - 7
110 | - - 5
111 | - - 6
112 | - - 5
113 | - - 7
114 | SUPERNET:
115 | DEPTH: 14
116 | EMBED_DIM: 448
117 | MLP_RATIO: 4.0
118 | NUM_HEADS: 7
119 |
--------------------------------------------------------------------------------
/experiments/supernet/supernet-B.yaml:
--------------------------------------------------------------------------------
1 | SUPERNET:
2 | MLP_RATIO: 4.0
3 | NUM_HEADS: 10
4 | EMBED_DIM: 640
5 | DEPTH: 16
6 | SEARCH_SPACE:
7 | MLP_RATIO:
8 | - 3.0
9 | - 3.5
10 | - 4.0
11 | NUM_HEADS:
12 | - 9
13 | - 10
14 | DEPTH:
15 | - 14
16 | - 15
17 | - 16
18 | EMBED_DIM:
19 | - 528
20 | - 576
21 | - 624
22 |
--------------------------------------------------------------------------------
/experiments/supernet/supernet-S.yaml:
--------------------------------------------------------------------------------
1 | SUPERNET:
2 | MLP_RATIO: 4.0
3 | NUM_HEADS: 7
4 | EMBED_DIM: 448
5 | DEPTH: 14
6 | SEARCH_SPACE:
7 | MLP_RATIO:
8 | - 3.0
9 | - 3.5
10 | - 4.0
11 | NUM_HEADS:
12 | - 5
13 | - 6
14 | - 7
15 | DEPTH:
16 | - 13
17 | - 14
18 | EMBED_DIM:
19 | - 320
20 | - 384
21 | - 448
22 |
--------------------------------------------------------------------------------
/experiments/supernet/supernet-T.yaml:
--------------------------------------------------------------------------------
1 | SUPERNET:
2 | MLP_RATIO: 4.0
3 | NUM_HEADS: 4
4 | EMBED_DIM: 256
5 | DEPTH: 14
6 | SEARCH_SPACE:
7 | MLP_RATIO:
8 | - 3.5
9 | - 4
10 | NUM_HEADS:
11 | - 3
12 | - 4
13 | DEPTH:
14 | - 12
15 | - 13
16 | - 14
17 | EMBED_DIM:
18 | - 192
19 | - 216
20 | - 240
21 |
--------------------------------------------------------------------------------
/experiments/supernet/tiny.yaml:
--------------------------------------------------------------------------------
1 | SEARCH_SPACE:
2 | DEPTH:
3 | 192:
4 | - 12
5 | 216:
6 | - 12
7 | 240:
8 | - 12
9 | - 14
10 | EMBED_DIM:
11 | - 192
12 | - 216
13 | - 240
14 | MLP_RATIO:
15 | 192:
16 | - - 4
17 | - - 4
18 | - - 4
19 | - - 4
20 | - - 4
21 | - - 3.5
22 | - - 3.5
23 | - - 4
24 | - - 4
25 | - - 3.5
26 | - - 4
27 | - - 4
28 | 216:
29 | - - 4
30 | - - 3.5
31 | - - 4
32 | - - 3.5
33 | - - 3.5
34 | - - 3.5
35 | - - 3.5
36 | - - 3.5
37 | - - 3.5
38 | - - 3.5
39 | - - 3.5
40 | - - 3.5
41 | 240:
42 | - - 4
43 | - - 3.5
44 | - 4
45 | - - 4
46 | - - 4
47 | - - 3.5
48 | - 4
49 | - - 3.5
50 | - 4
51 | - - 3.5
52 | - 4
53 | - - 4
54 | - - 4
55 | - - 3.5
56 | - 4
57 | - - 4
58 | - - 3.5
59 | - 4
60 | - - 4
61 | - - 4
62 | NUM_HEADS:
63 | 192:
64 | - - 4
65 | - - 4
66 | - - 4
67 | - - 4
68 | - - 3
69 | - - 3
70 | - - 3
71 | - - 3
72 | - - 3
73 | - - 4
74 | - - 3
75 | - - 4
76 | 216:
77 | - - 4
78 | - - 4
79 | - - 4
80 | - - 4
81 | - - 4
82 | - - 3
83 | - - 3
84 | - - 3
85 | - - 4
86 | - - 4
87 | - - 3
88 | - - 4
89 | 240:
90 | - - 3
91 | - 4
92 | - - 3
93 | - 4
94 | - - 3
95 | - 4
96 | - - 3
97 | - 4
98 | - - 3
99 | - 4
100 | - - 3
101 | - 4
102 | - - 3
103 | - 4
104 | - - 3
105 | - 4
106 | - - 3
107 | - 4
108 | - - 3
109 | - 4
110 | - - 3
111 | - 4
112 | - - 3
113 | - 4
114 | - - 3
115 | - 4
116 | - - 4
117 | SUPERNET:
118 | DEPTH: 14
119 | EMBED_DIM: 256
120 | MLP_RATIO: 4
121 | NUM_HEADS: 4
122 |
--------------------------------------------------------------------------------
/interval_cands/base.json:
--------------------------------------------------------------------------------
1 | {"42.0": [{"layer_num": 14, "mlp_ratio": [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.5, 3.0], "num_heads": [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9], "embed_dim": [528, 528, 528, 528, 528, 528, 528, 528, 528, 528, 528, 528, 528, 528], "num_params": 41.966944, "score": 5461899.5}], "54.0": [{"layer_num": 14, "mlp_ratio": [3.5, 4.0, 3.5, 3.5, 3.5, 4.0, 3.5, 3.5, 3.5, 3.5, 4.0, 3.0, 3.5, 4.0], "num_heads": [9, 9, 10, 9, 9, 10, 9, 9, 9, 9, 9, 9, 10, 9], "embed_dim": [576, 576, 576, 576, 576, 576, 576, 576, 576, 576, 576, 576, 576, 576], "num_params": 53.876104, "score": 6887513.0}], "66.0": [{"layer_num": 14, "mlp_ratio": [4.0, 4.0, 3.5, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0], "num_heads": [9, 10, 10, 9, 9, 10, 10, 9, 10, 9, 9, 9, 10, 10], "embed_dim": [624, 624, 624, 624, 624, 624, 624, 624, 624, 624, 624, 624, 624, 624], "num_params": 65.916448, "score": 8251315.0}]}
--------------------------------------------------------------------------------
/interval_cands/small.json:
--------------------------------------------------------------------------------
1 | {"18.0": [{"layer_num": 13, "mlp_ratio": [4.0, 4.0, 3.5, 4.0, 3.5, 4.0, 4.0, 3.5, 4.0, 4.0, 3.5, 4.0, 4.0], "num_heads": [7, 7, 7, 7, 7, 7, 7, 7, 6, 7, 5, 6, 5], "embed_dim": [320, 320, 320, 320, 320, 320, 320, 320, 320, 320, 320, 320, 320], "num_params": 17.9914, "score": 2764111.5}], "23.0": [{"layer_num": 13, "mlp_ratio": [4.0, 4.0, 3.5, 4.0, 3.5, 4.0, 3.5, 3.5, 4.0, 4.0, 3.0, 4.0, 3.5], "num_heads": [7, 7, 7, 7, 5, 7, 7, 5, 6, 5, 5, 6, 5], "embed_dim": [384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384], "num_params": 22.989928, "score": 3528041.5}], "28.0": [{"layer_num": 13, "mlp_ratio": [4.0, 4.0, 3.0, 4.0, 3.0, 3.5, 3.0, 3.0, 4.0, 3.0, 3.0, 4.0, 3.0], "num_heads": [7, 7, 7, 7, 5, 7, 7, 5, 6, 5, 5, 6, 5], "embed_dim": [448, 448, 448, 448, 448, 448, 448, 448, 448, 448, 448, 448, 448], "num_params": 27.976008, "score": 4264070.0}], "33.0": [{"layer_num": 14, "mlp_ratio": [4.0, 4.0, 3.5, 4.0, 3.5, 4.0, 4.0, 3.5, 4.0, 4.0, 3.5, 4.0, 4.0, 3.5], "num_heads": [7, 7, 7, 7, 7, 7, 7, 7, 5, 7, 5, 6, 5, 7], "embed_dim": [448, 448, 448, 448, 448, 448, 448, 448, 448, 448, 448, 448, 448, 448], "num_params": 32.98164, "score": 4739829.0}]}
--------------------------------------------------------------------------------
/interval_cands/tiny.json:
--------------------------------------------------------------------------------
1 | {
2 | "6.0": [
3 | {
4 | "layer_num": 12,
5 | "mlp_ratio": [
6 | 4,
7 | 4,
8 | 4,
9 | 4,
10 | 4,
11 | 3.5,
12 | 3.5,
13 | 4,
14 | 4,
15 | 3.5,
16 | 4,
17 | 4
18 | ],
19 | "num_heads": [
20 | 4,
21 | 4,
22 | 4,
23 | 4,
24 | 3,
25 | 3,
26 | 3,
27 | 3,
28 | 3,
29 | 4,
30 | 3,
31 | 4
32 | ],
33 | "embed_dim": [
34 | 192,
35 | 192,
36 | 192,
37 | 192,
38 | 192,
39 | 192,
40 | 192,
41 | 192,
42 | 192,
43 | 192,
44 | 192,
45 | 192
46 | ],
47 | "num_params": 5.99476,
48 | "score": 1010076.75
49 | }
50 | ],
51 | "7.0": [
52 | {
53 | "layer_num": 12,
54 | "mlp_ratio": [
55 | 4,
56 | 3.5,
57 | 4,
58 | 3.5,
59 | 3.5,
60 | 3.5,
61 | 3.5,
62 | 3.5,
63 | 3.5,
64 | 3.5,
65 | 3.5,
66 | 3.5
67 | ],
68 | "num_heads": [
69 | 4,
70 | 4,
71 | 4,
72 | 4,
73 | 4,
74 | 3,
75 | 3,
76 | 3,
77 | 4,
78 | 4,
79 | 3,
80 | 4
81 | ],
82 | "embed_dim": [
83 | 216,
84 | 216,
85 | 216,
86 | 216,
87 | 216,
88 | 216,
89 | 216,
90 | 216,
91 | 216,
92 | 216,
93 | 216,
94 | 216
95 | ],
96 | "num_params": 6.997192,
97 | "score": 1156532.875
98 | }
99 | ],
100 | "8.0": [
101 | {
102 | "layer_num": 12,
103 | "mlp_ratio": [
104 | 4,
105 | 3.5,
106 | 4,
107 | 4,
108 | 3.5,
109 | 3.5,
110 | 3.5,
111 | 4,
112 | 4,
113 | 3.5,
114 | 4,
115 | 3.5
116 | ],
117 | "num_heads": [
118 | 3,
119 | 3,
120 | 3,
121 | 3,
122 | 3,
123 | 3,
124 | 3,
125 | 3,
126 | 3,
127 | 3,
128 | 3,
129 | 3
130 | ],
131 | "embed_dim": [
132 | 240,
133 | 240,
134 | 240,
135 | 240,
136 | 240,
137 | 240,
138 | 240,
139 | 240,
140 | 240,
141 | 240,
142 | 240,
143 | 240
144 | ],
145 | "num_params": 7.996552,
146 | "score": 1308445.75
147 | }
148 | ],
149 | "9.0": [
150 | {
151 | "layer_num": 12,
152 | "mlp_ratio": [
153 | 4,
154 | 4,
155 | 4,
156 | 4,
157 | 4,
158 | 3.5,
159 | 4,
160 | 4,
161 | 4,
162 | 3.5,
163 | 4,
164 | 4
165 | ],
166 | "num_heads": [
167 | 4,
168 | 4,
169 | 4,
170 | 4,
171 | 4,
172 | 4,
173 | 4,
174 | 4,
175 | 4,
176 | 4,
177 | 4,
178 | 4
179 | ],
180 | "embed_dim": [
181 | 240,
182 | 240,
183 | 240,
184 | 240,
185 | 240,
186 | 240,
187 | 240,
188 | 240,
189 | 240,
190 | 240,
191 | 240,
192 | 240
193 | ],
194 | "num_params": 8.967016,
195 | "score": 1426972.25
196 | }
197 | ],
198 | "10.0": [
199 | {
200 | "layer_num": 14,
201 | "mlp_ratio": [
202 | 4,
203 | 4,
204 | 4,
205 | 4,
206 | 4,
207 | 3.5,
208 | 3.5,
209 | 4,
210 | 4,
211 | 3.5,
212 | 4,
213 | 3.5,
214 | 4,
215 | 4
216 | ],
217 | "num_heads": [
218 | 4,
219 | 4,
220 | 4,
221 | 4,
222 | 4,
223 | 3,
224 | 3,
225 | 3,
226 | 4,
227 | 4,
228 | 3,
229 | 4,
230 | 3,
231 | 4
232 | ],
233 | "embed_dim": [
234 | 240,
235 | 240,
236 | 240,
237 | 240,
238 | 240,
239 | 240,
240 | 240,
241 | 240,
242 | 240,
243 | 240,
244 | 240,
245 | 240,
246 | 240,
247 | 240
248 | ],
249 | "num_params": 9.978232,
250 | "score": 1525250.25
251 | }
252 | ],
253 | "11.0": [
254 | {
255 | "layer_num": 14,
256 | "mlp_ratio": [
257 | 4,
258 | 4,
259 | 4,
260 | 4,
261 | 4,
262 | 4,
263 | 4,
264 | 4,
265 | 4,
266 | 4,
267 | 4,
268 | 4,
269 | 4,
270 | 4
271 | ],
272 | "num_heads": [
273 | 4,
274 | 4,
275 | 4,
276 | 4,
277 | 4,
278 | 4,
279 | 4,
280 | 4,
281 | 4,
282 | 4,
283 | 4,
284 | 4,
285 | 4,
286 | 4
287 | ],
288 | "embed_dim": [
289 | 240,
290 | 240,
291 | 240,
292 | 240,
293 | 240,
294 | 240,
295 | 240,
296 | 240,
297 | 240,
298 | 240,
299 | 240,
300 | 240,
301 | 240,
302 | 240
303 | ],
304 | "num_params": 10.517272,
305 | "score": 1581370.75
306 | }
307 | ]
308 | }
--------------------------------------------------------------------------------
/lib/config.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict
2 | import yaml
3 |
4 |
5 | class edict(EasyDict):
6 | def __setattr__(self, name, value):
7 | if isinstance(value, (list, tuple)):
8 | value = [self.__class__(x)
9 | if isinstance(x, dict) else x for x in value]
10 | elif isinstance(value, dict) and not isinstance(value, self.__class__):
11 | if not isinstance(next(iter(value)), int):
12 | value = self.__class__(value)
13 | dict.__setattr__(self, name, value)
14 | dict.__setitem__(self, name, value)
15 |
16 |
17 | cfg = edict()
18 |
19 |
20 | def _edict2dict(dest_dict, src_edict):
21 | if isinstance(dest_dict, dict) and isinstance(src_edict, dict):
22 | for k, v in src_edict.items():
23 | if not isinstance(v, edict):
24 | dest_dict[k] = v
25 | else:
26 | dest_dict[k] = {}
27 | _edict2dict(dest_dict[k], v)
28 | else:
29 | return
30 |
31 | def gen_config(config_file):
32 | cfg_dict = {}
33 | _edict2dict(cfg_dict, cfg)
34 | with open(config_file, 'w') as f:
35 | yaml.dump(cfg_dict, f, default_flow_style=False)
36 |
37 |
38 | def _update_config(base_cfg, exp_cfg):
39 | if isinstance(base_cfg, edict) and isinstance(exp_cfg, edict):
40 | for k, v in exp_cfg.items():
41 | base_cfg[k] = v
42 | else:
43 | return
44 |
45 |
46 | def update_config_from_file(filename):
47 | exp_config = None
48 | with open(filename) as f:
49 | exp_config = edict(yaml.safe_load(f))
50 | _update_config(cfg, exp_config)
51 |
52 |
53 |
--------------------------------------------------------------------------------
/lib/cuda.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from timm.utils.clip_grad import dispatch_clip_grad
3 |
4 |
5 | class NativeScaler:
6 | state_dict_key = "amp_scaler"
7 |
8 | def __init__(self):
9 | self._scaler = torch.cuda.amp.GradScaler()
10 |
11 | def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False):
12 | #self._scaler.scale(loss).backward(create_graph=create_graph)
13 | if clip_grad is not None:
14 | assert parameters is not None
15 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
16 | dispatch_clip_grad(parameters, clip_grad, mode=clip_mode)
17 | self._scaler.step(optimizer)
18 | self._scaler.update()
19 |
20 | def state_dict(self):
21 | return self._scaler.state_dict()
22 |
23 | def load_state_dict(self, state_dict):
24 | self._scaler.load_state_dict(state_dict)
25 |
--------------------------------------------------------------------------------
/lib/datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import json
4 | import torch
5 | import scipy
6 | import scipy.io as sio
7 | from skimage import io
8 |
9 | from torchvision import datasets, transforms
10 | from torchvision.datasets.folder import ImageFolder, default_loader
11 |
12 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
13 | from timm.data import create_transform
14 |
15 | class Flowers(ImageFolder):
16 | def __init__(self, root, train=True, transform=None, **kwargs):
17 | self.dataset_root = root
18 | self.loader = default_loader
19 | self.target_transform = None
20 | self.transform = transform
21 | label_path = os.path.join(root, 'imagelabels.mat')
22 | split_path = os.path.join(root, 'setid.mat')
23 |
24 | print('Dataset Flowers is trained with resolution 224!')
25 |
26 | # labels
27 | labels = sio.loadmat(label_path)['labels'][0]
28 | self.img_to_label = dict()
29 | for i in range(len(labels)):
30 | self.img_to_label[i] = labels[i]
31 |
32 | splits = sio.loadmat(split_path)
33 | self.trnid, self.valid, self.tstid = sorted(splits['trnid'][0].tolist()), \
34 | sorted(splits['valid'][0].tolist()), \
35 | sorted(splits['tstid'][0].tolist())
36 | if train:
37 | self.imgs = self.trnid + self.valid
38 | else:
39 | self.imgs = self.tstid
40 |
41 | self.samples = []
42 | for item in self.imgs:
43 | self.samples.append((os.path.join(root, 'jpg', "image_{:05d}.jpg".format(item)), self.img_to_label[item-1]-1))
44 |
45 | class Cars196(ImageFolder, datasets.CIFAR10):
46 | base_folder_devkit = 'devkit'
47 | base_folder_trainims = 'cars_train'
48 | base_folder_testims = 'cars_test'
49 |
50 | filename_testanno = 'cars_test_annos.mat'
51 | filename_trainanno = 'cars_train_annos.mat'
52 |
53 | base_folder = 'cars_train'
54 | train_list = [
55 | ['00001.jpg', '8df595812fee3ca9a215e1ad4b0fb0c4'],
56 | ['00002.jpg', '4b9e5efcc3612378ec63a22f618b5028']
57 | ]
58 | test_list = []
59 | num_training_classes = 98 # 196/2
60 |
61 | def __init__(self, root, train=False, transform=None, target_transform=None, **kwargs):
62 | self.root = root
63 | self.transform = transform
64 |
65 | self.target_transform = target_transform
66 | self.loader = default_loader
67 | print('Dataset Cars196 is trained with resolution 224!')
68 |
69 | self.samples = []
70 | self.nb_classes = 196
71 |
72 | if train:
73 | labels = \
74 | sio.loadmat(os.path.join(self.root, self.base_folder_devkit, self.filename_trainanno))['annotations'][0]
75 | for item in labels:
76 | img_name = item[-1].tolist()[0]
77 | label = int(item[4]) - 1
78 | self.samples.append((os.path.join(self.root, self.base_folder_trainims, img_name), label))
79 | else:
80 | labels = \
81 | sio.loadmat(os.path.join(self.root, 'cars_test_annos_withlabels.mat'))['annotations'][0]
82 | for item in labels:
83 | img_name = item[-1].tolist()[0]
84 | label = int(item[-2]) - 1
85 | self.samples.append((os.path.join(self.root, self.base_folder_testims, img_name), label))
86 |
87 | class Pets(ImageFolder):
88 | def __init__(self, root, train=True, transform=None, target_transform=None, **kwargs):
89 | self.dataset_root = root
90 | self.loader = default_loader
91 | self.target_transform = None
92 | self.transform = transform
93 | train_list_path = os.path.join(self.dataset_root, 'annotations', 'trainval.txt')
94 | test_list_path = os.path.join(self.dataset_root, 'annotations', 'test.txt')
95 |
96 | self.samples = []
97 | if train:
98 | with open(train_list_path, 'r') as f:
99 | for line in f:
100 | img_name = line.split(' ')[0]
101 | label = int(line.split(' ')[1])
102 | self.samples.append((os.path.join(root, 'images', "{}.jpg".format(img_name)), label-1))
103 | else:
104 | with open(test_list_path, 'r') as f:
105 | for line in f:
106 | img_name = line.split(' ')[0]
107 | label = int(line.split(' ')[1])
108 | self.samples.append((os.path.join(root, 'images', "{}.jpg".format(img_name)), label-1))
109 |
110 | class INatDataset(ImageFolder):
111 | def __init__(self, root, train=True, year=2018, transform=None, target_transform=None,
112 | category='name', loader=default_loader):
113 | self.transform = transform
114 | self.loader = loader
115 | self.target_transform = target_transform
116 | self.year = year
117 | # assert category in ['kingdom','phylum','class','order','supercategory','family','genus','name']
118 | path_json = os.path.join(root, f'{"train" if train else "val"}{year}.json')
119 | with open(path_json) as json_file:
120 | data = json.load(json_file)
121 |
122 | with open(os.path.join(root, 'categories.json')) as json_file:
123 | data_catg = json.load(json_file)
124 |
125 | path_json_for_targeter = os.path.join(root, f"train{year}.json")
126 |
127 | with open(path_json_for_targeter) as json_file:
128 | data_for_targeter = json.load(json_file)
129 |
130 | targeter = {}
131 | indexer = 0
132 | for elem in data_for_targeter['annotations']:
133 | king = []
134 | king.append(data_catg[int(elem['category_id'])][category])
135 | if king[0] not in targeter.keys():
136 | targeter[king[0]] = indexer
137 | indexer += 1
138 | self.nb_classes = len(targeter)
139 |
140 | self.samples = []
141 | for elem in data['images']:
142 | cut = elem['file_name'].split('/')
143 | target_current = int(cut[2])
144 | path_current = os.path.join(root, cut[0], cut[2], cut[3])
145 |
146 | categors = data_catg[target_current]
147 | target_current_true = targeter[categors[category]]
148 | self.samples.append((path_current, target_current_true))
149 |
150 | # __getitem__ and __len__ inherited from ImageFolder
151 |
152 | def build_dataset(is_train, args, folder_name=None):
153 | transform = build_transform(is_train, args)
154 |
155 | if args.data_set == 'CIFAR10':
156 | dataset = datasets.CIFAR10(args.data_path, train=is_train, transform=transform, download=True)
157 | nb_classes = 10
158 | elif args.data_set == 'CIFAR100':
159 | dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform, download=True)
160 | nb_classes = 100
161 | elif args.data_set == 'CARS':
162 | dataset = Cars196(args.data_path, train=is_train, transform=transform)
163 | nb_classes = 196
164 | elif args.data_set == 'PETS':
165 | dataset = Pets(args.data_path, train=is_train, transform=transform)
166 | nb_classes = 37
167 | elif args.data_set == 'FLOWERS':
168 | dataset = Flowers(args.data_path, train=is_train, transform=transform)
169 | nb_classes = 102
170 | elif args.data_set == 'IMNET':
171 | root = os.path.join(args.data_path, 'train' if is_train else 'val')
172 | dataset = datasets.ImageFolder(root, transform=transform)
173 | nb_classes = 1000
174 | elif args.data_set == 'EVO_IMNET':
175 | root = os.path.join(args.data_path, folder_name)
176 | dataset = datasets.ImageFolder(root, transform=transform)
177 | nb_classes = 1000
178 | elif args.data_set == 'INAT':
179 | dataset = INatDataset(args.data_path, train=is_train, year=2018,
180 | category=args.inat_category, transform=transform)
181 | nb_classes = dataset.nb_classes
182 | elif args.data_set == 'INAT19':
183 | dataset = INatDataset(args.data_path, train=is_train, year=2019,
184 | category=args.inat_category, transform=transform)
185 | nb_classes = dataset.nb_classes
186 |
187 | return dataset, nb_classes
188 |
189 | def build_transform(is_train, args):
190 | resize_im = args.input_size > 32
191 | if is_train:
192 | # this should always dispatch to transforms_imagenet_train
193 | transform = create_transform(
194 | input_size=args.input_size,
195 | is_training=True,
196 | color_jitter=args.color_jitter,
197 | auto_augment=args.aa,
198 | interpolation=args.train_interpolation,
199 | re_prob=args.reprob,
200 | re_mode=args.remode,
201 | re_count=args.recount,
202 | )
203 | if not resize_im:
204 | # replace RandomResizedCropAndInterpolation with
205 | # RandomCrop
206 | transform.transforms[0] = transforms.RandomCrop(
207 | args.input_size, padding=4)
208 | return transform
209 |
210 | t = []
211 | if resize_im:
212 | size = int((256 / 224) * args.input_size)
213 | t.append(
214 | transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images
215 | )
216 | #t.append(transforms.CenterCrop(args.input_size))
217 | crop_tf = {
218 | 1: transforms.CenterCrop,
219 | 5: transforms.FiveCrop,
220 | 10: transforms.TenCrop,
221 | }
222 | t.append(crop_tf[args.eval_crops](args.input_size))
223 |
224 | if resize_im and args.eval_crops > 1:
225 | t.append(transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])))
226 | else:
227 | t.append(transforms.ToTensor())
228 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
229 | return transforms.Compose(t)
230 |
--------------------------------------------------------------------------------
/lib/imagenet_withhold.py:
--------------------------------------------------------------------------------
1 |
2 | from PIL import Image
3 | import io
4 |
5 | import torch
6 | from torch.utils.data import Dataset
7 | import torchvision.transforms as transforms
8 |
9 |
10 | class ImageNet_Withhold(Dataset):
11 | def __init__(self, data_root, ann_file='', transform=None, train=True, task ='train'):
12 | super(ImageNet_Withhold, self).__init__()
13 | ann_file = ann_file + '/' + 'val_true.txt'
14 | train_split = (task == 'train' or task == 'val')
15 | self.data_root = data_root + '/'+ ('train' if train_split else 'val')
16 |
17 | self.data = []
18 | self.nb_classes = 0
19 | folders = {}
20 | cnt = 0
21 | self.z = ZipReader()
22 | # if train:
23 | # for member in self.tarfile.getmembers():
24 | # print(member)
25 | # self.tarfile = tarfile.open(self.data_root)
26 |
27 | f = open(ann_file)
28 | prefix = 'data/sdb/imagenet'+'/'+ ('train' if train_split else 'val') + '/'
29 | for line in f:
30 | tmp = line.strip().split('\t')[0]
31 | class_pic = tmp.split('/')
32 | class_tmp = class_pic[0]
33 | pic = class_pic[1]
34 |
35 | if class_tmp in folders:
36 | # print(self.tarfile.getmember(('train/' if train else 'val/') + tmp[0] + '.JPEG'))
37 | self.data.append((class_tmp + '.zip', prefix + tmp + '.JPEG', folders[class_tmp]))
38 | else:
39 | folders[class_tmp] = cnt
40 | cnt += 1
41 | self.data.append((class_tmp + '.zip', prefix + tmp + '.JPEG',folders[class_tmp]))
42 |
43 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
44 | std=[0.229, 0.224, 0.225])
45 | if transform is not None:
46 | self.transforms = transform
47 | else:
48 | if train:
49 | self.transforms = transforms.Compose([
50 | transforms.RandomSizedCrop(224),
51 | transforms.RandomHorizontalFlip(),
52 | transforms.ToTensor(),
53 | normalize,
54 | ])
55 | else:
56 | self.transforms = transforms.Compose([
57 | transforms.Scale(256),
58 | transforms.CenterCrop(224),
59 | transforms.ToTensor(),
60 | normalize,
61 | ])
62 |
63 |
64 | self.nb_classes = cnt
65 | def __len__(self):
66 | return len(self.data)
67 |
68 | def __getitem__(self, idx):
69 |
70 | # print('extract_file', time.time()-start_time)
71 | iob = self.z.read(self.data_root + '/' + self.data[idx][0], self.data[idx][1])
72 | iob = io.BytesIO(iob)
73 | img = Image.open(iob).convert('RGB')
74 | target = self.data[idx][2]
75 | if self.transforms is not None:
76 | img = self.transforms(img)
77 | # print('open', time.time()-start_time)
78 | return img, target
79 |
--------------------------------------------------------------------------------
/lib/samplers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.distributed as dist
3 | import math
4 |
5 |
6 | class RASampler(torch.utils.data.Sampler):
7 | """Sampler that restricts data loading to a subset of the dataset for distributed,
8 | with repeated augmentation.
9 | It ensures that different each augmented version of a sample will be visible to a
10 | different process (GPU)
11 | Heavily based on torch.utils.data.DistributedSampler
12 | """
13 |
14 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
15 | if num_replicas is None:
16 | if not dist.is_available():
17 | raise RuntimeError("Requires distributed package to be available")
18 | num_replicas = dist.get_world_size()
19 | if rank is None:
20 | if not dist.is_available():
21 | raise RuntimeError("Requires distributed package to be available")
22 | rank = dist.get_rank()
23 | self.dataset = dataset
24 | self.num_replicas = num_replicas
25 | self.rank = rank
26 | self.epoch = 0
27 | self.num_samples = int(math.ceil(len(self.dataset) * 3.0 / self.num_replicas))
28 | self.total_size = self.num_samples * self.num_replicas
29 | # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas))
30 | self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
31 | self.shuffle = shuffle
32 |
33 | def __iter__(self):
34 | # deterministically shuffle based on epoch
35 | g = torch.Generator()
36 | g.manual_seed(self.epoch)
37 | if self.shuffle:
38 | indices = torch.randperm(len(self.dataset), generator=g).tolist()
39 | else:
40 | indices = list(range(len(self.dataset)))
41 |
42 | # add extra samples to make it evenly divisible
43 | indices = [ele for ele in indices for i in range(3)]
44 | indices += indices[:(self.total_size - len(indices))]
45 | assert len(indices) == self.total_size
46 |
47 | # subsample
48 | indices = indices[self.rank:self.total_size:self.num_replicas]
49 |
50 | assert len(indices) == self.num_samples
51 | return iter(indices[:self.num_selected_samples])
52 |
53 | def __len__(self):
54 | return self.num_selected_samples
55 |
56 | def set_epoch(self, epoch):
57 | self.epoch = epoch
58 |
--------------------------------------------------------------------------------
/lib/score_maker.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import torch
3 | import math
4 | import logging
5 | from torch import optim
6 | import torch.nn.functional as F
7 | import torch.distributed as dist
8 | from contextlib import suppress
9 | from scipy import stats
10 | import numpy as np
11 | from sklearn.metrics import roc_auc_score
12 | from tqdm import tqdm
13 | import random
14 | import functools
15 | import torch.distributed as dist
16 | from typing import Iterable, Optional
17 | from timm.data import Mixup
18 | from timm.optim import create_optimizer
19 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
20 |
21 | from model.supernet_transformer import Vision_TransformerSuper
22 |
23 |
24 | def is_master(args, local=False):
25 | return is_local_master(args) if local else is_global_master(args)
26 |
27 |
28 | def unwrap_model(model):
29 | if hasattr(model, 'module'):
30 | return model.module
31 | else:
32 | return model
33 |
34 |
35 | def gen_key(embed_dim, choice):
36 | return f'{embed_dim},{choice}'
37 |
38 |
39 | class ScoreMaker(object):
40 | def __init__(self):
41 | self.grad_dict_before_train = {}
42 | self.grad_dict_after_train = {}
43 | self.param_val_dict = {}
44 | self.item_score_dict = {}
45 | self.key_items = ['attn.qkv.weight', 'fc1.weight']
46 |
47 | def drop_gradient(self):
48 | self.grad_dict = None
49 |
50 | def nan_to_zero(self, a):
51 | return torch.where(torch.isnan(a), torch.full_like(a, 0), a)
52 |
53 | def build_avg_image(self, s):
54 | # 3 channel in image
55 | assert s[1] == 3
56 |
57 | img = torch.zeros(s)
58 | mean = IMAGENET_DEFAULT_MEAN
59 | std = IMAGENET_DEFAULT_STD
60 | for i in range(3):
61 | torch.nn.init.normal_(img[:, i, :, :], mean=mean[i], std=std[i])
62 | return img
63 |
64 | def get_gradient(self, model, criterion, data_loader, args, choices, device, mixup_fn: Optional[Mixup] = None):
65 | config = {}
66 | dimensions = ['mlp_ratio', 'num_heads']
67 | depth = max(choices['depth'])
68 | for dimension in dimensions:
69 | config[dimension] = [max(choices[dimension]) for _ in range(depth)]
70 | config['embed_dim'] = [max(choices['embed_dim'])] * depth
71 | config['layer_num'] = depth
72 |
73 | model_module = unwrap_model(model)
74 | model_module.set_sample_config(config=config)
75 |
76 | model.train()
77 | criterion.train()
78 |
79 | random.seed(0)
80 |
81 | optimizer = create_optimizer(args, model_module)
82 |
83 | batch_num = 0
84 | grad_dict = {}
85 |
86 | optimizer.zero_grad()
87 |
88 | for samples, targets in data_loader:
89 | samples = samples.to(device, non_blocking=True)
90 | targets = targets.to(device, non_blocking=True)
91 |
92 | if mixup_fn is not None:
93 | samples, targets = mixup_fn(samples, targets)
94 |
95 | if args.data_free:
96 | input_dim = list(samples[0, :].shape)
97 | inputs = self.build_avg_image([64] + input_dim).to(device) # 64 batch image
98 | output = model.forward(inputs)
99 | torch.sum(output).backward()
100 | batch_num += 1
101 | print('data free!')
102 | break
103 |
104 | outputs = model(samples)
105 | loss = criterion(outputs, targets)
106 | loss.backward()
107 | batch_num += 1
108 |
109 | for k, param in model_module.named_parameters():
110 | if param.requires_grad:
111 | grad_dict[k] = param.grad
112 |
113 | if args.distributed:
114 | dist.barrier()
115 | result = torch.tensor([batch_num]).to(args.device, non_blocking=True)
116 | dist.all_reduce(result)
117 | batch_num = result[0]
118 | for k, v in grad_dict.items():
119 | dist.all_reduce(grad_dict[k])
120 |
121 | for k, v in grad_dict.items():
122 | grad_dict[k] = v / batch_num
123 | # grad_dict[k] = grad_dict[k].cpu()
124 |
125 | self.grad_dict = grad_dict
126 |
127 | def get_block_scores(self, model, args, choices):
128 | model_module = unwrap_model(model)
129 | param_val_dict = model_module.state_dict()
130 |
131 | head_choices = choices['num_heads']
132 | mlp_choices = choices['mlp_ratio']
133 | layers = max(choices['depth'])
134 | embed_dims = choices['embed_dim']
135 | max_dim = max(choices['embed_dim'])
136 | head_dim = model_module.super_embed_dim // model_module.super_num_heads
137 |
138 | head_score_dict = {}
139 | mlp_score_dict = {}
140 | for embed_dim in embed_dims:
141 | for j in range(len(head_choices) - 1):
142 | head_score_dict[gen_key(embed_dim, head_choices[j + 1])] = []
143 | for j in range(len(mlp_choices) - 1):
144 | mlp_score_dict[gen_key(embed_dim, mlp_choices[j + 1])] = []
145 |
146 | def get_item_score(pv, gv, block_score_method):
147 | pg = pv.mul(gv)
148 |
149 | if 'balance_taylor6_norm' in block_score_method:
150 | item_score = pg.abs() / pv.abs().sum() / gv.abs().sum() / pg.abs().sum()
151 | elif 'taylor6_doublenorm' in block_score_method:
152 | item_score = pg.abs() / pg.abs().sum() / pg.abs().sum()
153 | elif 'taylor6_norm' in block_score_method:
154 | item_score = pg.abs() / pg.abs().sum()
155 | elif 'balance_taylor6' in block_score_method:
156 | item_score = pg.abs() / pv.abs().sum() / gv.abs().sum()
157 | elif 'taylor6' in block_score_method:
158 | item_score = pg.abs()
159 | elif 'balance_taylor5_norm' in block_score_method:
160 | item_score = pg / pv.sum().abs() / gv.sum().abs() / pg.sum().abs()
161 | elif 'taylor5_doublenorm' in block_score_method:
162 | item_score = pg / pg.sum().abs() / pg.sum().abs()
163 | elif 'taylor5_norm' in block_score_method:
164 | item_score = pg / pg.sum().abs()
165 | elif 'balance_taylor5' in block_score_method:
166 | item_score = pg / pv.sum().abs() / gv.sum().abs()
167 | elif 'taylor5' in block_score_method:
168 | item_score = pg
169 | elif 'taylor9_norm' in block_score_method:
170 | item_score = gv.abs() / gv.abs().sum()
171 | elif 'taylor9_doublenorm' in block_score_method:
172 | item_score = gv.abs() / gv.abs().sum() / gv.abs().sum()
173 | elif 'taylor9' in block_score_method:
174 | item_score = gv.abs()
175 | elif 'l1norm' in block_score_method:
176 | item_score = pv.abs()
177 | else:
178 | item_score = pv
179 | return item_score
180 |
181 | for embed_dim in embed_dims:
182 | for i in range(layers):
183 | qkv_w = f'blocks.{i}.attn.qkv.weight'
184 | c_fc_w = f'blocks.{i}.fc1.weight'
185 | c_proj_w = f'blocks.{i}.fc2.weight'
186 | qkv_score = get_item_score(param_val_dict[qkv_w][:, :max_dim],
187 | self.grad_dict[qkv_w][:, :max_dim],
188 | args.block_score_method_for_head)[:, :embed_dim]
189 | c_fc_score = get_item_score(param_val_dict[c_fc_w][:, :max_dim],
190 | self.grad_dict[c_fc_w][:, :max_dim],
191 | args.block_score_method_for_mlp)[:, :embed_dim]
192 | c_proj_score = get_item_score(param_val_dict[c_proj_w][:max_dim, :],
193 | self.grad_dict[c_proj_w][:max_dim, :],
194 | args.block_score_method_for_mlp)[:embed_dim, :]
195 | for j in range(len(head_choices) - 1):
196 | qkv_embed_base = head_dim * head_choices[j]
197 | qkv_embed_dim = head_dim * head_choices[j + 1]
198 | left_qkv_score = torch.cat([qkv_score[qkv_embed_base * 3 + k:qkv_embed_dim * 3:3, :] for k in range(3)], dim=0)
199 | score = left_qkv_score.sum().abs().cpu()
200 | head_score_dict[gen_key(embed_dim, head_choices[j + 1])].append(score)
201 | for j in range(len(mlp_choices) - 1):
202 | mlp_embed_base = int(embed_dim * mlp_choices[j])
203 | mlp_embed_dim = int(embed_dim * mlp_choices[j + 1])
204 | left_c_fc_score = c_fc_score[mlp_embed_base:mlp_embed_dim, :]
205 | left_c_proj_score = c_proj_score[:, mlp_embed_base:mlp_embed_dim]
206 | score = left_c_fc_score.sum().abs().cpu() + left_c_proj_score.sum().abs().cpu()
207 | mlp_score_dict[gen_key(embed_dim, mlp_choices[j + 1])].append(score)
208 |
209 | return {'head_scores': head_score_dict, 'mlp_scores': mlp_score_dict}
210 |
211 |
212 | def get_item_score(self, model, criterion, data_loader, args, choices, device, mixup_fn: Optional[Mixup] = None):
213 |
214 | config = {}
215 | dimensions = ['mlp_ratio', 'num_heads']
216 | depth = max(choices['depth'])
217 | for dimension in dimensions:
218 | config[dimension] = [max(choices[dimension]) for _ in range(depth)]
219 | config['embed_dim'] = [max(choices['embed_dim'])] * depth
220 | config['layer_num'] = depth
221 |
222 | model_module = unwrap_model(model)
223 | model_module.set_sample_config(config=config)
224 |
225 | param_val_dict = model_module.state_dict()
226 | grad_dict = {}
227 |
228 | if 'taylor' in args.score_method:
229 | model.train()
230 | criterion.train()
231 |
232 | random.seed(0)
233 |
234 | optimizer = create_optimizer(args, model_module)
235 |
236 | batch_num = 0
237 |
238 | for samples, targets in data_loader:
239 | samples = samples.to(device, non_blocking=True)
240 | targets = targets.to(device, non_blocking=True)
241 |
242 | if mixup_fn is not None:
243 | samples, targets = mixup_fn(samples, targets)
244 |
245 | outputs = model(samples)
246 | loss = criterion(outputs, targets)
247 |
248 | optimizer.zero_grad()
249 | loss.backward()
250 |
251 | for k, param in model_module.named_parameters():
252 | if param.requires_grad:
253 | if batch_num == 0:
254 | grad_dict[k] = copy.deepcopy(param.grad)
255 | else:
256 | grad_dict[k] = grad_dict[k] + param.grad
257 |
258 | batch_num += 1
259 |
260 | if args.distributed:
261 | dist.barrier()
262 | result = torch.tensor([batch_num]).to(args.device, non_blocking=True)
263 | dist.all_reduce(result)
264 | batch_num = result[0]
265 | for k, v in grad_dict.items():
266 | dist.all_reduce(grad_dict[k])
267 |
268 | for k, v in grad_dict.items():
269 | grad_dict[k] = v / batch_num
270 | grad_dict[k] = grad_dict[k]
271 |
272 | for k in param_val_dict.keys():
273 | for key_item in self.key_items:
274 | if key_item in k:
275 | if 'l1norm' in args.score_method:
276 | self.item_score_dict[k] = param_val_dict[k].abs().cpu()
277 | elif 'taylor5' in args.score_method:
278 | self.item_score_dict[k] = param_val_dict[k].mul(grad_dict[k]).cpu()
279 | elif 'taylor6' in args.score_method:
280 | self.item_score_dict[k] = param_val_dict[k].mul(grad_dict[k]).abs().cpu()
281 | elif 'taylor9' in args.score_method:
282 | self.item_score_dict[k] = grad_dict[k].abs().cpu()
283 | else:
284 | assert False
285 |
286 | def get_head_score(self, head_choices, layers, head_dim, embed_dims, layer_norm=False):
287 | score_dict = {}
288 | for embed_dim in embed_dims:
289 | for j in range(len(head_choices) - 1):
290 | score_dict[gen_key(embed_dim, head_choices[j + 1])] = []
291 | for embed_dim in embed_dims:
292 | for i in range(layers):
293 | qkv_w = f'blocks.{i}.attn.qkv.weight'
294 | for j in range(len(head_choices) - 1):
295 | qkv_embed_base = head_dim * head_choices[j]
296 | qkv_embed_dim = head_dim * head_choices[j + 1]
297 | qkv_score = self.item_score_dict[qkv_w][:,:embed_dim]
298 | item_score = torch.cat([qkv_score[qkv_embed_base * 3 + k:qkv_embed_dim * 3:3, :] for k in range(3)], dim=0)
299 | score = item_score.sum().abs()
300 | if layer_norm:
301 | score = score / qkv_score.sum().abs()
302 | score_dict[gen_key(embed_dim, head_choices[j + 1])].append(score)
303 | return score_dict
304 |
305 | def get_mlp_score(self, mlp_choices, layers, embed_dims, layer_norm=False):
306 | score_dict = {}
307 | for embed_dim in embed_dims:
308 | for j in range(len(mlp_choices) - 1):
309 | score_dict[gen_key(embed_dim, mlp_choices[j + 1])] = []
310 | for embed_dim in embed_dims:
311 | for i in range(layers):
312 | c_fc_w = f'blocks.{i}.fc1.weight'
313 | for j in range(len(mlp_choices) - 1):
314 | mlp_embed_base = int(embed_dim * mlp_choices[j])
315 | mlp_embed_dim = int(embed_dim * mlp_choices[j+1])
316 | mlp_score = self.item_score_dict[c_fc_w][:, :embed_dim]
317 | item_score = mlp_score[mlp_embed_base:mlp_embed_dim, :]
318 | score = item_score.sum().abs()
319 | if layer_norm:
320 | score = score / mlp_score.sum().abs()
321 | score_dict[gen_key(embed_dim, mlp_choices[j + 1])].append(score)
322 | return score_dict
323 |
324 | def get_left_part_from_super_model(self, model: Vision_TransformerSuper, para_dict, sample_config):
325 | layers = model.super_layer_num
326 | sample_layers = sample_config['layer_num']
327 | left_dict = {}
328 |
329 | embed_dims = sample_config['embed_dim']
330 | output_dims = [out_dim for out_dim in sample_config['embed_dim'][1:]] + [sample_config['embed_dim'][-1]]
331 |
332 | left_dict['patch_embed_super.proj.weight'] = para_dict['patch_embed_super.proj.weight'][:embed_dims[0], ...]
333 | left_dict['patch_embed_super.proj.bias'] = para_dict['patch_embed_super.proj.bias'][:embed_dims[0], ...]
334 | left_dict['norm.weight'] = para_dict['norm.weight'][:embed_dims[-1]]
335 | left_dict['norm.bias'] = para_dict['norm.bias'][:embed_dims[-1]]
336 | left_dict['head.weight'] = para_dict['head.weight'][:, :embed_dims[-1]]
337 | left_dict['head.bias'] = para_dict['head.bias'][:embed_dims[-1]]
338 |
339 | for i in range(layers):
340 | qkv_w = f'blocks.{i}.attn.qkv.weight'
341 | qkv_b = f'blocks.{i}.attn.qkv.bias'
342 | proj_w = f'blocks.{i}.attn.proj.weight'
343 | proj_b = f'blocks.{i}.attn.proj.bias'
344 | ln1_w = f'blocks.{i}.attn_layer_norm.weight'
345 | ln1_b = f'blocks.{i}.attn_layer_norm.bias'
346 | c_fc_w = f'blocks.{i}.fc1.weight'
347 | c_fc_b = f'blocks.{i}.fc1.bias'
348 | c_proj_w = f'blocks.{i}.fc2.weight'
349 | c_proj_b = f'blocks.{i}.fc2.bias'
350 | ln2_w = f'blocks.{i}.ffn_layer_norm.weight'
351 | ln2_b = f'blocks.{i}.ffn_layer_norm.bias'
352 | if i < sample_layers:
353 | num_heads = sample_config['num_heads'][i]
354 | head_dim = model.super_embed_dim // model.super_num_heads
355 | qk_embed_dim = head_dim * num_heads
356 | mlp_ratio = sample_config['mlp_ratio'][i]
357 | embed_dim = embed_dims[i]
358 | mlp_width = int(embed_dim * mlp_ratio)
359 | output_dim = output_dims[i]
360 |
361 | left_dict[qkv_w] = para_dict[qkv_w][:, :embed_dim]
362 | left_dict[qkv_w] = torch.cat([left_dict[qkv_w][i:qk_embed_dim*3:3, :] for i in range(3)], dim=0)
363 |
364 | # left_dict[qkv_b] = para_dict[qkv_b][:qk_embed_dim*3]
365 | left_dict[qkv_b] = torch.cat([para_dict[qkv_b][i:qk_embed_dim*3:3] for i in range(3)])
366 |
367 | left_dict[proj_w] = para_dict[proj_w][:, :qk_embed_dim]
368 | left_dict[proj_w] = left_dict[proj_w][:embed_dim, :]
369 |
370 | left_dict[proj_b] = para_dict[proj_b][:embed_dim]
371 |
372 | left_dict[ln1_w] = para_dict[ln1_w][:embed_dim]
373 |
374 | left_dict[ln1_b] = para_dict[ln1_b][:embed_dim]
375 |
376 | left_dict[c_fc_w] = para_dict[c_fc_w][:, :embed_dim]
377 | left_dict[c_fc_w] = left_dict[c_fc_w][:mlp_width, :]
378 |
379 | left_dict[c_fc_b] = para_dict[c_fc_b][:mlp_width]
380 |
381 | left_dict[c_proj_w] = para_dict[c_proj_w][:, :mlp_width]
382 | left_dict[c_proj_w] = left_dict[c_proj_w][:output_dim, :]
383 |
384 | left_dict[c_proj_b] = para_dict[c_proj_b][:output_dim]
385 |
386 | left_dict[ln2_w] = para_dict[ln2_w][:output_dim]
387 |
388 | left_dict[ln2_b] = para_dict[ln2_b][:output_dim]
389 | else:
390 | continue
391 |
392 | num_paras = 0
393 | for k, v in left_dict.items():
394 | num_paras += v.numel()
395 |
396 | return left_dict, num_paras
397 |
398 | def get_left_attn_mlp_from_super_model(self, model: Vision_TransformerSuper, para_dict, sample_config):
399 |
400 | layers = model.super_layer_num
401 | sample_layers = sample_config['layer_num']
402 | left_dict = {}
403 |
404 | embed_dims = sample_config['embed_dim']
405 | output_dims = [out_dim for out_dim in sample_config['embed_dim'][1:]] + [sample_config['embed_dim'][-1]]
406 |
407 | for i in range(layers):
408 | qkv_w = f'blocks.{i}.attn.qkv.weight'
409 | qkv_b = f'blocks.{i}.attn.qkv.bias'
410 | proj_w = f'blocks.{i}.attn.proj.weight'
411 | proj_b = f'blocks.{i}.attn.proj.bias'
412 | c_fc_w = f'blocks.{i}.fc1.weight'
413 | c_fc_b = f'blocks.{i}.fc1.bias'
414 | c_proj_w = f'blocks.{i}.fc2.weight'
415 | c_proj_b = f'blocks.{i}.fc2.bias'
416 | if i < sample_layers:
417 | num_heads = sample_config['num_heads'][i]
418 | head_dim = model.super_embed_dim // model.super_num_heads
419 | qk_embed_dim = head_dim * num_heads
420 | mlp_ratio = sample_config['mlp_ratio'][i]
421 | embed_dim = embed_dims[i]
422 | mlp_width = int(embed_dim * mlp_ratio)
423 | output_dim = output_dims[i]
424 |
425 | left_dict[qkv_w] = para_dict[qkv_w][:, :embed_dim]
426 | left_dict[qkv_w] = torch.cat([left_dict[qkv_w][i:qk_embed_dim * 3:3, :] for i in range(3)], dim=0)
427 |
428 | # left_dict[qkv_b] = para_dict[qkv_b][:qk_embed_dim*3]
429 | left_dict[qkv_b] = torch.cat([para_dict[qkv_b][i:qk_embed_dim * 3:3] for i in range(3)], dim=0)
430 |
431 | left_dict[proj_w] = para_dict[proj_w][:, :qk_embed_dim]
432 | left_dict[proj_w] = left_dict[proj_w][:embed_dim, :]
433 |
434 | left_dict[proj_b] = para_dict[proj_b][:embed_dim]
435 |
436 | left_dict[c_fc_w] = para_dict[c_fc_w][:, :embed_dim]
437 | left_dict[c_fc_w] = left_dict[c_fc_w][:mlp_width, :]
438 |
439 | left_dict[c_fc_b] = para_dict[c_fc_b][:mlp_width]
440 |
441 | left_dict[c_proj_w] = para_dict[c_proj_w][:, :mlp_width]
442 | left_dict[c_proj_w] = left_dict[c_proj_w][:output_dim, :]
443 |
444 | left_dict[c_proj_b] = para_dict[c_proj_b][:output_dim]
445 | else:
446 | continue
447 |
448 | num_paras = 0
449 | for k, v in left_dict.items():
450 | num_paras += v.numel()
451 |
452 | return left_dict, num_paras
453 |
454 | def get_scores(self, model, score_methods, config):
455 | score_methods = score_methods.strip().split('+')
456 | scores = []
457 | for score_method in score_methods:
458 | scores.append(self.get_score(model, score_method, config))
459 | return scores
460 |
461 | def get_score(self, model, score_method, config):
462 |
463 | if score_method == 'entropy':
464 | depth, mlp_ratio, num_heads, embed_dim = config['layer_num'], config['mlp_ratio'], config['num_heads'], config['embed_dim']
465 | entropy_score = 0.
466 | for i in range(depth):
467 | d = embed_dim[i]
468 | n = 14 * 14 # input_size = 224, patch_size = 16
469 | d_f = mlp_ratio[i] * d
470 | d_h = 64
471 | n_h = num_heads[i]
472 | entropy_score += math.log(d_f) + math.log(d_h * n_h) + math.log(n) + 4 * math.log(d)
473 | return entropy_score
474 |
475 | super_paras = unwrap_model(model).state_dict()
476 |
477 | if 'left_attn_mlp' in score_method:
478 | para, num_paras = self.get_left_attn_mlp_from_super_model(unwrap_model(model), super_paras, config)
479 | else:
480 | para, num_paras = self.get_left_part_from_super_model(unwrap_model(model), super_paras, config)
481 | grad = None
482 | if 'taylor' in score_method:
483 | if 'left_attn_mlp' in score_method:
484 | grad, _ = self.get_left_attn_mlp_from_super_model(unwrap_model(model), self.grad_dict, config)
485 | else:
486 | grad, _ = self.get_left_part_from_super_model(unwrap_model(model), self.grad_dict, config)
487 |
488 | if 'avg' not in score_method:
489 | num_paras = None
490 |
491 | if 'l1norm' in score_method:
492 | res = self.criterion_l_l1norm(para, num_paras)
493 | elif 'l1norm_norm' in score_method:
494 | res = self.criterion_l_l1norm(paras, super_paras=super_paras)
495 | elif 'taylor5' in score_method:
496 | res = self.criterion_l_taylor5(para, grad, num_paras)
497 | elif 'taylor5_norm' in score_method:
498 | res = self.criterion_l_taylor5(paras, grads, super_paras=super_paras, super_grads=super_grads)
499 | elif 'taylor6' in score_method:
500 | res = self.criterion_l_taylor6(para, grad, num_paras)
501 | elif 'taylor6_norm' in score_method:
502 | res = self.criterion_l_taylor6(paras, grads, super_paras=super_paras, super_grads=super_grads)
503 | elif 'taylor9' in score_method:
504 | res = self.criterion_l_taylor9(para, grad, num_paras)
505 | elif 'taylor9_norm' in score_method:
506 | res = self.criterion_l_taylor9(paras, grads, super_paras=super_paras, super_grads=super_grads)
507 | else:
508 | assert False
509 |
510 | if 'pruned' in score_method:
511 | res = - res
512 | if type(res) == float:
513 | return res
514 | else:
515 | return res.cpu()
516 |
517 | def criterion_l_l1norm(self, paras, num_paras=None, super_paras=None):
518 | score = 0.
519 | for k, v in paras.items():
520 | if super_paras is not None:
521 | score += v.abs().sum() / super_paras[k].abs().sum()
522 | else:
523 | score += v.abs().sum()
524 | if num_paras:
525 | score /= num_paras
526 | return score.cpu()
527 |
528 | def criterion_l_l2norm(self, paras, num_paras=None, super_paras=None):
529 | score = 0.
530 | for k, v in paras.items():
531 | if super_paras is not None:
532 | score += v.norm() / super_paras[k].norm()
533 | else:
534 | score += v.norm()
535 | if num_paras:
536 | score /= num_paras
537 | return score.cpu()
538 |
539 | def criterion_l_taylor1(self, paras, grads, num_paras=None, super_paras=None, super_grads=None):
540 | score = 0.
541 | for k, v in paras.items():
542 | g = grads[k]
543 | if super_paras is not None and super_grads is not None:
544 | score += v.mul(g).sum() / super_paras[k].mul(super_grads[k]).sum()
545 | else:
546 | score += v.mul(g).sum()
547 | if num_paras:
548 | score /= num_paras
549 | score = score ** 2
550 | return score.cpu()
551 |
552 | def criterion_l_taylor2(self, paras, grads, num_paras=None, super_paras=None, super_grads=None): # fisher
553 | score = 0.
554 | for k, v in paras.items():
555 | g = grads[k]
556 | if super_paras is not None and super_grads is not None:
557 | score += (v.mul(g) ** 2).sum() / (super_paras[k].mul(super_grads[k]) ** 2).sum()
558 | else:
559 | score += (v.mul(g) ** 2).sum()
560 | if num_paras:
561 | score /= num_paras
562 | return score.cpu()
563 |
564 | def criterion_l_taylor3(self, paras, grads, num_paras=None, super_paras=None, super_grads=None):
565 | score = 0.
566 | for k, v in paras.items():
567 | g = grads[k]
568 | if super_paras is not None and super_grads is not None:
569 | score += g.sum() / super_grads[k].sum()
570 | else:
571 | score += g.sum()
572 | if num_paras:
573 | score /= num_paras
574 | score = score ** 2
575 | return score.cpu()
576 |
577 | def criterion_l_taylor4(self, paras, grads, num_paras=None, super_paras=None, super_grads=None):
578 | score = 0.
579 | for k, v in paras.items():
580 | g = grads[k]
581 | if super_paras is not None and super_grads is not None:
582 | score += (g ** 2).sum() / (super_grads[k] ** 2).sum()
583 | else:
584 | score += (g ** 2).sum()
585 | if num_paras:
586 | score /= num_paras
587 | return score.cpu()
588 |
589 | def criterion_l_taylor5(self, paras, grads, num_paras=None, super_paras=None, super_grads=None): # synflow
590 | score = 0.
591 | for k, v in paras.items():
592 | g = grads[k]
593 | if super_paras is not None and super_grads is not None:
594 | score += v.mul(g).sum() / super_paras[k].mul(super_grads[k]).sum()
595 | else:
596 | score += v.mul(g).sum()
597 | score = score.abs()
598 | if num_paras:
599 | score /= num_paras
600 | return score.cpu()
601 |
602 | def criterion_l_taylor6(self, paras, grads, num_paras=None, super_paras=None, super_grads=None):
603 | score = 0.
604 | for k, v in paras.items():
605 | g = grads[k]
606 | if super_paras is not None and super_grads is not None:
607 | score += v.mul(g).abs().sum() / super_paras[k].mul(super_grads[k]).abs().sum()
608 | else:
609 | score += v.mul(g).abs().sum()
610 | if num_paras:
611 | score /= num_paras
612 | return score.cpu()
613 |
614 | def criterion_l_taylor7(self, paras, grads, num_paras=None, super_paras=None, super_grads=None): # Euclidean norm of the gradients v1
615 | score = 0.
616 | for k, v in paras.items():
617 | g = grads[k]
618 | if super_paras is not None and super_grads is not None:
619 | score += (g ** 2).sum() / (super_grads[k] ** 2).sum()
620 | else:
621 | score += (g ** 2).sum()
622 | if num_paras:
623 | score /= num_paras
624 | score = math.sqrt(score)
625 | return score
626 |
627 | def criterion_l_taylor8(self, paras, grads, num_paras=None, super_paras=None, super_grads=None): # Euclidean norm of the gradients v2
628 | score = 0.
629 | for k, v in paras.items():
630 | g = grads[k]
631 | if super_paras is not None and super_grads is not None:
632 | score += g.norm() / super_grads[k].norm()
633 | else:
634 | score += g.norm()
635 | if num_paras:
636 | score /= num_paras
637 | return score.cpu()
638 |
639 | def criterion_l_taylor9(self, paras, grads, num_paras=None, super_paras=None, super_grads=None): # snip
640 | score = 0.
641 | for k, v in paras.items():
642 | g = grads[k]
643 | if super_paras is not None and super_grads is not None:
644 | score += g.abs().sum() / super_grads[k].abs().sum()
645 | else:
646 | score += g.abs().sum()
647 | if num_paras:
648 | score /= num_paras
649 | score = math.sqrt(score)
650 | return score
651 |
652 | def sort_eval(self, vis_dict, top_num=50):
653 | acc_list = []
654 | score_lists = {}
655 | for cand, info in vis_dict.items():
656 | # only consider the model under limits
657 | if 'acc' not in info.keys():
658 | continue
659 | acc_list.append(info['acc'])
660 | score_stats = info['score_stats']
661 | for k, score in score_stats.items():
662 | if k not in score_lists.keys():
663 | score_lists[k] = []
664 | if type(score) == float:
665 | score_lists[k].append(score)
666 | else:
667 | score_lists[k].append(score.cpu())
668 | acc_list = np.array(acc_list)
669 | for k in score_lists.keys():
670 | score_lists[k] = np.array(score_lists[k])
671 |
672 | p_vals = self.get_p_value(acc_list, score_lists)
673 | p_dict = {}
674 | for (p_val, k) in p_vals:
675 | p_dict[k] = p_val
676 |
677 | idx = acc_list.argsort()[-top_num:][::-1]
678 | sorted_acc = [(acc_list[idx], 'acc_list')]
679 | for k, scores in score_lists.items():
680 | if p_dict[k] > 0:
681 | idx = scores.argsort()[-top_num:][::-1]
682 | else:
683 | idx = scores.argsort()[:top_num]
684 | acc = acc_list[idx]
685 | acc.sort()
686 | sorted_acc.append((acc[::-1], k))
687 |
688 | def compare(a, b):
689 | for i in range(a[0].size):
690 | if a[0][i] > b[0][i]:
691 | return 1
692 | elif a[0][i] < b[0][i]:
693 | return -1
694 | return 0
695 |
696 | sorted_acc.sort(key=functools.cmp_to_key(compare))
697 | for acc in sorted_acc:
698 | print(acc)
699 |
700 |
--------------------------------------------------------------------------------
/lib/subImageNet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import argparse
4 | random.seed(0)
5 | parser = argparse.ArgumentParser('Generate SubImageNet', add_help=False)
6 | parser.add_argument('--data-path', default='../data/imagenet', type=str,
7 | help='dataset path')
8 | args = parser.parse_args()
9 |
10 | data_path = args.data_path
11 | ImageNet_train_path = os.path.join(data_path, 'train')
12 | subImageNet_name = 'subImageNet'
13 | class_idx_txt_path = os.path.join(data_path, subImageNet_name)
14 |
15 | # train
16 | classes = sorted(os.listdir(ImageNet_train_path))
17 | if not os.path.exists(os.path.join(data_path, subImageNet_name)):
18 | os.mkdir(os.path.join(data_path, subImageNet_name))
19 |
20 | subImageNet = dict()
21 | with open(os.path.join(class_idx_txt_path, 'subimages_list.txt'), 'w') as f:
22 | subImageNet_class = classes
23 | for iclass in subImageNet_class:
24 | class_path = os.path.join(ImageNet_train_path, iclass)
25 | if not os.path.exists(
26 | os.path.join(
27 | data_path,
28 | subImageNet_name,
29 | iclass)):
30 | os.mkdir(os.path.join(data_path, subImageNet_name, iclass))
31 | subImages = random.sample(sorted(os.listdir(class_path)), 100)
32 | # print("{}\n".format(subImages))
33 | f.write("{}\n".format(subImages))
34 | subImageNet[iclass] = subImages
35 | for image in subImages:
36 | raw_path = os.path.join(ImageNet_train_path, iclass, image)
37 | new_ipath = os.path.join(
38 | data_path, subImageNet_name, iclass, image)
39 | os.system('cp {} {}'.format(raw_path, new_ipath))
40 |
41 | sub_classes = sorted(subImageNet.keys())
42 | with open(os.path.join(class_idx_txt_path, 'info.txt'), 'w') as f:
43 | class_idx = 0
44 | for key in sub_classes:
45 | images = sorted((subImageNet[key]))
46 | # print(len(images))
47 | f.write("{}\n".format(key))
48 | class_idx = class_idx + 1
--------------------------------------------------------------------------------
/lib/utils.py:
--------------------------------------------------------------------------------
1 | import io
2 | import os
3 | import time
4 | from collections import defaultdict, deque
5 | import datetime
6 |
7 | import torch
8 | import torch.distributed as dist
9 |
10 |
11 | class SmoothedValue(object):
12 | """Track a series of values and provide access to smoothed values over a
13 | window or the global series average.
14 | """
15 |
16 | def __init__(self, window_size=20, fmt=None):
17 | if fmt is None:
18 | fmt = "{median:.4f} ({global_avg:.4f})"
19 | self.deque = deque(maxlen=window_size)
20 | self.total = 0.0
21 | self.count = 0
22 | self.fmt = fmt
23 |
24 | def update(self, value, n=1):
25 | self.deque.append(value)
26 | self.count += n
27 | self.total += value * n
28 |
29 | def synchronize_between_processes(self):
30 | """
31 | Warning: does not synchronize the deque!
32 | """
33 | if not is_dist_avail_and_initialized():
34 | return
35 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
36 | dist.barrier()
37 | dist.all_reduce(t)
38 | t = t.tolist()
39 | self.count = int(t[0])
40 | self.total = t[1]
41 |
42 | @property
43 | def median(self):
44 | d = torch.tensor(list(self.deque))
45 | return d.median().item()
46 |
47 | @property
48 | def avg(self):
49 | d = torch.tensor(list(self.deque), dtype=torch.float32)
50 | return d.mean().item()
51 |
52 | @property
53 | def global_avg(self):
54 | return self.total / self.count
55 |
56 | @property
57 | def max(self):
58 | return max(self.deque)
59 |
60 | @property
61 | def value(self):
62 | return self.deque[-1]
63 |
64 | def __str__(self):
65 | return self.fmt.format(
66 | median=self.median,
67 | avg=self.avg,
68 | global_avg=self.global_avg,
69 | max=self.max,
70 | value=self.value)
71 |
72 |
73 | class MetricLogger(object):
74 | def __init__(self, delimiter="\t"):
75 | self.meters = defaultdict(SmoothedValue)
76 | self.delimiter = delimiter
77 |
78 | def update(self, **kwargs):
79 | for k, v in kwargs.items():
80 | if isinstance(v, torch.Tensor):
81 | v = v.item()
82 | assert isinstance(v, (float, int))
83 | self.meters[k].update(v)
84 |
85 | def __getattr__(self, attr):
86 | if attr in self.meters:
87 | return self.meters[attr]
88 | if attr in self.__dict__:
89 | return self.__dict__[attr]
90 | raise AttributeError("'{}' object has no attribute '{}'".format(
91 | type(self).__name__, attr))
92 |
93 | def __str__(self):
94 | loss_str = []
95 | for name, meter in self.meters.items():
96 | loss_str.append(
97 | "{}: {}".format(name, str(meter))
98 | )
99 | return self.delimiter.join(loss_str)
100 |
101 | def synchronize_between_processes(self):
102 | for meter in self.meters.values():
103 | meter.synchronize_between_processes()
104 |
105 | def add_meter(self, name, meter):
106 | self.meters[name] = meter
107 |
108 | def log_every(self, iterable, print_freq, header=None):
109 | i = 0
110 | if not header:
111 | header = ''
112 | start_time = time.time()
113 | end = time.time()
114 | iter_time = SmoothedValue(fmt='{avg:.4f}')
115 | data_time = SmoothedValue(fmt='{avg:.4f}')
116 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
117 | log_msg = [
118 | header,
119 | '[{0' + space_fmt + '}/{1}]',
120 | 'eta: {eta}',
121 | '{meters}',
122 | 'time: {time}',
123 | 'data: {data}'
124 | ]
125 | if torch.cuda.is_available():
126 | log_msg.append('max mem: {memory:.0f}')
127 | log_msg = self.delimiter.join(log_msg)
128 | MB = 1024.0 * 1024.0
129 | for obj in iterable:
130 | data_time.update(time.time() - end)
131 | yield obj
132 | iter_time.update(time.time() - end)
133 | if i % print_freq == 0 or i == len(iterable) - 1:
134 | eta_seconds = iter_time.global_avg * (len(iterable) - i)
135 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
136 | if torch.cuda.is_available():
137 | print(log_msg.format(
138 | i, len(iterable), eta=eta_string,
139 | meters=str(self),
140 | time=str(iter_time), data=str(data_time),
141 | memory=torch.cuda.max_memory_allocated() / MB))
142 | else:
143 | print(log_msg.format(
144 | i, len(iterable), eta=eta_string,
145 | meters=str(self),
146 | time=str(iter_time), data=str(data_time)))
147 | i += 1
148 | end = time.time()
149 | total_time = time.time() - start_time
150 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
151 | print('{} Total time: {} ({:.4f} s / it)'.format(
152 | header, total_time_str, total_time / len(iterable)))
153 |
154 |
155 | def _load_checkpoint_for_ema(model_ema, checkpoint):
156 | """
157 | Workaround for ModelEma._load_checkpoint to accept an already-loaded object
158 | """
159 | mem_file = io.BytesIO()
160 | torch.save(checkpoint, mem_file)
161 | mem_file.seek(0)
162 | model_ema._load_checkpoint(mem_file)
163 |
164 |
165 | def setup_for_distributed(is_master):
166 | """
167 | This function disables printing when not in master process
168 | """
169 | import builtins as __builtin__
170 | builtin_print = __builtin__.print
171 |
172 | def print(*args, **kwargs):
173 | force = kwargs.pop('force', False)
174 | if is_master or force:
175 | builtin_print(*args, **kwargs)
176 |
177 | __builtin__.print = print
178 |
179 |
180 | def is_dist_avail_and_initialized():
181 | if not dist.is_available():
182 | return False
183 | if not dist.is_initialized():
184 | return False
185 | return True
186 |
187 |
188 | def get_world_size():
189 | if not is_dist_avail_and_initialized():
190 | return 1
191 | return dist.get_world_size()
192 |
193 |
194 | def get_rank():
195 | if not is_dist_avail_and_initialized():
196 | return 0
197 | return dist.get_rank()
198 |
199 |
200 | def is_main_process():
201 | return get_rank() == 0
202 |
203 |
204 | def save_on_master(*args, **kwargs):
205 | if is_main_process():
206 | torch.save(*args, **kwargs)
207 |
208 |
209 | def init_distributed_mode(args):
210 | if 'OMPI_COMM_WORLD_RANK' in os.environ:
211 | args.rank = int(os.environ.get('OMPI_COMM_WORLD_RANK'))
212 | args.world_size = int(os.environ.get('OMPI_COMM_WORLD_SIZE'))
213 | args.gpu = args.rank % torch.cuda.device_count()
214 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
215 | args.rank = int(os.environ["RANK"])
216 | args.world_size = int(os.environ['WORLD_SIZE'])
217 | args.gpu = int(os.environ['LOCAL_RANK'])
218 | elif 'SLURM_PROCID' in os.environ:
219 | args.rank = int(os.environ['SLURM_PROCID'])
220 | args.gpu = args.rank % torch.cuda.device_count()
221 | else:
222 | print('Not using distributed mode')
223 | args.distributed = False
224 | return
225 |
226 | args.distributed = True
227 |
228 | torch.cuda.set_device(args.gpu)
229 | args.dist_backend = 'nccl'
230 | print('| distributed init (rank {}): {}'.format(
231 | args.rank, args.dist_url), flush=True)
232 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
233 | world_size=args.world_size, rank=args.rank)
234 | torch.distributed.barrier()
235 | setup_for_distributed(args.rank == 0)
236 |
--------------------------------------------------------------------------------
/model/module/Linear_super.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 |
6 | class LinearSuper(nn.Linear):
7 | def __init__(self, super_in_dim, super_out_dim, bias=True, uniform_=None, non_linear='linear', scale=False):
8 | super().__init__(super_in_dim, super_out_dim, bias=bias)
9 |
10 | # super_in_dim and super_out_dim indicate the largest network!
11 | self.super_in_dim = super_in_dim
12 | self.super_out_dim = super_out_dim
13 |
14 | # input_dim and output_dim indicate the current sampled size
15 | self.sample_in_dim = None
16 | self.sample_out_dim = None
17 |
18 | self.samples = {}
19 |
20 | self.scale = scale
21 | self._reset_parameters(bias, uniform_, non_linear)
22 | self.profiling = False
23 |
24 | def profile(self, mode=True):
25 | self.profiling = mode
26 |
27 | def sample_parameters(self, resample=False):
28 | if self.profiling or resample:
29 | return self._sample_parameters()
30 | return self.samples
31 |
32 | def _reset_parameters(self, bias, uniform_, non_linear):
33 | nn.init.xavier_uniform_(self.weight) if uniform_ is None else uniform_(
34 | self.weight, non_linear=non_linear)
35 | if bias:
36 | nn.init.constant_(self.bias, 0.)
37 |
38 | def set_sample_config(self, sample_in_dim, sample_out_dim):
39 | self.sample_in_dim = sample_in_dim
40 | self.sample_out_dim = sample_out_dim
41 |
42 | self._sample_parameters()
43 |
44 | def _sample_parameters(self):
45 | self.samples['weight'] = sample_weight(self.weight, self.sample_in_dim, self.sample_out_dim)
46 | self.samples['bias'] = self.bias
47 | self.sample_scale = self.super_out_dim/self.sample_out_dim
48 | if self.bias is not None:
49 | self.samples['bias'] = sample_bias(self.bias, self.sample_out_dim)
50 | return self.samples
51 |
52 | def forward(self, x):
53 | self.sample_parameters()
54 | return F.linear(x, self.samples['weight'], self.samples['bias']) * (self.sample_scale if self.scale else 1)
55 |
56 | def calc_sampled_param_num(self):
57 | assert 'weight' in self.samples.keys()
58 | weight_numel = self.samples['weight'].numel()
59 |
60 | if self.samples['bias'] is not None:
61 | bias_numel = self.samples['bias'].numel()
62 | else:
63 | bias_numel = 0
64 |
65 | return weight_numel + bias_numel
66 | def get_complexity(self, sequence_length):
67 | total_flops = 0
68 | total_flops += sequence_length * np.prod(self.samples['weight'].size())
69 | return total_flops
70 |
71 | def sample_weight(weight, sample_in_dim, sample_out_dim):
72 | sample_weight = weight[:, :sample_in_dim]
73 | sample_weight = sample_weight[:sample_out_dim, :]
74 |
75 | return sample_weight
76 |
77 |
78 | def sample_bias(bias, sample_out_dim):
79 | sample_bias = bias[:sample_out_dim]
80 |
81 | return sample_bias
82 |
--------------------------------------------------------------------------------
/model/module/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tinyvision/PreNAS/0050c7a22482e8736f148bc41ab0d952968a8748/model/module/__init__.py
--------------------------------------------------------------------------------
/model/module/embedding_super.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from model.utils import to_2tuple
5 | import numpy as np
6 |
7 | class PatchembedSuper(nn.Module):
8 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, scale=False):
9 | super(PatchembedSuper, self).__init__()
10 |
11 | img_size = to_2tuple(img_size)
12 | patch_size = to_2tuple(patch_size)
13 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
14 | self.img_size = img_size
15 | self.patch_size = patch_size
16 | self.num_patches = num_patches
17 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
18 | self.super_embed_dim = embed_dim
19 | self.scale = scale
20 |
21 | # sampled_
22 | self.sample_embed_dim = None
23 | self.sampled_weight = None
24 | self.sampled_bias = None
25 | self.sampled_scale = None
26 |
27 | def set_sample_config(self, sample_embed_dim):
28 | self.sample_embed_dim = sample_embed_dim
29 | self.sampled_weight = self.proj.weight[:sample_embed_dim, ...]
30 | self.sampled_bias = self.proj.bias[:self.sample_embed_dim, ...]
31 | if self.scale:
32 | self.sampled_scale = self.super_embed_dim / sample_embed_dim
33 | def forward(self, x):
34 | B, C, H, W = x.shape
35 | assert H == self.img_size[0] and W == self.img_size[1], \
36 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
37 | x = F.conv2d(x, self.sampled_weight, self.sampled_bias, stride=self.patch_size, padding=self.proj.padding, dilation=self.proj.dilation).flatten(2).transpose(1,2)
38 | if self.scale:
39 | return x * self.sampled_scale
40 | return x
41 | def calc_sampled_param_num(self):
42 | return self.sampled_weight.numel() + self.sampled_bias.numel()
43 |
44 | def get_complexity(self, sequence_length):
45 | total_flops = 0
46 | if self.sampled_bias is not None:
47 | total_flops += self.sampled_bias.size(0)
48 | total_flops += sequence_length * np.prod(self.sampled_weight.size())
49 | return total_flops
--------------------------------------------------------------------------------
/model/module/layernorm_super.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class SwitchableLayerNormSuper(nn.Module):
7 | def __init__(self, embed_dim_list):
8 | super(SwitchableLayerNormSuper, self).__init__()
9 |
10 | self.embed_dim_list = embed_dim_list
11 |
12 | # the largest embed dim
13 | self.super_embed_dim = max(embed_dim_list)
14 |
15 | # the current sampled embed dim
16 | self.sample_embed_dim = None
17 |
18 | self.lns = nn.ModuleList([nn.LayerNorm(dim) for dim in embed_dim_list])
19 |
20 | def set_sample_config(self, sample_embed_dim):
21 | self.sample_embed_dim = sample_embed_dim
22 | self.sample_idx = self.embed_dim_list.index(sample_embed_dim)
23 |
24 | def forward(self, x):
25 | return self.lns[self.sample_idx](x)
26 |
27 | def calc_sampled_param_num(self):
28 | ln = self.lns[self.sample_idx]
29 | return ln.weight.numel() + ln.bias.numel()
30 |
31 | def get_complexity(self, sequence_length):
32 | return sequence_length * self.sample_embed_dim
33 |
34 |
35 | class LayerNormSuper(torch.nn.LayerNorm):
36 | def __init__(self, super_embed_dim):
37 | super().__init__(super_embed_dim)
38 |
39 | # the largest embed dim
40 | self.super_embed_dim = super_embed_dim
41 |
42 | # the current sampled embed dim
43 | self.sample_embed_dim = None
44 |
45 | self.samples = {}
46 | self.profiling = False
47 |
48 | def profile(self, mode=True):
49 | self.profiling = mode
50 |
51 | def sample_parameters(self, resample=False):
52 | if self.profiling or resample:
53 | return self._sample_parameters()
54 | return self.samples
55 |
56 | def _sample_parameters(self):
57 | self.samples['weight'] = self.weight[:self.sample_embed_dim]
58 | self.samples['bias'] = self.bias[:self.sample_embed_dim]
59 | return self.samples
60 |
61 | def set_sample_config(self, sample_embed_dim):
62 | self.sample_embed_dim = sample_embed_dim
63 | self._sample_parameters()
64 |
65 | def forward(self, x):
66 | self.sample_parameters()
67 | return F.layer_norm(x, (self.sample_embed_dim,), weight=self.samples['weight'], bias=self.samples['bias'], eps=self.eps)
68 |
69 | def calc_sampled_param_num(self):
70 | assert 'weight' in self.samples.keys()
71 | assert 'bias' in self.samples.keys()
72 | return self.samples['weight'].numel() + self.samples['bias'].numel()
73 |
74 | def get_complexity(self, sequence_length):
75 | return sequence_length * self.sample_embed_dim
76 |
--------------------------------------------------------------------------------
/model/module/multihead_super.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import Parameter
4 | import torch.nn.functional as F
5 | from .Linear_super import LinearSuper
6 | from .qkv_super import qkv_super
7 | from .scaling_super import ScalingSuper
8 | from ..utils import trunc_normal_
9 | from torch.cuda.amp import autocast
10 |
11 | def softmax(x, dim, onnx_trace=False):
12 | if onnx_trace:
13 | return F.softmax(x.float(), dim=dim)
14 | else:
15 | return F.softmax(x, dim=dim, dtype=torch.float32)
16 |
17 | class RelativePosition2D_super(nn.Module):
18 |
19 | def __init__(self, num_units, max_relative_position):
20 | super().__init__()
21 |
22 | self.num_units = num_units
23 | self.max_relative_position = max_relative_position
24 | # The first element in embeddings_table_v is the vertical embedding for the class
25 | self.embeddings_table_v = nn.Parameter(torch.randn(max_relative_position * 2 + 2, num_units))
26 | self.embeddings_table_h = nn.Parameter(torch.randn(max_relative_position * 2 + 2, num_units))
27 |
28 | trunc_normal_(self.embeddings_table_v, std=.02)
29 | trunc_normal_(self.embeddings_table_h, std=.02)
30 |
31 | self.sample_head_dim = None
32 | self.sample_embeddings_table_h = None
33 | self.sample_embeddings_table_v = None
34 |
35 | def set_sample_config(self, sample_head_dim):
36 | self.sample_head_dim = sample_head_dim
37 | self.sample_embeddings_table_h = self.embeddings_table_h[:,:sample_head_dim]
38 | self.sample_embeddings_table_v = self.embeddings_table_v[:,:sample_head_dim]
39 |
40 | def calc_sampled_param_num(self):
41 | return self.sample_embeddings_table_h.numel() + self.sample_embeddings_table_v.numel()
42 |
43 | def forward(self, length_q, length_k):
44 | # remove the first cls token distance computation
45 | length_q = length_q - 1
46 | length_k = length_k - 1
47 | device = self.embeddings_table_v.device
48 | range_vec_q = torch.arange(length_q, device=device)
49 | range_vec_k = torch.arange(length_k, device=device)
50 | # compute the row and column distance
51 | distance_mat_v = (range_vec_k[None, :] // int(length_q ** 0.5 ) - range_vec_q[:, None] // int(length_q ** 0.5 ))
52 | distance_mat_h = (range_vec_k[None, :] % int(length_q ** 0.5 ) - range_vec_q[:, None] % int(length_q ** 0.5 ))
53 | # clip the distance to the range of [-max_relative_position, max_relative_position]
54 | distance_mat_clipped_v = torch.clamp(distance_mat_v, -self.max_relative_position, self.max_relative_position)
55 | distance_mat_clipped_h = torch.clamp(distance_mat_h, -self.max_relative_position, self.max_relative_position)
56 |
57 | # translate the distance from [1, 2 * max_relative_position + 1], 0 is for the cls token
58 | final_mat_v = distance_mat_clipped_v + self.max_relative_position + 1
59 | final_mat_h = distance_mat_clipped_h + self.max_relative_position + 1
60 | # pad the 0 which represent the cls token
61 | final_mat_v = torch.nn.functional.pad(final_mat_v, (1,0,1,0), "constant", 0)
62 | final_mat_h = torch.nn.functional.pad(final_mat_h, (1,0,1,0), "constant", 0)
63 |
64 | final_mat_v = final_mat_v.long()
65 | final_mat_h = final_mat_h.long()
66 | # get the embeddings with the corresponding distance
67 | embeddings = self.sample_embeddings_table_v[final_mat_v] + self.sample_embeddings_table_h[final_mat_h]
68 |
69 | return embeddings
70 |
71 | class AttentionSuper(nn.Module):
72 | def __init__(self, super_embed_dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., normalization = False, relative_position = False,
73 | num_patches = None, max_relative_position=14, scale=False, change_qkv = False, choices=None, scale_attn=False):
74 | super().__init__()
75 | self.num_heads = num_heads
76 | head_dim = super_embed_dim // num_heads
77 | self.scale = qk_scale or head_dim ** -0.5
78 | self.super_embed_dim = super_embed_dim
79 |
80 | self.fc_scale = scale
81 | self.change_qkv = change_qkv
82 | self.scale_attn = scale_attn
83 |
84 | self.choices = choices
85 |
86 | if change_qkv:
87 | self.qkv = qkv_super(super_embed_dim, 3 * super_embed_dim, bias=qkv_bias)
88 | if scale_attn:
89 | self.qk_scaling = ScalingSuper([n_head * 64 for n_head in self.choices['num_heads']])
90 | self.v_scaling = ScalingSuper([n_head * 64 for n_head in self.choices['num_heads']])
91 | else:
92 | self.qkv = LinearSuper(super_embed_dim, 3 * super_embed_dim, bias=qkv_bias)
93 |
94 | self.relative_position = relative_position
95 | if self.relative_position:
96 | self.rel_pos_embed_k = RelativePosition2D_super(super_embed_dim //num_heads, max_relative_position)
97 | self.rel_pos_embed_v = RelativePosition2D_super(super_embed_dim //num_heads, max_relative_position)
98 | self.max_relative_position = max_relative_position
99 | self.sample_qk_embed_dim = None
100 | self.sample_v_embed_dim = None
101 | self.sample_num_heads = None
102 | self.sample_scale = None
103 | self.sample_in_embed_dim = None
104 |
105 | self.proj = LinearSuper(super_embed_dim, super_embed_dim)
106 |
107 | self.attn_drop = nn.Dropout(attn_drop)
108 | self.proj_drop = nn.Dropout(proj_drop)
109 |
110 | def set_sample_config(self, sample_q_embed_dim=None, sample_num_heads=None, sample_in_embed_dim=None):
111 |
112 | self.sample_in_embed_dim = sample_in_embed_dim
113 | self.sample_num_heads = sample_num_heads
114 | if not self.change_qkv:
115 | self.sample_qk_embed_dim = self.super_embed_dim
116 | self.sample_scale = (sample_in_embed_dim // self.sample_num_heads) ** -0.5
117 | else:
118 | if self.scale_attn:
119 | self.qk_scaling.set_sample_config(sample_q_embed_dim)
120 | self.v_scaling.set_sample_config(sample_q_embed_dim)
121 | self.sample_qk_embed_dim = sample_q_embed_dim
122 | self.sample_scale = (self.sample_qk_embed_dim // self.sample_num_heads) ** -0.5
123 |
124 | self.qkv.set_sample_config(sample_in_dim=sample_in_embed_dim, sample_out_dim=3*self.sample_qk_embed_dim)
125 | self.proj.set_sample_config(sample_in_dim=self.sample_qk_embed_dim, sample_out_dim=sample_in_embed_dim)
126 | if self.relative_position:
127 | self.rel_pos_embed_k.set_sample_config(self.sample_qk_embed_dim // sample_num_heads)
128 | self.rel_pos_embed_v.set_sample_config(self.sample_qk_embed_dim // sample_num_heads)
129 | def calc_sampled_param_num(self):
130 |
131 | return 0
132 | def get_complexity(self, sequence_length):
133 | total_flops = 0
134 | total_flops += self.qkv.get_complexity(sequence_length)
135 | # attn
136 | total_flops += sequence_length * sequence_length * self.sample_qk_embed_dim
137 | # x
138 | total_flops += sequence_length * sequence_length * self.sample_qk_embed_dim
139 | total_flops += self.proj.get_complexity(sequence_length)
140 | if self.relative_position:
141 | total_flops += self.max_relative_position * sequence_length * sequence_length + sequence_length * sequence_length / 2.0
142 | total_flops += self.max_relative_position * sequence_length * sequence_length + sequence_length * self.sample_qk_embed_dim / 2.0
143 | return total_flops
144 |
145 | def forward(self, x):
146 | B, N, C = x.shape
147 | qkv = self.qkv(x).reshape(B, N, 3, self.sample_num_heads, -1).permute(2, 0, 3, 1, 4)
148 | with autocast(enabled=False):
149 | q, k, v = qkv[0].float(), qkv[1].float(), qkv[2].float() # make torchscript happy (cannot use tensor as tuple)
150 |
151 | if self.scale_attn:
152 | q = self.qk_scaling(q.transpose(1, 2).reshape(B, N, -1)).reshape(B, N, self.sample_num_heads, -1).transpose(1, 2)
153 | attn = (q @ k.transpose(-2, -1)) * self.sample_scale
154 | if self.relative_position:
155 | r_p_k = self.rel_pos_embed_k(N, N)
156 | attn = attn + (q.permute(2, 0, 1, 3).reshape(N, self.sample_num_heads * B, -1) @ r_p_k.transpose(2, 1)) \
157 | .transpose(1, 0).reshape(B, self.sample_num_heads, N, N) * self.sample_scale
158 |
159 | attn = attn.softmax(dim=-1)
160 | attn = self.attn_drop(attn)
161 |
162 | if self.scale_attn:
163 | v = self.v_scaling(v.transpose(1, 2).reshape(B, N, -1)).reshape(B, N, self.sample_num_heads, -1).transpose(1, 2)
164 | x = (attn @ v).transpose(1,2).reshape(B, N, -1)
165 | if self.relative_position:
166 | r_p_v = self.rel_pos_embed_v(N, N)
167 | attn_1 = attn.permute(2, 0, 1, 3).reshape(N, B * self.sample_num_heads, -1)
168 | # The size of attention is (B, num_heads, N, N), reshape it to (N, B*num_heads, N) and do batch matmul with
169 | # the relative position embedding of V (N, N, head_dim) get shape like (N, B*num_heads, head_dim). We reshape it to the
170 | # same size as x (B, num_heads, N, hidden_dim)
171 | x = x + (attn_1 @ r_p_v).transpose(1, 0).reshape(B, self.sample_num_heads, N, -1).transpose(2,1).reshape(B, N, -1)
172 |
173 | if self.fc_scale:
174 | x = x * (self.super_embed_dim / self.sample_qk_embed_dim)
175 | x = self.proj(x)
176 | x = self.proj_drop(x)
177 | return x
178 |
--------------------------------------------------------------------------------
/model/module/qkv_super.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 |
6 |
7 | class qkv_super(nn.Linear):
8 | def __init__(self, super_in_dim, super_out_dim, bias=True, uniform_=None, non_linear='linear', scale=False):
9 | super().__init__(super_in_dim, super_out_dim, bias=bias)
10 |
11 | # super_in_dim and super_out_dim indicate the largest network!
12 | self.super_in_dim = super_in_dim
13 | self.super_out_dim = super_out_dim
14 |
15 | # input_dim and output_dim indicate the current sampled size
16 | self.sample_in_dim = None
17 | self.sample_out_dim = None
18 |
19 | self.samples = {}
20 |
21 | self.scale = scale
22 | # self._reset_parameters(bias, uniform_, non_linear)
23 | self.profiling = False
24 |
25 | def profile(self, mode=True):
26 | self.profiling = mode
27 |
28 | def sample_parameters(self, resample=False):
29 | if self.profiling or resample:
30 | return self._sample_parameters()
31 | return self.samples
32 |
33 | def _reset_parameters(self, bias, uniform_, non_linear):
34 | nn.init.xavier_uniform_(self.weight) if uniform_ is None else uniform_(
35 | self.weight, non_linear=non_linear)
36 | if bias:
37 | nn.init.constant_(self.bias, 0.)
38 |
39 | def set_sample_config(self, sample_in_dim, sample_out_dim):
40 | self.sample_in_dim = sample_in_dim
41 | self.sample_out_dim = sample_out_dim
42 |
43 | self._sample_parameters()
44 |
45 | def _sample_parameters(self):
46 | self.samples['weight'] = sample_weight(self.weight, self.sample_in_dim, self.sample_out_dim)
47 | self.samples['bias'] = self.bias
48 | self.sample_scale = self.super_out_dim/self.sample_out_dim
49 | if self.bias is not None:
50 | self.samples['bias'] = sample_bias(self.bias, self.sample_out_dim)
51 | return self.samples
52 |
53 | def forward(self, x):
54 | self.sample_parameters()
55 | return F.linear(x, self.samples['weight'], self.samples['bias']) * (self.sample_scale if self.scale else 1)
56 |
57 | def calc_sampled_param_num(self):
58 | assert 'weight' in self.samples.keys()
59 | weight_numel = self.samples['weight'].numel()
60 |
61 | if self.samples['bias'] is not None:
62 | bias_numel = self.samples['bias'].numel()
63 | else:
64 | bias_numel = 0
65 |
66 | return weight_numel + bias_numel
67 | def get_complexity(self, sequence_length):
68 | total_flops = 0
69 | total_flops += sequence_length * np.prod(self.samples['weight'].size())
70 | return total_flops
71 |
72 | def sample_weight(weight, sample_in_dim, sample_out_dim):
73 |
74 | sample_weight = weight[:, :sample_in_dim]
75 | sample_weight = torch.cat([sample_weight[i:sample_out_dim:3, :] for i in range(3)], dim =0)
76 |
77 | return sample_weight
78 |
79 |
80 | def sample_bias(bias, sample_out_dim):
81 | #sample_bias = bias[:sample_out_dim]
82 | sample_bias = torch.cat([bias[i:sample_out_dim:3] for i in range(3)])
83 |
84 | return sample_bias
85 |
--------------------------------------------------------------------------------
/model/module/scaling_super.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class ScalingSuper(nn.Module):
7 | def __init__(self, embed_dim_list):
8 | super(ScalingSuper, self).__init__()
9 |
10 | self.embed_dim_list = embed_dim_list
11 |
12 | # the largest embed dim
13 | self.super_embed_dim = max(embed_dim_list)
14 |
15 | # the current sampled embed dim
16 | self.sample_embed_dim = None
17 |
18 | self.scalings = nn.Parameter(1e-4 * torch.ones(len(embed_dim_list), self.super_embed_dim))
19 |
20 | def set_sample_config(self, sample_embed_dim):
21 | self.sample_embed_dim = sample_embed_dim
22 | self.sample_idx = self.embed_dim_list.index(sample_embed_dim)
23 |
24 | def forward(self, x):
25 | return x * self.scalings[self.sample_idx][:self.sample_embed_dim]
26 |
27 | def calc_sampled_param_num(self):
28 | return 0 #self.scalings[self.sample_idx][:self.sample_embed_dim].numel()
29 |
30 | def get_complexity(self, sequence_length):
31 | return 0 #sequence_length * self.sample_embed_dim
32 |
33 |
--------------------------------------------------------------------------------
/model/supernet_transformer.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from model.module.Linear_super import LinearSuper
7 | from model.module.layernorm_super import LayerNormSuper, SwitchableLayerNormSuper
8 | from model.module.scaling_super import ScalingSuper
9 | from model.module.multihead_super import AttentionSuper
10 | from model.module.embedding_super import PatchembedSuper
11 | from model.utils import trunc_normal_
12 | from model.utils import DropPath
13 | import numpy as np
14 |
15 | def gelu(x: torch.Tensor) -> torch.Tensor:
16 | if hasattr(torch.nn.functional, 'gelu'):
17 | return torch.nn.functional.gelu(x.float()).type_as(x)
18 | else:
19 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
20 |
21 |
22 | class Vision_TransformerSuper(nn.Module):
23 |
24 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
25 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
26 | drop_path_rate=0., pre_norm=True, scale=False, gp=False, relative_position=False, change_qkv=False, abs_pos = True, max_relative_position=14,
27 | choices=None, switch_ln=False, scale_attn=False, scale_mlp=False, scale_embed=False):
28 | super(Vision_TransformerSuper, self).__init__()
29 | # the configs of super arch
30 | self.super_embed_dim = embed_dim
31 | # self.super_embed_dim = args.embed_dim
32 | self.super_mlp_ratio = mlp_ratio
33 | self.super_layer_num = depth
34 | self.super_num_heads = num_heads
35 | self.super_dropout = drop_rate
36 | self.super_attn_dropout = attn_drop_rate
37 | self.num_classes = num_classes
38 | self.pre_norm=pre_norm
39 | self.scale=scale
40 | self.patch_embed_super = PatchembedSuper(img_size=img_size, patch_size=patch_size,
41 | in_chans=in_chans, embed_dim=embed_dim)
42 | self.gp = gp
43 | self.choices = choices
44 | self.scale_embed = scale_embed
45 |
46 | # configs for the sampled subTransformer
47 | self.sample_embed_dim = None
48 | self.sample_mlp_ratio = None
49 | self.sample_layer_num = None
50 | self.sample_num_heads = None
51 | self.sample_dropout = None
52 | self.sample_output_dim = None
53 |
54 | self.blocks = nn.ModuleList()
55 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
56 |
57 | for i in range(depth):
58 | self.blocks.append(TransformerEncoderLayer(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio,
59 | qkv_bias=qkv_bias, qk_scale=qk_scale, dropout=drop_rate,
60 | attn_drop=attn_drop_rate, drop_path=dpr[i],
61 | pre_norm=pre_norm, scale=self.scale,
62 | change_qkv=change_qkv, relative_position=relative_position,
63 | max_relative_position=max_relative_position,
64 | choices=choices, switch_ln=switch_ln,
65 | scale_attn=scale_attn, scale_mlp=scale_mlp, scale_embed=scale_embed))
66 |
67 | # parameters for vision transformer
68 | num_patches = self.patch_embed_super.num_patches
69 |
70 | self.abs_pos = abs_pos
71 | if self.abs_pos:
72 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
73 | trunc_normal_(self.pos_embed, std=.02)
74 |
75 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
76 | trunc_normal_(self.cls_token, std=.02)
77 |
78 | if scale_embed:
79 | self.embed_scaling = ScalingSuper(self.choices['embed_dim'])
80 |
81 | # self.pos_drop = nn.Dropout(p=drop_rate)
82 | if self.pre_norm:
83 | if switch_ln:
84 | self.norm = SwitchableLayerNormSuper(self.choices['embed_dim'])
85 | else:
86 | self.norm = LayerNormSuper(super_embed_dim=embed_dim)
87 |
88 | # classifier head
89 | self.head = LinearSuper(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
90 |
91 | self.apply(self._init_weights)
92 |
93 | def _init_weights(self, m):
94 | if isinstance(m, nn.Linear):
95 | trunc_normal_(m.weight, std=.02)
96 | if isinstance(m, nn.Linear) and m.bias is not None:
97 | nn.init.constant_(m.bias, 0)
98 | elif isinstance(m, nn.LayerNorm):
99 | nn.init.constant_(m.bias, 0)
100 | nn.init.constant_(m.weight, 1.0)
101 |
102 | @torch.jit.ignore
103 | def no_weight_decay(self):
104 | return {'pos_embed', 'cls_token', 'rel_pos_embed'}
105 |
106 | def get_classifier(self):
107 | return self.head
108 |
109 | #def reset_classifier(self, num_classes, global_pool=''):
110 | # self.num_classes = num_classes
111 | # self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
112 |
113 | def set_sample_config(self, config: dict):
114 | self.sample_embed_dim = config['embed_dim']
115 | self.sample_mlp_ratio = config['mlp_ratio']
116 | self.sample_layer_num = config['layer_num']
117 | self.sample_num_heads = config['num_heads']
118 | self.sample_dropout = calc_dropout(self.super_dropout, self.sample_embed_dim[0], self.super_embed_dim)
119 | self.patch_embed_super.set_sample_config(self.sample_embed_dim[0])
120 | if self.scale_embed:
121 | self.embed_scaling.set_sample_config(self.sample_embed_dim[0])
122 | self.sample_output_dim = [out_dim for out_dim in self.sample_embed_dim[1:]] + [self.sample_embed_dim[-1]]
123 | for i, blocks in enumerate(self.blocks):
124 | # not exceed sample layer number
125 | if i < self.sample_layer_num:
126 | sample_dropout = calc_dropout(self.super_dropout, self.sample_embed_dim[i], self.super_embed_dim)
127 | sample_attn_dropout = calc_dropout(self.super_attn_dropout, self.sample_embed_dim[i], self.super_embed_dim)
128 | blocks.set_sample_config(is_identity_layer=False,
129 | sample_embed_dim=self.sample_embed_dim[i],
130 | sample_mlp_ratio=self.sample_mlp_ratio[i],
131 | sample_num_heads=self.sample_num_heads[i],
132 | sample_dropout=sample_dropout,
133 | sample_out_dim=self.sample_output_dim[i],
134 | sample_attn_dropout=sample_attn_dropout)
135 | # exceeds sample layer number
136 | else:
137 | blocks.set_sample_config(is_identity_layer=True)
138 | if self.pre_norm:
139 | self.norm.set_sample_config(self.sample_embed_dim[-1])
140 | self.head.set_sample_config(self.sample_embed_dim[-1], self.num_classes)
141 |
142 | def get_sampled_params_numel(self, config):
143 | self.set_sample_config(config)
144 | numels = []
145 | for name, module in self.named_modules():
146 | if hasattr(module, 'calc_sampled_param_num'):
147 | if name.split('.')[0] == 'blocks' and int(name.split('.')[1]) >= config['layer_num']:
148 | continue
149 | numels.append(module.calc_sampled_param_num())
150 |
151 | return sum(numels) + self.sample_embed_dim[0]* (2 +self.patch_embed_super.num_patches)
152 | def get_complexity(self, sequence_length):
153 | total_flops = 0
154 | total_flops += self.patch_embed_super.get_complexity(sequence_length)
155 | total_flops += np.prod(self.pos_embed[..., :self.sample_embed_dim[0]].size()) / 2.0
156 | for blk in self.blocks:
157 | total_flops += blk.get_complexity(sequence_length+1)
158 | total_flops += self.head.get_complexity(sequence_length+1)
159 | return total_flops
160 | def forward_features(self, x):
161 | B = x.shape[0]
162 | x = self.patch_embed_super(x)
163 | cls_tokens = self.cls_token[..., :self.sample_embed_dim[0]].expand(B, -1, -1)
164 | x = torch.cat((cls_tokens, x), dim=1)
165 | if self.abs_pos:
166 | x = x + self.pos_embed[..., :self.sample_embed_dim[0]]
167 |
168 | if self.scale_embed:
169 | x = self.embed_scaling(x)
170 |
171 | x = F.dropout(x, p=self.sample_dropout, training=self.training)
172 |
173 | # start_time = time.time()
174 | for blk in self.blocks:
175 | x = blk(x)
176 | # print(time.time()-start_time)
177 | if self.pre_norm:
178 | x = self.norm(x)
179 |
180 | if self.gp:
181 | return torch.mean(x[:, 1:] , dim=1)
182 |
183 | return x[:, 0]
184 |
185 | def forward(self, x):
186 | x = self.forward_features(x)
187 | x = self.head(x)
188 | return x
189 |
190 |
191 | class TransformerEncoderLayer(nn.Module):
192 | """Encoder layer block.
193 |
194 | Args:
195 | args (argparse.Namespace): parsed command-line arguments which
196 | """
197 |
198 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, dropout=0., attn_drop=0.,
199 | drop_path=0., act_layer=nn.GELU, pre_norm=True, scale=False,
200 | relative_position=False, change_qkv=False, max_relative_position=14, choices=None, switch_ln=False,
201 | scale_attn=False, scale_mlp=False, scale_embed=False):
202 | super().__init__()
203 |
204 | # the configs of super arch of the encoder, three dimension [embed_dim, mlp_ratio, and num_heads]
205 | self.super_embed_dim = dim
206 | self.super_mlp_ratio = mlp_ratio
207 | self.super_ffn_embed_dim_this_layer = int(mlp_ratio * dim)
208 | self.super_num_heads = num_heads
209 | self.normalize_before = pre_norm
210 | self.super_dropout = attn_drop
211 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
212 | self.scale = scale
213 | self.relative_position = relative_position
214 | # self.super_activation_dropout = getattr(args, 'activation_dropout', 0)
215 | self.choices = choices
216 | self.scale_mlp = scale_mlp
217 | self.scale_embed = scale_embed
218 |
219 | # the configs of current sampled arch
220 | self.sample_embed_dim = None
221 | self.sample_mlp_ratio = None
222 | self.sample_ffn_embed_dim_this_layer = None
223 | self.sample_num_heads_this_layer = None
224 | self.sample_scale = None
225 | self.sample_dropout = None
226 | self.sample_attn_dropout = None
227 |
228 | self.is_identity_layer = None
229 | self.attn = AttentionSuper(
230 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop,
231 | proj_drop=dropout, scale=self.scale, relative_position=self.relative_position, change_qkv=change_qkv,
232 | max_relative_position=max_relative_position, choices=choices, scale_attn=scale_attn,
233 | )
234 |
235 | if switch_ln:
236 | self.attn_layer_norm = SwitchableLayerNormSuper(self.choices['embed_dim'])
237 | self.ffn_layer_norm = SwitchableLayerNormSuper(self.choices['embed_dim'])
238 | else:
239 | self.attn_layer_norm = LayerNormSuper(self.super_embed_dim)
240 | self.ffn_layer_norm = LayerNormSuper(self.super_embed_dim)
241 |
242 | if scale_embed:
243 | self.embed_scaling1 = ScalingSuper(self.choices['embed_dim'])
244 | self.embed_scaling2 = ScalingSuper(self.choices['embed_dim'])
245 |
246 | # self.dropout = dropout
247 | self.activation_fn = gelu
248 | # self.normalize_before = args.encoder_normalize_before
249 |
250 | self.fc1 = LinearSuper(super_in_dim=self.super_embed_dim, super_out_dim=self.super_ffn_embed_dim_this_layer)
251 | if scale_mlp:
252 | self.mlp_scaling = ScalingSuper([emb * ratio for emb in self.choices['embed_dim'] for ratio in self.choices['mlp_ratio']])
253 | self.fc2 = LinearSuper(super_in_dim=self.super_ffn_embed_dim_this_layer, super_out_dim=self.super_embed_dim)
254 |
255 |
256 | def set_sample_config(self, is_identity_layer, sample_embed_dim=None, sample_mlp_ratio=None, sample_num_heads=None, sample_dropout=None, sample_attn_dropout=None, sample_out_dim=None):
257 |
258 | if is_identity_layer:
259 | self.is_identity_layer = True
260 | return
261 |
262 | self.is_identity_layer = False
263 |
264 | self.sample_embed_dim = sample_embed_dim
265 | self.sample_out_dim = sample_out_dim
266 | self.sample_mlp_ratio = sample_mlp_ratio
267 | self.sample_ffn_embed_dim_this_layer = int(sample_embed_dim*sample_mlp_ratio)
268 | self.sample_num_heads_this_layer = sample_num_heads
269 |
270 | self.sample_dropout = sample_dropout
271 | self.sample_attn_dropout = sample_attn_dropout
272 | self.attn_layer_norm.set_sample_config(sample_embed_dim=self.sample_embed_dim)
273 |
274 | self.attn.set_sample_config(sample_q_embed_dim=self.sample_num_heads_this_layer*64, sample_num_heads=self.sample_num_heads_this_layer, sample_in_embed_dim=self.sample_embed_dim)
275 |
276 | self.fc1.set_sample_config(sample_in_dim=self.sample_embed_dim, sample_out_dim=self.sample_ffn_embed_dim_this_layer)
277 | if self.scale_mlp:
278 | self.mlp_scaling.set_sample_config(self.sample_ffn_embed_dim_this_layer)
279 | self.fc2.set_sample_config(sample_in_dim=self.sample_ffn_embed_dim_this_layer, sample_out_dim=self.sample_out_dim)
280 |
281 | self.ffn_layer_norm.set_sample_config(sample_embed_dim=self.sample_embed_dim)
282 |
283 | if self.scale_embed:
284 | self.embed_scaling1.set_sample_config(sample_embed_dim=self.sample_embed_dim)
285 | self.embed_scaling2.set_sample_config(sample_embed_dim=self.sample_embed_dim)
286 |
287 |
288 | def forward(self, x):
289 | """
290 | Args:
291 | x (Tensor): input to the layer of shape `(batch, patch_num , sample_embed_dim)`
292 |
293 | Returns:
294 | encoded output of shape `(batch, patch_num, sample_embed_dim)`
295 | """
296 | if self.is_identity_layer:
297 | return x
298 |
299 | # compute attn
300 | # start_time = time.time()
301 |
302 | residual = x
303 | x = self.maybe_layer_norm(self.attn_layer_norm, x, before=True)
304 | x = self.attn(x)
305 | if self.scale_embed:
306 | x = self.embed_scaling1(x)
307 | x = F.dropout(x, p=self.sample_attn_dropout, training=self.training)
308 | x = self.drop_path(x)
309 | x = residual + x
310 | x = self.maybe_layer_norm(self.attn_layer_norm, x, after=True)
311 | # print("attn :", time.time() - start_time)
312 | # compute the ffn
313 | # start_time = time.time()
314 | residual = x
315 | x = self.maybe_layer_norm(self.ffn_layer_norm, x, before=True)
316 | x = self.fc1(x)
317 | if self.scale_mlp:
318 | x = self.mlp_scaling(x)
319 | x = self.activation_fn(x)
320 | x = F.dropout(x, p=self.sample_dropout, training=self.training)
321 | x = self.fc2(x)
322 | if self.scale_embed:
323 | x = self.embed_scaling2(x)
324 | x = F.dropout(x, p=self.sample_dropout, training=self.training)
325 | if self.scale:
326 | x = x * (self.super_mlp_ratio / self.sample_mlp_ratio)
327 | x = self.drop_path(x)
328 | x = residual + x
329 | x = self.maybe_layer_norm(self.ffn_layer_norm, x, after=True)
330 | # print("ffn :", time.time() - start_time)
331 | return x
332 |
333 | def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
334 | assert before ^ after
335 | if after ^ self.normalize_before:
336 | return layer_norm(x)
337 | else:
338 | return x
339 | def get_complexity(self, sequence_length):
340 | total_flops = 0
341 | if self.is_identity_layer:
342 | return total_flops
343 | total_flops += self.attn_layer_norm.get_complexity(sequence_length+1)
344 | total_flops += self.attn.get_complexity(sequence_length+1)
345 | total_flops += self.ffn_layer_norm.get_complexity(sequence_length+1)
346 | total_flops += self.fc1.get_complexity(sequence_length+1)
347 | total_flops += self.fc2.get_complexity(sequence_length+1)
348 | return total_flops
349 |
350 | def calc_dropout(dropout, sample_embed_dim, super_embed_dim):
351 | return dropout * 1.0 * sample_embed_dim / super_embed_dim
352 |
353 |
354 |
355 |
356 |
357 |
--------------------------------------------------------------------------------
/model/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import warnings
4 | from itertools import repeat
5 | import collections.abc as container_abcs
6 | import torch.nn as nn
7 |
8 | def _no_grad_trunc_normal_(tensor, mean, std, a, b):
9 | # Cut & paste from PyTorch official master until it's in a few official releases - RW
10 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
11 | def norm_cdf(x):
12 | # Computes standard normal cumulative distribution function
13 | return (1. + math.erf(x / math.sqrt(2.))) / 2.
14 |
15 | if (mean < a - 2 * std) or (mean > b + 2 * std):
16 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
17 | "The distribution of values may be incorrect.",
18 | stacklevel=2)
19 |
20 | with torch.no_grad():
21 | # Values are generated by using a truncated uniform distribution and
22 | # then using the inverse CDF for the normal distribution.
23 | # Get upper and lower cdf values
24 | l = norm_cdf((a - mean) / std)
25 | u = norm_cdf((b - mean) / std)
26 |
27 | # Uniformly fill tensor with values from [l, u], then translate to
28 | # [2l-1, 2u-1].
29 | tensor.uniform_(2 * l - 1, 2 * u - 1)
30 |
31 | # Use inverse cdf transform for normal distribution to get truncated
32 | # standard normal
33 | tensor.erfinv_()
34 |
35 | # Transform to proper mean, std
36 | tensor.mul_(std * math.sqrt(2.))
37 | tensor.add_(mean)
38 |
39 | # Clamp to ensure it's in the proper range
40 | tensor.clamp_(min=a, max=b)
41 | return tensor
42 |
43 |
44 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
45 | # type: (Tensor, float, float, float, float) -> Tensor
46 | r"""Fills the input Tensor with values drawn from a truncated
47 | normal distribution. The values are effectively drawn from the
48 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
49 | with values outside :math:`[a, b]` redrawn until they are within
50 | the bounds. The method used for generating the random values works
51 | best when :math:`a \leq \text{mean} \leq b`.
52 | Args:
53 | tensor: an n-dimensional `torch.Tensor`
54 | mean: the mean of the normal distribution
55 | std: the standard deviation of the normal distribution
56 | a: the minimum cutoff value
57 | b: the maximum cutoff value
58 | Examples:
59 | >>> w = torch.empty(3, 5)
60 | >>> nn.init.trunc_normal_(w)
61 | """
62 | return _no_grad_trunc_normal_(tensor, mean, std, a, b)
63 |
64 | def _ntuple(n):
65 | def parse(x):
66 | if isinstance(x, container_abcs.Iterable):
67 | return x
68 | return tuple(repeat(x, n))
69 | return parse
70 |
71 | def drop_path(x, drop_prob: float = 0., training: bool = False):
72 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
73 |
74 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
75 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
76 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
77 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
78 | 'survival rate' as the argument.
79 |
80 | """
81 | if drop_prob == 0. or not training:
82 | return x
83 | keep_prob = 1 - drop_prob
84 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
85 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
86 | random_tensor.floor_() # binarize
87 | output = x.div(keep_prob) * random_tensor
88 | return output
89 |
90 |
91 | class DropPath(nn.Module):
92 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
93 | """
94 | def __init__(self, drop_prob=None):
95 | super(DropPath, self).__init__()
96 | self.drop_prob = drop_prob
97 |
98 | def forward(self, x):
99 | return drop_path(x, self.drop_prob, self.training)
100 |
101 |
102 | to_1tuple = _ntuple(1)
103 | to_2tuple = _ntuple(2)
104 | to_3tuple = _ntuple(3)
105 | to_4tuple = _ntuple(4)
106 | to_ntuple = _ntuple
107 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.7.0
2 | timm==0.3.2
3 | scikit-image
4 | ptflops
5 | easydict
6 | PyYAML
7 | pillow
8 | torchvision==0.2.1
9 | opencv-python
10 |
--------------------------------------------------------------------------------
/supernet_engine.py:
--------------------------------------------------------------------------------
1 | import math
2 | import sys
3 | from typing import Iterable, Optional
4 | from timm.utils.model import unwrap_model
5 | import torch
6 | import torch.distributed as dist
7 | from copy import deepcopy
8 | from timm.data import Mixup
9 | from timm.utils import accuracy, ModelEma
10 | from lib import utils
11 | import random
12 | import time
13 | import json
14 | from contextlib import ExitStack
15 |
16 | def sample_a_cand(candidates, grouping=None, exclude=None):
17 | exclude = exclude or []
18 | while not grouping:
19 | idx = random.choice(range(len(candidates)))
20 | if idx not in exclude:
21 | return candidates[idx]
22 | while True:
23 | idx = random.choice(random.choice(grouping))
24 | if idx not in exclude:
25 | return candidates[idx]
26 |
27 | def sample_candidates(candidates, eval=False, sandwich=0, sandwich_base=True, sandwich_top=True, shuffle=False, grouping=None):
28 | if eval:
29 | return candidates[0]
30 | else:
31 | if sandwich == 0:
32 | cand = sample_a_cand(candidates, grouping)
33 | if shuffle:
34 | cand = deepcopy(cand)
35 | random.shuffle(cand['mlp_ratio'])
36 | random.shuffle(cand['num_heads'])
37 | return cand
38 | else:
39 | base_cand = []
40 | top_cand = []
41 | exclude = []
42 | if sandwich_base:
43 | base_cand = [candidates[0]]
44 | exclude.append(0)
45 | if sandwich_top:
46 | top_cand = [candidates[-1]]
47 | exclude.append(len(candidates)-1)
48 | inter_cands = [sample_a_cand(candidates, grouping, exclude) for _ in range(sandwich)]
49 | return base_cand + inter_cands + top_cand
50 |
51 | def sample_a_config(choices, efunc=random.choice):
52 | config = {}
53 | embed_dim = efunc(choices['embed_dim'])
54 | if isinstance(choices['depth'], dict):
55 | depth = efunc(choices['depth'][embed_dim])
56 | else:
57 | depth = efunc(choices['depth'])
58 | dimensions = ['mlp_ratio', 'num_heads']
59 | for dimension in dimensions:
60 | if isinstance(choices[dimension], dict):
61 | config[dimension] = [efunc(choices[dimension][embed_dim][i]) for i in range(depth)]
62 | else:
63 | config[dimension] = [efunc(choices[dimension]) for _ in range(depth)]
64 | config['embed_dim'] = [embed_dim] * depth
65 | config['layer_num'] = depth
66 | return config
67 |
68 | def sample_configs(choices, eval=False, sandwich=0, sandwich_base=True, sandwich_top=True):
69 | if eval:
70 | return sample_a_config(choices, min)
71 | else:
72 | if sandwich == 0:
73 | return sample_a_config(choices)
74 | else:
75 | base_config = [sample_a_config(choices, min)] if sandwich_base else []
76 | top_config = [sample_a_config(choices, max)] if sandwich_top else []
77 | inter_configs = [sample_a_config(choices) for _ in range(sandwich)]
78 | return base_config + inter_configs + top_config
79 |
80 | def bp_once(loss, loss_scaler=None, create_graph=False):
81 | if loss_scaler:
82 | loss_scaler._scaler.scale(loss).backward(create_graph=create_graph)
83 | else:
84 | loss.backward()
85 |
86 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
87 | data_loader: Iterable, optimizer: torch.optim.Optimizer,
88 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
89 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None,
90 | amp: bool = True, teacher_model: torch.nn.Module = None,
91 | teach_loss: torch.nn.Module = None, choices=None, mode='super', retrain_config=None,
92 | print2file=False, candidates=None, sandwich=0, sandwich_base=True, sandwich_top=True,
93 | shuffle=False, grouping=None):
94 | model.train()
95 | criterion.train()
96 |
97 | # set random seed
98 | random.seed(epoch)
99 |
100 | metric_logger = utils.MetricLogger(delimiter=" ")
101 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
102 | header = 'Epoch: [{}]'.format(epoch)
103 | print_freq = 10
104 | if mode == 'retrain':
105 | config = retrain_config
106 | model_module = unwrap_model(model)
107 | print("DEBUG:retrain {}".format(config), force=print2file)
108 | model_module.set_sample_config(config=config)
109 | print("DEBUG:retrain {}".format(model_module.get_sampled_params_numel(config)), force=print2file)
110 |
111 | for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
112 | samples = samples.to(device, non_blocking=True)
113 | targets = targets.to(device, non_blocking=True)
114 |
115 | # sample random config
116 | if mode == 'super':
117 | sandwich_args = {'sandwich': sandwich,
118 | 'sandwich_base': sandwich_base,
119 | 'sandwich_top': sandwich_top,
120 | }
121 | if candidates is not None:
122 | config = sample_candidates(candidates, **sandwich_args, shuffle=shuffle, grouping=grouping)
123 | else:
124 | config = sample_configs(choices, **sandwich_args)
125 | if isinstance(config, dict):
126 | config = [config]
127 | model_module = unwrap_model(model)
128 | #model_module.set_sample_config(config=config)
129 | elif mode == 'retrain':
130 | config = retrain_config
131 | model_module = unwrap_model(model)
132 | model_module.set_sample_config(config=config)
133 | if mixup_fn is not None:
134 | samples, targets = mixup_fn(samples, targets)
135 |
136 | optimizer.zero_grad()
137 | # this attribute is added by timm on one optimizer (adahessian)
138 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
139 |
140 | loss_value = 0.0
141 | with ExitStack() if not amp else torch.cuda.amp.autocast():
142 | if teacher_model:
143 | with torch.no_grad():
144 | teach_output = teacher_model(samples)
145 | _, teacher_label = teach_output.topk(1, 1, True, True)
146 | del teach_output
147 | teacher_label.squeeze_()
148 | factor = 1.0 / len(config)
149 | for cfg in config:
150 | model_module.set_sample_config(cfg)
151 | outputs = model(samples)
152 | # gt
153 | loss = 0.5 * factor * criterion(outputs, targets)
154 | bp_once(loss, loss_scaler, is_second_order)
155 | loss_value += loss.item()
156 | # teacher
157 | loss = 0.5 * factor * teach_loss(outputs, teacher_label)
158 | bp_once(loss, loss_scaler, is_second_order)
159 | loss_value += loss.item()
160 | else:
161 | factor = 1.0 / len(config)
162 | for cfg in config:
163 | model_module.set_sample_config(cfg)
164 | loss = factor * criterion(model(samples), targets)
165 | bp_once(loss, loss_scaler, is_second_order)
166 | loss_value += loss.item()
167 |
168 | if not math.isfinite(loss_value):
169 | print("Loss is {}, stopping training".format(loss_value))
170 | #sys.exit(1)
171 | continue
172 |
173 | if amp:
174 | loss_scaler(loss, optimizer, clip_grad=max_norm,
175 | parameters=model.parameters(), create_graph=is_second_order)
176 | else:
177 | optimizer.step()
178 |
179 | torch.cuda.synchronize()
180 | if model_ema is not None:
181 | model_ema.update(model)
182 |
183 | metric_logger.update(loss=loss_value)
184 | metric_logger.update(lr=optimizer.param_groups[0]["lr"])
185 |
186 | # only check at the end of epoch (avoid flooding)
187 | print("DEBUG:train {}".format(config), force=print2file)
188 |
189 | # gather the stats from all processes
190 | metric_logger.synchronize_between_processes()
191 | print("Averaged stats:", metric_logger)
192 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
193 |
194 | @torch.no_grad()
195 | def evaluate(data_loader, model, device, amp=True, choices=None, mode='super', retrain_config=None, print2file=False, candidates=None, eval_crops=1):
196 | criterion = torch.nn.CrossEntropyLoss()
197 |
198 | metric_logger = utils.MetricLogger(delimiter=" ")
199 | header = 'Test:'
200 |
201 | # switch to evaluation mode
202 | model.eval()
203 | if mode == 'super':
204 | if candidates is not None:
205 | config = sample_candidates(candidates, eval=True)
206 | else:
207 | config = sample_configs(choices, eval=True)
208 | config = [config]
209 | if utils.is_dist_avail_and_initialized():
210 | dist.broadcast_object_list(config, src=0)
211 | config = config[0]
212 | model_module = unwrap_model(model)
213 | model_module.set_sample_config(config=config)
214 | else:
215 | config = retrain_config
216 | model_module = unwrap_model(model)
217 | model_module.set_sample_config(config=config)
218 |
219 |
220 | print("DEBUG:eval sampled model config: {}".format(config), force=print2file)
221 | parameters = model_module.get_sampled_params_numel(config)
222 | print("DEBUG:eval sampled model parameters: {}".format(parameters), force=print2file)
223 |
224 | for images, target in metric_logger.log_every(data_loader, 10, header):
225 | images = images.to(device, non_blocking=True)
226 | target = target.to(device, non_blocking=True)
227 |
228 | if eval_crops > 1:
229 | bs, ncrops, c, h, w = images.size()
230 | images = images.view(-1, c, h, w)
231 |
232 | # compute output
233 | with ExitStack() if not amp else torch.cuda.amp.autocast():
234 | output = model(images)
235 | if eval_crops > 1:
236 | output = output.view(bs, ncrops, -1).mean(1)
237 | loss = criterion(output, target)
238 |
239 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
240 |
241 | batch_size = images.shape[0]
242 | metric_logger.update(loss=loss.item())
243 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
244 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
245 | # gather the stats from all processes
246 | metric_logger.synchronize_between_processes()
247 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
248 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
249 |
250 | metric_logger.update(n_parameters=parameters)
251 |
252 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
253 |
--------------------------------------------------------------------------------
/supernet_train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 | import os
4 | import sys
5 | import numpy as np
6 | import time
7 | import torch
8 | import torch.backends.cudnn as cudnn
9 | import json
10 | import yaml
11 | from collections import defaultdict
12 | from pathlib import Path
13 | from pprint import pprint
14 | from timm.data import Mixup
15 | from timm.models import create_model
16 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
17 | from timm.scheduler import create_scheduler
18 | from timm.optim import create_optimizer
19 | #from timm.utils import NativeScaler
20 | from lib.cuda import NativeScaler
21 | from lib.datasets import build_dataset
22 | from supernet_engine import train_one_epoch, evaluate
23 | from lib.samplers import RASampler
24 | from lib import utils
25 | from lib.config import cfg, update_config_from_file
26 | from model.supernet_transformer import Vision_TransformerSuper
27 |
28 |
29 | def get_args_parser():
30 | parser = argparse.ArgumentParser('AutoFormer training and evaluation script', add_help=False)
31 | parser.add_argument('--batch-size', default=64, type=int)
32 | parser.add_argument('--epochs', default=300, type=int)
33 | # config file
34 | parser.add_argument('--cfg',help='experiment configure file name',required=True,type=str)
35 |
36 | # custom parameters
37 | parser.add_argument('--platform', default='pai', type=str, choices=['itp', 'pai', 'aml'],
38 | help='Name of model to train')
39 | parser.add_argument('--teacher_model', default='', type=str,
40 | help='Name of teacher model to train')
41 | parser.add_argument('--relative_position', action='store_true')
42 | parser.add_argument('--gp', action='store_true')
43 | parser.add_argument('--change_qkv', action='store_true')
44 | parser.add_argument('--max_relative_position', type=int, default=14, help='max distance in relative position embedding')
45 |
46 | # Model parameters
47 | parser.add_argument('--model', default='', type=str, metavar='MODEL',
48 | help='Name of model to train')
49 | # AutoFormer config
50 | parser.add_argument('--mode', type=str, default='super', choices=['super', 'retrain'], help='mode of AutoFormer')
51 | parser.add_argument('--input-size', default=224, type=int)
52 | parser.add_argument('--patch_size', default=16, type=int)
53 |
54 | parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
55 | help='Dropout rate (default: 0.)')
56 | parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT',
57 | help='Drop path rate (default: 0.1)')
58 | parser.add_argument('--drop-block', type=float, default=None, metavar='PCT',
59 | help='Drop block rate (default: None)')
60 |
61 | parser.add_argument('--model-ema', action='store_true')
62 | parser.add_argument('--no-model-ema', action='store_false', dest='model_ema')
63 | # parser.set_defaults(model_ema=True)
64 | parser.add_argument('--model-ema-decay', type=float, default=0.99996, help='')
65 | parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='')
66 | parser.add_argument('--rpe_type', type=str, default='bias', choices=['bias', 'direct'])
67 | parser.add_argument('--post_norm', action='store_true')
68 | parser.add_argument('--no_abs_pos', action='store_true')
69 |
70 | # Optimizer parameters
71 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
72 | help='Optimizer (default: "adamw"')
73 | parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
74 | help='Optimizer Epsilon (default: 1e-8)')
75 | parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
76 | help='Optimizer Betas (default: None, use opt default)')
77 | parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
78 | help='Clip gradient norm (default: None, no clipping)')
79 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
80 | help='SGD momentum (default: 0.9)')
81 | parser.add_argument('--weight-decay', type=float, default=0.05,
82 | help='weight decay (default: 0.05)')
83 |
84 | # Learning rate schedule parameters
85 | parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
86 | help='LR scheduler (default: "cosine"')
87 | parser.add_argument('--lr', type=float, default=5e-4, metavar='LR',
88 | help='learning rate (default: 5e-4)')
89 | parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
90 | help='learning rate noise on/off epoch percentages')
91 | parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
92 | help='learning rate noise limit percent (default: 0.67)')
93 | parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
94 | help='learning rate noise std-dev (default: 1.0)')
95 | parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',
96 | help='warmup learning rate (default: 1e-6)')
97 | parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
98 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
99 | parser.add_argument('--lr-power', type=float, default=1.0,
100 | help='power of the polynomial lr scheduler')
101 |
102 | parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
103 | help='epoch interval to decay LR')
104 | parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
105 | help='epochs to warmup LR, if scheduler supports')
106 | parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
107 | help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
108 | parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
109 | help='patience epochs for Plateau LR scheduler (default: 10')
110 | parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
111 | help='LR decay rate (default: 0.1)')
112 |
113 | # Augmentation parameters
114 | parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
115 | help='Color jitter factor (default: 0.4)')
116 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
117 | help='Use AutoAugment policy. "v0" or "original". " + \
118 | "(default: rand-m9-mstd0.5-inc1)'),
119 | parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)')
120 | parser.add_argument('--train-interpolation', type=str, default='bicubic',
121 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
122 |
123 | parser.add_argument('--repeated-aug', action='store_true')
124 | parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug')
125 |
126 |
127 | parser.set_defaults(repeated_aug=True)
128 |
129 | # * Random Erase params
130 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
131 | help='Random erase prob (default: 0.25)')
132 | parser.add_argument('--remode', type=str, default='pixel',
133 | help='Random erase mode (default: "pixel")')
134 | parser.add_argument('--recount', type=int, default=1,
135 | help='Random erase count (default: 1)')
136 | parser.add_argument('--resplit', action='store_true', default=False,
137 | help='Do not random erase first (clean) augmentation split')
138 |
139 | # * Mixup params
140 | parser.add_argument('--mixup', type=float, default=0.8,
141 | help='mixup alpha, mixup enabled if > 0. (default: 0.8)')
142 | parser.add_argument('--cutmix', type=float, default=1.0,
143 | help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)')
144 | parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
145 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
146 | parser.add_argument('--mixup-prob', type=float, default=1.0,
147 | help='Probability of performing mixup or cutmix when either/both is enabled')
148 | parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
149 | help='Probability of switching to cutmix when both mixup and cutmix enabled')
150 | parser.add_argument('--mixup-mode', type=str, default='batch',
151 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
152 |
153 | # Dataset parameters
154 | parser.add_argument('--data-path', default='./data/imagenet/', type=str,
155 | help='dataset path')
156 | parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19', 'EVO_IMNET'],
157 | type=str, help='Image Net dataset path')
158 | parser.add_argument('--inat-category', default='name',
159 | choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'],
160 | type=str, help='semantic granularity')
161 |
162 | parser.add_argument('--output_dir', default='./logs/',
163 | help='path where to save, empty for no saving')
164 | parser.add_argument('--task', default='', help='task prefix')
165 | parser.add_argument('--device', default='cuda',
166 | help='device to use for training / testing')
167 | parser.add_argument('--seed', default=0, type=int)
168 | parser.add_argument('--resume', default='', help='resume from checkpoint')
169 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
170 | help='start epoch')
171 | parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
172 | parser.add_argument('--num_workers', default=10, type=int)
173 | parser.add_argument('--dist-eval', action='store_true', default=False, help='Enabling distributed evaluation')
174 | parser.add_argument('--pin-mem', action='store_true',
175 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
176 | parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem',
177 | help='')
178 | parser.set_defaults(pin_mem=True)
179 |
180 | # distributed training parameters
181 | parser.add_argument('--world_size', default=1, type=int,
182 | help='number of distributed processes')
183 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
184 |
185 | parser.add_argument('--amp', action='store_true')
186 | parser.add_argument('--no-amp', action='store_false', dest='amp')
187 | parser.set_defaults(amp=True)
188 |
189 | parser.add_argument('--print2file', action='store_true', default=False, help='save stdout to file')
190 | parser.add_argument('--candfile', default='', type=str, help='candidates json file')
191 | parser.add_argument('--group-by-dim', action='store_true', default=False, help='group candidates by embed_dim')
192 | parser.add_argument('--group-by-depth', action='store_true', default=False, help='group candidates by depth')
193 | parser.add_argument('--sandwich', default=0, type=int, help='number of interlayers in sandwich, default 0 to turn off')
194 | parser.add_argument('--no-sandwich-base', action='store_true', default=False, help='remove the base layer of sandwich')
195 | parser.add_argument('--no-sandwich-top', action='store_true', default=False, help='remove the top layer of sandwich')
196 | parser.add_argument('--switch-ln', action='store_true', default=False, help='Enabling switchable layernorm')
197 | parser.add_argument('--scale-attn', action='store_true', default=False, help='scale attention')
198 | parser.add_argument('--scale-mlp', action='store_true', default=False, help='scale mlp')
199 | parser.add_argument('--scale-embed', action='store_true', default=False, help='scale embed dim')
200 | parser.add_argument('--shuffle', action='store_true', default=False, help='shuffle chosen candidate')
201 | parser.add_argument('--eval-crops', default=1, type=int, choices=[1, 5, 10], help='number of crops for evaluation')
202 |
203 | return parser
204 |
205 | def main(args):
206 |
207 | utils.init_distributed_mode(args)
208 | update_config_from_file(args.cfg)
209 |
210 | print(args)
211 | args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
212 |
213 | if args.print2file:
214 | sys._stdout = sys.stdout
215 | sys.stdout = open(os.path.join(args.output_dir, f'out{utils.get_rank()}'), 'w', buffering=1)
216 |
217 | device = torch.device(args.device)
218 |
219 | # fix the seed for reproducibility
220 | seed = args.seed + utils.get_rank()
221 | torch.manual_seed(seed)
222 | np.random.seed(seed)
223 | # random.seed(seed)
224 | cudnn.benchmark = True
225 |
226 | dataset_train, args.nb_classes = build_dataset(is_train=True, args=args, folder_name="subImageNet" if args.data_set == "EVO_IMNET" else "train")
227 | dataset_val, _ = build_dataset(is_train=False, args=args, folder_name="val")
228 |
229 | if args.distributed:
230 | num_tasks = utils.get_world_size()
231 | global_rank = utils.get_rank()
232 | if args.repeated_aug:
233 | sampler_train = RASampler(
234 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
235 | )
236 | else:
237 | sampler_train = torch.utils.data.DistributedSampler(
238 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
239 | )
240 | if args.dist_eval:
241 | if len(dataset_val) % num_tasks != 0:
242 | print(
243 | 'Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
244 | 'This will slightly alter validation results as extra duplicate entries are added to achieve '
245 | 'equal num of samples per-process.')
246 | sampler_val = torch.utils.data.DistributedSampler(
247 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False)
248 | else:
249 | sampler_val = torch.utils.data.SequentialSampler(dataset_val)
250 | else:
251 | sampler_val = torch.utils.data.SequentialSampler(dataset_val)
252 | sampler_train = torch.utils.data.RandomSampler(dataset_train)
253 |
254 | data_loader_train = torch.utils.data.DataLoader(
255 | dataset_train, sampler=sampler_train,
256 | batch_size=args.batch_size,
257 | num_workers=args.num_workers,
258 | pin_memory=args.pin_mem,
259 | drop_last=True,
260 | )
261 |
262 | data_loader_val = torch.utils.data.DataLoader(
263 | dataset_val, batch_size=int(2 * args.batch_size),
264 | sampler=sampler_val, num_workers=args.num_workers,
265 | pin_memory=args.pin_mem, drop_last=False
266 | )
267 |
268 | mixup_fn = None
269 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
270 | if mixup_active:
271 | mixup_fn = Mixup(
272 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
273 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
274 | label_smoothing=args.smoothing, num_classes=args.nb_classes)
275 |
276 | print(f"Creating SuperVisionTransformer")
277 | print(cfg)
278 |
279 | choices = {'num_heads': cfg.SEARCH_SPACE.NUM_HEADS, 'mlp_ratio': cfg.SEARCH_SPACE.MLP_RATIO,
280 | 'embed_dim': cfg.SEARCH_SPACE.EMBED_DIM , 'depth': cfg.SEARCH_SPACE.DEPTH}
281 |
282 | model = Vision_TransformerSuper(img_size=args.input_size,
283 | patch_size=args.patch_size,
284 | embed_dim=cfg.SUPERNET.EMBED_DIM, depth=cfg.SUPERNET.DEPTH,
285 | num_heads=cfg.SUPERNET.NUM_HEADS,mlp_ratio=cfg.SUPERNET.MLP_RATIO,
286 | qkv_bias=True, drop_rate=args.drop,
287 | drop_path_rate=args.drop_path,
288 | gp=args.gp,
289 | num_classes=args.nb_classes,
290 | max_relative_position=args.max_relative_position,
291 | relative_position=args.relative_position,
292 | change_qkv=args.change_qkv, abs_pos=not args.no_abs_pos,
293 | choices=choices, switch_ln=args.switch_ln,
294 | scale_attn=args.scale_attn, scale_mlp=args.scale_mlp,
295 | scale_embed=args.scale_embed)
296 |
297 | model.to(device)
298 | if args.teacher_model:
299 | teacher_model = create_model(
300 | args.teacher_model,
301 | pretrained=True,
302 | num_classes=args.nb_classes,
303 | )
304 | teacher_model.to(device)
305 | teacher_loss = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
306 | else:
307 | teacher_model = None
308 | teacher_loss = None
309 |
310 | model_ema = None
311 |
312 | print(model)
313 |
314 | model_without_ddp = model
315 | if args.distributed:
316 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
317 | model_without_ddp = model.module
318 |
319 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
320 | print('number of params:', n_parameters)
321 |
322 | linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0
323 | args.lr = linear_scaled_lr
324 | optimizer = create_optimizer(args, model_without_ddp)
325 | loss_scaler = NativeScaler()
326 | lr_scheduler, _ = create_scheduler(args, optimizer)
327 |
328 | # criterion = LabelSmoothingCrossEntropy()
329 |
330 | if args.mixup > 0.:
331 | # smoothing is handled with mixup label transform
332 | criterion = SoftTargetCrossEntropy()
333 | elif args.smoothing:
334 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
335 | else:
336 | criterion = torch.nn.CrossEntropyLoss()
337 |
338 | output_dir = Path(args.output_dir)
339 |
340 | if not output_dir.exists():
341 | output_dir.mkdir(parents=True)
342 | # save config for later experiments
343 | if args.output_dir and utils.is_main_process():
344 | with open(output_dir / "config.yaml", 'w') as f:
345 | f.write(args_text)
346 | if args.resume:
347 | if args.resume.startswith('https'):
348 | checkpoint = torch.hub.load_state_dict_from_url(
349 | args.resume, map_location='cpu', check_hash=True)
350 | else:
351 | checkpoint = torch.load(args.resume, map_location='cpu')
352 | model_without_ddp.load_state_dict(checkpoint['model'])
353 | if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
354 | optimizer.load_state_dict(checkpoint['optimizer'])
355 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
356 | args.start_epoch = checkpoint['epoch'] + 1
357 | if 'scaler' in checkpoint:
358 | loss_scaler.load_state_dict(checkpoint['scaler'])
359 | if args.model_ema:
360 | utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema'])
361 |
362 | retrain_config = None
363 | if args.mode == 'retrain' and "RETRAIN" in cfg:
364 | retrain_config = {'layer_num': cfg.RETRAIN.DEPTH, 'embed_dim': [cfg.RETRAIN.EMBED_DIM]*cfg.RETRAIN.DEPTH,
365 | 'num_heads': cfg.RETRAIN.NUM_HEADS,'mlp_ratio': cfg.RETRAIN.MLP_RATIO}
366 | candidates = None
367 | cand_index_group = None
368 | if args.candfile:
369 | cand_dict = json.load(open(args.candfile))
370 | candidates = [cand for cand_list in cand_dict.values() for cand in cand_list]
371 | if args.group_by_dim or args.group_by_depth:
372 | cand_index_group = defaultdict(list)
373 | for idx, cand in enumerate(candidates):
374 | k = ()
375 | k = k + (cand['embed_dim'][0],) if args.group_by_dim else k
376 | k = k + (cand['layer_num'],) if args.group_by_depth else k
377 | cand_index_group[k].append(idx)
378 | cand_index_group = list(cand_index_group.values())
379 |
380 | if args.eval:
381 | if args.candfile:
382 | batch_stats = []
383 | for retrain_config in candidates:
384 | test_stats = evaluate(data_loader_val, model, device, mode = args.mode, retrain_config=retrain_config, print2file=args.print2file, eval_crops=args.eval_crops)
385 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
386 | test_stats['config'] = retrain_config
387 | batch_stats.append(test_stats)
388 | if args.output_dir and utils.is_main_process():
389 | json.dump(batch_stats, open(output_dir / 'results.json', 'w'), indent=2)
390 | else:
391 | test_stats = evaluate(data_loader_val, model, device, mode = args.mode, retrain_config=retrain_config, print2file=args.print2file, eval_crops=args.eval_crops)
392 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
393 |
394 | return
395 |
396 | print("Start training")
397 | start_time = time.time()
398 | max_accuracy = 0.0
399 |
400 | for epoch in range(args.start_epoch, args.epochs):
401 | if args.distributed:
402 | data_loader_train.sampler.set_epoch(epoch)
403 |
404 | train_stats = train_one_epoch(
405 | model, criterion, data_loader_train,
406 | optimizer, device, epoch, loss_scaler,
407 | args.clip_grad, model_ema, mixup_fn,
408 | amp=args.amp, teacher_model=teacher_model,
409 | teach_loss=teacher_loss,
410 | choices=choices, mode = args.mode, retrain_config=retrain_config,
411 | print2file=args.print2file, candidates=candidates, sandwich=args.sandwich,
412 | sandwich_base=not args.no_sandwich_base, sandwich_top=not args.no_sandwich_top,
413 | shuffle=args.shuffle, grouping=cand_index_group,
414 | )
415 |
416 | lr_scheduler.step(epoch)
417 | if args.output_dir and (epoch + 1) % 10 == 0:
418 | checkpoint_paths = [output_dir / 'checkpoint-{}.pth'.format(epoch+1)]
419 | for checkpoint_path in checkpoint_paths:
420 | utils.save_on_master({
421 | 'model': model_without_ddp.state_dict(),
422 | 'optimizer': optimizer.state_dict(),
423 | 'lr_scheduler': lr_scheduler.state_dict(),
424 | 'epoch': epoch,
425 | # 'model_ema': get_state_dict(model_ema),
426 | 'scaler': loss_scaler.state_dict(),
427 | 'args': args,
428 | }, checkpoint_path)
429 |
430 | test_stats = evaluate(data_loader_val, model, device, amp=args.amp, choices=choices, mode = args.mode, retrain_config=retrain_config, print2file=args.print2file, candidates=candidates, eval_crops=args.eval_crops)
431 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
432 | max_accuracy = max(max_accuracy, test_stats["acc1"])
433 | print(f'Max accuracy: {max_accuracy:.2f}%')
434 |
435 | log_stats = {'datetime': datetime.datetime.now().strftime("%m/%d %H:%M"),
436 | **{f'train_{k}': v for k, v in train_stats.items()},
437 | **{f'test_{k}': v for k, v in test_stats.items()},
438 | 'epoch': epoch}
439 |
440 | if args.output_dir and utils.is_main_process():
441 | with (output_dir / "log.txt").open("a") as f:
442 | f.write(json.dumps(log_stats) + "\n")
443 |
444 | total_time = time.time() - start_time
445 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
446 | print('Training time {}'.format(total_time_str))
447 |
448 |
449 | if __name__ == '__main__':
450 | now = datetime.datetime.now().strftime("%Y_%m_%d-%H_%M_%S")
451 | parser = argparse.ArgumentParser('AutoFormer training and evaluation script', parents=[get_args_parser()])
452 | args = parser.parse_args()
453 | if args.task:
454 | suffix = args.task
455 | elif args.resume:
456 | resume_folder = os.path.basename(os.path.dirname(os.path.normpath(args.resume)))
457 | suffix = resume_folder.partition('@')[-1]
458 | else:
459 | suffix = ''
460 | sep = '@' if suffix else ''
461 | args.output_dir = os.path.join(args.output_dir, 'test' if args.eval else 'train', now+sep+suffix)
462 | Path(args.output_dir).mkdir(parents=True, exist_ok=True)
463 | main(args)
464 |
--------------------------------------------------------------------------------
/two_step_search.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 | import numpy as np
4 | import time
5 | import torch
6 | import torch.backends.cudnn as cudnn
7 | import json
8 | import yaml
9 | import random
10 | from pathlib import Path
11 | from timm.data import Mixup
12 | from timm.models import create_model
13 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
14 | from timm.utils import NativeScaler
15 | from lib.datasets import build_dataset
16 | from lib.samplers import RASampler
17 | from lib import utils
18 | from lib.config import cfg, update_config_from_file
19 | from lib.score_maker import ScoreMaker
20 | from model.supernet_transformer import Vision_TransformerSuper
21 | from evolution_pre_train import Searcher
22 |
23 |
24 | def get_args_parser():
25 | parser = argparse.ArgumentParser('AutoFormer training and evaluation script', add_help=False)
26 | parser.add_argument('--batch-size', default=64, type=int)
27 | # config file
28 | parser.add_argument('--cfg',help='experiment configure file name',required=True,type=str)
29 |
30 | # custom parameters
31 | parser.add_argument('--platform', default='pai', type=str, choices=['itp', 'pai', 'aml'],
32 | help='Name of model to train')
33 | parser.add_argument('--teacher_model', default='', type=str,
34 | help='Name of teacher model to train')
35 | parser.add_argument('--relative_position', action='store_true')
36 | parser.add_argument('--gp', action='store_true')
37 | parser.add_argument('--change_qkv', action='store_true')
38 | parser.add_argument('--max_relative_position', type=int, default=14, help='max distance in relative position embedding')
39 |
40 | # Model parameters
41 | parser.add_argument('--model', default='', type=str, metavar='MODEL',
42 | help='Name of model to train')
43 |
44 | parser.add_argument('--input-size', default=224, type=int)
45 | parser.add_argument('--patch_size', default=16, type=int)
46 |
47 | parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
48 | help='Dropout rate (default: 0.)')
49 | parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT',
50 | help='Drop path rate (default: 0.1)')
51 | parser.add_argument('--drop-block', type=float, default=None, metavar='PCT',
52 | help='Drop block rate (default: None)')
53 |
54 | parser.add_argument('--model-ema', action='store_true')
55 | parser.add_argument('--no-model-ema', action='store_false', dest='model_ema')
56 | # parser.set_defaults(model_ema=True)
57 | parser.add_argument('--model-ema-decay', type=float, default=0.99996, help='')
58 | parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='')
59 | parser.add_argument('--rpe_type', type=str, default='bias', choices=['bias', 'direct'])
60 | parser.add_argument('--post_norm', action='store_true')
61 | parser.add_argument('--no_abs_pos', action='store_true')
62 |
63 | # Optimizer parameters
64 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
65 | help='Optimizer (default: "adamw"')
66 | parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
67 | help='Optimizer Epsilon (default: 1e-8)')
68 | parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
69 | help='Optimizer Betas (default: None, use opt default)')
70 | parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
71 | help='Clip gradient norm (default: None, no clipping)')
72 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
73 | help='SGD momentum (default: 0.9)')
74 | parser.add_argument('--weight-decay', type=float, default=0.05,
75 | help='weight decay (default: 0.05)')
76 |
77 | # Learning rate schedule parameters
78 | parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
79 | help='LR scheduler (default: "cosine"')
80 | parser.add_argument('--lr', type=float, default=5e-4, metavar='LR',
81 | help='learning rate (default: 5e-4)')
82 | parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
83 | help='learning rate noise on/off epoch percentages')
84 | parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
85 | help='learning rate noise limit percent (default: 0.67)')
86 | parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
87 | help='learning rate noise std-dev (default: 1.0)')
88 | parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',
89 | help='warmup learning rate (default: 1e-6)')
90 | parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
91 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
92 | parser.add_argument('--lr-power', type=float, default=1.0,
93 | help='power of the polynomial lr scheduler')
94 |
95 | parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
96 | help='epoch interval to decay LR')
97 | parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
98 | help='epochs to warmup LR, if scheduler supports')
99 | parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
100 | help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
101 | parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
102 | help='patience epochs for Plateau LR scheduler (default: 10')
103 | parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
104 | help='LR decay rate (default: 0.1)')
105 |
106 | # Augmentation parameters
107 | parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
108 | help='Color jitter factor (default: 0.4)')
109 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
110 | help='Use AutoAugment policy. "v0" or "original". " + \
111 | "(default: rand-m9-mstd0.5-inc1)'),
112 | parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)')
113 | parser.add_argument('--train-interpolation', type=str, default='bicubic',
114 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
115 |
116 | parser.add_argument('--repeated-aug', action='store_true')
117 | parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug')
118 |
119 | # * Random Erase params
120 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
121 | help='Random erase prob (default: 0.25)')
122 | parser.add_argument('--remode', type=str, default='pixel',
123 | help='Random erase mode (default: "pixel")')
124 | parser.add_argument('--recount', type=int, default=1,
125 | help='Random erase count (default: 1)')
126 | parser.add_argument('--resplit', action='store_true', default=False,
127 | help='Do not random erase first (clean) augmentation split')
128 |
129 | # * Mixup params
130 | parser.add_argument('--mixup', type=float, default=0.8,
131 | help='mixup alpha, mixup enabled if > 0. (default: 0.8)')
132 | parser.add_argument('--cutmix', type=float, default=1.0,
133 | help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)')
134 | parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
135 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
136 | parser.add_argument('--mixup-prob', type=float, default=1.0,
137 | help='Probability of performing mixup or cutmix when either/both is enabled')
138 | parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
139 | help='Probability of switching to cutmix when both mixup and cutmix enabled')
140 | parser.add_argument('--mixup-mode', type=str, default='batch',
141 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
142 |
143 | # Dataset parameters
144 | parser.add_argument('--data-path', default='./data/imagenet/', type=str,
145 | help='dataset path')
146 | parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19'],
147 | type=str, help='Image Net dataset path')
148 | parser.add_argument('--inat-category', default='name',
149 | choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'],
150 | type=str, help='semantic granularity')
151 | parser.add_argument('--output_dir', default='./',
152 | help='path where to save, empty for no saving')
153 | parser.add_argument('--device', default='cuda',
154 | help='device to use for training / testing')
155 | parser.add_argument('--seed', default=0, type=int)
156 | parser.add_argument('--resume', default='', help='resume from checkpoint')
157 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
158 | help='start epoch')
159 | parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
160 | parser.add_argument('--num_workers', default=10, type=int)
161 | parser.add_argument('--dist-eval', action='store_true', default=False, help='Enabling distributed evaluation')
162 | parser.add_argument('--pin-mem', action='store_true',
163 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
164 | parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem')
165 | parser.set_defaults(pin_mem=True)
166 |
167 | # distributed training parameters
168 | parser.add_argument('--world_size', default=1, type=int,
169 | help='number of distributed processes')
170 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
171 | parser.add_argument('--amp', action='store_true')
172 | parser.add_argument('--no-amp', action='store_false', dest='amp')
173 | parser.add_argument('--score-method', default='left_super_taylor9', type=str,
174 | help='Score method in step two')
175 | parser.add_argument('--block-score-method-for-head', default='balance_taylor5_max_dim', type=str,
176 | help='Score method for head in step one')
177 | parser.add_argument('--block-score-method-for-mlp', default='deeper_is_better', type=str,
178 | help='Score method for mlp in step one')
179 | parser.add_argument('--candidate-path', default='the path of interval candidates',type=str)
180 | parser.add_argument('--super-model-size', default='T', type=str)
181 | parser.add_argument('--interval-cands-output', default='./out/interval_candidates.pt', type=str)
182 | parser.add_argument('--min_param_limits', default=4, type=float)
183 | parser.add_argument('--param_limits', default=12, type=float)
184 | parser.add_argument('--param-interval', default=2, type=float)
185 | parser.add_argument('--cand-per-interval', default=1, type=int)
186 | parser.add_argument('--population-num', default=50, type=int)
187 | parser.add_argument('--max-epochs', default=20, type=int)
188 | parser.add_argument('--select-num', type=int, default=20)
189 | parser.add_argument('--m_prob', type=float, default=0.2)
190 | parser.add_argument('--s_prob', type=float, default=0.4)
191 | parser.add_argument('--crossover-num', type=int, default=25)
192 | parser.add_argument('--mutation-num', type=int, default=25)
193 | parser.add_argument('--data-free', action='store_true', help='False if use the data to get gradient.')
194 | parser.add_argument('--reallocate', action='store_true', help='if reallocate when random and evolution search.')
195 | parser.add_argument('--avg-dim-sample', action='store_true', help='True if sample the dimension in uniform distribution.')
196 | parser.add_argument('--search-mode', default='iteration', choices=['iteration', 'random', 'evolution'], type=str, help='The mode of search the candidates.')
197 | parser.set_defaults(amp=True)
198 |
199 | return parser
200 |
201 |
202 | def main(args):
203 |
204 | utils.init_distributed_mode(args)
205 | update_config_from_file(args.cfg)
206 |
207 | print(args)
208 | args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
209 |
210 | device = torch.device(args.device)
211 |
212 | # fix the seed for reproducibility
213 | seed = args.seed + utils.get_rank()
214 | torch.manual_seed(seed)
215 | np.random.seed(seed)
216 | # random.seed(seed)
217 | cudnn.benchmark = True
218 |
219 | dataset_sub_train, args.nb_classes = build_dataset(is_train=True, args=args, folder_name="subImageNet")
220 |
221 | if args.distributed:
222 | num_tasks = utils.get_world_size()
223 | global_rank = utils.get_rank()
224 | if args.repeated_aug:
225 | sampler_sub_train = RASampler(
226 | dataset_sub_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
227 | )
228 | else:
229 | sampler_sub_train = torch.utils.data.DistributedSampler(
230 | dataset_sub_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
231 | )
232 | else:
233 | sampler_sub_train = torch.utils.data.RandomSampler(dataset_sub_train)
234 |
235 | data_loader_sub_train = torch.utils.data.DataLoader(
236 | dataset_sub_train, batch_size=args.batch_size,
237 | sampler=sampler_sub_train, num_workers=args.num_workers,
238 | pin_memory=args.pin_mem, drop_last=False
239 | )
240 |
241 | mixup_fn = None
242 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
243 | if mixup_active:
244 | mixup_fn = Mixup(
245 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
246 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
247 | label_smoothing=args.smoothing, num_classes=args.nb_classes)
248 |
249 | print(f"Creating SuperVisionTransformer")
250 | print(cfg)
251 | model = Vision_TransformerSuper(img_size=args.input_size,
252 | patch_size=args.patch_size,
253 | embed_dim=cfg.SUPERNET.EMBED_DIM, depth=cfg.SUPERNET.DEPTH,
254 | num_heads=cfg.SUPERNET.NUM_HEADS, mlp_ratio=cfg.SUPERNET.MLP_RATIO,
255 | qkv_bias=True, drop_rate=args.drop,
256 | drop_path_rate=args.drop_path,
257 | gp=args.gp,
258 | num_classes=args.nb_classes,
259 | max_relative_position=args.max_relative_position,
260 | relative_position=args.relative_position,
261 | change_qkv=args.change_qkv, abs_pos=not args.no_abs_pos)
262 |
263 | choices = {'num_heads': cfg.SEARCH_SPACE.NUM_HEADS, 'mlp_ratio': cfg.SEARCH_SPACE.MLP_RATIO,
264 | 'embed_dim': cfg.SEARCH_SPACE.EMBED_DIM, 'depth': cfg.SEARCH_SPACE.DEPTH}
265 |
266 | model.to(device)
267 |
268 | model_without_ddp = model
269 | if args.distributed:
270 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
271 | model_without_ddp = model.module
272 |
273 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
274 | print('number of params:', n_parameters)
275 |
276 | if args.mixup > 0.:
277 | # smoothing is handled with mixup label transform
278 | criterion = SoftTargetCrossEntropy()
279 | elif args.smoothing:
280 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
281 | else:
282 | criterion = torch.nn.CrossEntropyLoss()
283 |
284 | output_dir = Path(args.output_dir)
285 |
286 | if not output_dir.exists():
287 | output_dir.mkdir(parents=True)
288 |
289 | if args.resume:
290 | print('resume')
291 | if args.resume.startswith('https'):
292 | checkpoint = torch.hub.load_state_dict_from_url(
293 | args.resume, map_location='cpu', check_hash=True)
294 | else:
295 | checkpoint = torch.load(args.resume, map_location='cpu')
296 | model_without_ddp.load_state_dict(checkpoint['model'])
297 | if 'epoch' in checkpoint:
298 | args.start_epoch = checkpoint['epoch'] + 1
299 |
300 | print("Start search candidate")
301 | score_maker = ScoreMaker()
302 | score_maker.get_gradient(model, criterion, data_loader_sub_train, args, choices, device, mixup_fn=mixup_fn)
303 | evolution_searcher = Searcher(args, device, model, model_without_ddp, choices, output_dir, score_maker)
304 | print(evolution_searcher.get_params_range())
305 | interval_candidates = evolution_searcher.search(args.interval_cands_output)
306 | score_maker.drop_gradient()
307 |
308 |
309 | if __name__ == '__main__':
310 | parser = argparse.ArgumentParser('AutoFormer training and evaluation script', parents=[get_args_parser()])
311 | args = parser.parse_args()
312 | main(args)
313 |
--------------------------------------------------------------------------------