├── .idea ├── .gitignore ├── PreNAS.iml ├── inspectionProfiles │ └── profiles_settings.xml ├── modules.xml └── vcs.xml ├── 01_zero_shot_search.sh ├── 02_one_shot_training.sh ├── 03_evaluation.sh ├── LICENSE.txt ├── README.md ├── candidates_to_choices.py ├── evolution_pre_train.py ├── experiments └── supernet │ ├── base.yaml │ ├── small.yaml │ ├── supernet-B.yaml │ ├── supernet-S.yaml │ ├── supernet-T.yaml │ └── tiny.yaml ├── figure └── overview.svg ├── interval_cands ├── base.json ├── small.json └── tiny.json ├── lib ├── config.py ├── cuda.py ├── datasets.py ├── imagenet_withhold.py ├── samplers.py ├── score_maker.py ├── subImageNet.py └── utils.py ├── model ├── module │ ├── Linear_super.py │ ├── __init__.py │ ├── embedding_super.py │ ├── layernorm_super.py │ ├── multihead_super.py │ ├── qkv_super.py │ └── scaling_super.py ├── supernet_transformer.py └── utils.py ├── requirements.txt ├── supernet_engine.py ├── supernet_train.py └── two_step_search.py /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Datasource local storage ignored files 5 | /dataSources/ 6 | /dataSources.local.xml 7 | # Editor-based HTTP Client requests 8 | /httpRequests/ 9 | -------------------------------------------------------------------------------- /.idea/PreNAS.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /01_zero_shot_search.sh: -------------------------------------------------------------------------------- 1 | ### for tiny search space 2 | python -m torch.distributed.launch \ 3 | --nproc_per_node=8 \ 4 | --use_env \ 5 | two_step_search.py \ 6 | --gp \ 7 | --change_qk \ 8 | --relative_position \ 9 | --dist-eval \ 10 | --batch-size 64 \ 11 | --data-free \ 12 | --score-method left_super_taylor6 \ 13 | --block-score-method-for-head balance_taylor6_max_dim \ 14 | --block-score-method-for-mlp balance_taylor6_max_dim \ 15 | --cand-per-interval 1 \ 16 | --param-interval 1.0 \ 17 | --min_param_limits 5 \ 18 | --param_limits 12 \ 19 | --data-path ../datas/imagenet \ 20 | --cfg ./experiments/supernet/supernet-T.yaml \ 21 | --interval-cands-output ./interval_cands/tiny.json 22 | 23 | python candidates_to_choices.py ./interval_cands/tiny.json ./experiments/supernet/tiny.yaml 24 | 25 | ### for small search space 26 | #python -m torch.distributed.launch \ 27 | #--nproc_per_node=8 \ 28 | #--use_env \ 29 | #two_step_search.py \ 30 | #--gp \ 31 | #--change_qk \ 32 | #--relative_position \ 33 | #--dist-eval \ 34 | #--batch-size 64 \ 35 | #--data-free \ 36 | #--score-method left_super_taylor6 \ 37 | #--block-score-method-for-head balance_taylor6_max_dim \ 38 | #--block-score-method-for-mlp balance_taylor6_max_dim \ 39 | #--cand-per-interval 1 \ 40 | #--param-interval 5.0 \ 41 | #--min_param_limits 13 \ 42 | #--param_limits 33 \ 43 | #--data-path ../datas/imagenet \ 44 | #--cfg ./experiments/supernet/supernet-S.yaml \ 45 | #--interval-cands-output ./interval_cands/small.json 46 | # 47 | #python candidates_to_choices.py ./interval_cands/small.json ./experiments/supernet/small.yaml 48 | 49 | ### for base search space 50 | #python -m torch.distributed.launch \ 51 | #--nproc_per_node=8 \ 52 | #--use_env two_step_search.py \ 53 | #--gp \ 54 | #--change_qk \ 55 | #--relative_position \ 56 | #--dist-eval \ 57 | #--batch-size 64 \ 58 | #--data-free \ 59 | #--score-method left_super_taylor6 \ 60 | #--block-score-method-for-head balance_taylor6_max_dim \ 61 | #--block-score-method-for-mlp balance_taylor6_max_dim \ 62 | #--cand-per-interval 1 \ 63 | #--param-interval 12.0 \ 64 | #--min_param_limits 30 \ 65 | #--param_limits 70 \ 66 | #--data-path ../datas/imagenet \ 67 | #--cfg ./experiments/supernet/supernet-B.yaml \ 68 | #--interval-cands-output ./interval_cands/base.json 69 | # 70 | #python candidates_to_choices.py ./interval_cands/base.json ./experiments/supernet/base.yaml -------------------------------------------------------------------------------- /02_one_shot_training.sh: -------------------------------------------------------------------------------- 1 | ### train PreNAS_tiny 2 | python -m torch.distributed.launch \ 3 | --nproc_per_node=8 \ 4 | --use_env \ 5 | supernet_train.py \ 6 | --gp \ 7 | --change_qk \ 8 | --relative_position \ 9 | --mode super \ 10 | --dist-eval \ 11 | --epochs 500 \ 12 | --warmup-epochs 20 \ 13 | --batch-size 128 \ 14 | --min-lr 1e-7 \ 15 | --group-by-dim \ 16 | --group-by-depth \ 17 | --mixup-mode elem \ 18 | --aa rand-n3-m10-mstd0.5-inc1 \ 19 | --recount 2 \ 20 | --data-path ../datas/imagenet \ 21 | --cfg ./experiments/supernet/base.yaml \ 22 | --candfile ./interval_cands/base.json \ 23 | --output_dir ./output/tiny 24 | 25 | ### train PreNAS_small 26 | #python -m torch.distributed.launch \ 27 | #--nproc_per_node=8 \ 28 | #--use_env \ 29 | #supernet_train.py \ 30 | #--gp \ 31 | #--change_qk \ 32 | #--relative_position \ 33 | #--mode super \ 34 | #--dist-eval \ 35 | #--epochs 500 \ 36 | #--warmup-epochs 20 \ 37 | #--batch-size 128 \ 38 | #--group-by-dim \ 39 | #--group-by-depth \ 40 | #--mixup-mode elem \ 41 | #--aa v0r-mstd0.5 \ 42 | #--data-path ../datas/imagenet \ 43 | #--cfg ./experiments/supernet/small.yaml \ 44 | #--candfile ./interval_cands/small.json \ 45 | #--output_dir ./output/small 46 | 47 | ### train PreNAS_base 48 | #python -m torch.distributed.launch \ 49 | #--nproc_per_node=8 \ 50 | #--use_env \ 51 | #supernet_train.py \ 52 | #--gp \ 53 | #--change_qk \ 54 | #--relative_position \ 55 | #--mode super \ 56 | #--dist-eval \ 57 | #--epochs 500 \ 58 | #--warmup-epochs 20 \ 59 | #--batch-size 128 \ 60 | #--min-lr 1e-7 \ 61 | #--group-by-dim \ 62 | #--group-by-depth \ 63 | #--mixup-mode elem \ 64 | #--aa rand-n3-m10-mstd0.5-inc1 \ 65 | #--recount 2 \ 66 | #--data-path ../datas/imagenet \ 67 | #--cfg ./experiments/supernet/base.yaml \ 68 | #--candfile ./interval_cands/base.json \ 69 | #--output_dir ./output/base 70 | -------------------------------------------------------------------------------- /03_evaluation.sh: -------------------------------------------------------------------------------- 1 | ### eval PreNAS_tiny 2 | python -m torch.distributed.launch \ 3 | --nproc_per_node=8 \ 4 | --use_env \ 5 | supernet_train.py \ 6 | --gp \ 7 | --change_qk \ 8 | --relative_position \ 9 | --mode retrain \ 10 | --dist-eval \ 11 | --batch-size 128 \ 12 | --eval \ 13 | --data-path ../datas/imagenet \ 14 | --cfg ./experiments/supernet/tiny.yaml \ 15 | --candfile ./interval_cands/tiny.json \ 16 | --resume ./output/tiny/checkpoint.pth 17 | 18 | ### eval PreNAS_small 19 | #python -m torch.distributed.launch \ 20 | #--nproc_per_node=8 \ 21 | #--use_env \ 22 | #supernet_train.py \ 23 | #--gp \ 24 | #--change_qk \ 25 | #--relative_position \ 26 | #--mode retrain \ 27 | #--dist-eval \ 28 | #--batch-size 128 \ 29 | #--eval \ 30 | #--data-path ../datas/imagenet \ 31 | #--cfg ./experiments/supernet/small.yaml \ 32 | #--candfile ./interval_cands/small.json \ 33 | #--resume ./output/small/checkpoint.pth 34 | 35 | ### eval PreNAS_base 36 | #python -m torch.distributed.launch \ 37 | #--nproc_per_node=8 \ 38 | #--use_env \ 39 | #supernet_train.py \ 40 | #--gp \ 41 | #--change_qk \ 42 | #--relative_position \ 43 | #--mode retrain \ 44 | #--dist-eval \ 45 | #--batch-size 128 \ 46 | #--eval \ 47 | #--data-path ../datas/imagenet \ 48 | #--cfg ./experiments/supernet/base.yaml \ 49 | #--candfile ./interval_cands/base.json \ 50 | #--resume ./output/base/checkpoint.pth 51 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2022-2023 Alibaba Group Holding Limited. 190 | 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PreNAS: Preferred One-Shot Learning Towards Efficient Neural Architecture Search 2 | 3 | PreNAS is a novel learning paradigm that integrates one-shot and zero-shot NAS techniques to enhance search efficiency and training effectiveness. 4 | This search-free approach outperforms current state-of-the-art one-shot NAS methods for both Vision Transformer and convolutional architectures, 5 | as confirmed by its superior performance when the code is released. 6 | 7 | >Wang H, Ge C, Chen H and Sun X. PreNAS: Preferred One-Shot Learning Towards Efficient Neural Architecture Search. ICML 2023. 8 | 9 | Paper link: [arXiv](https://arxiv.org/abs/2304.14636) 10 | 11 | ## Overview 12 |
13 |
14 |
15 | Previous one-shot NAS samples all architectures in the search space when one-shot training of the supernet for better evaluation in evolution search. 16 | Instead, PreNAS first searches the target architectures via a zero-cost proxy and next applies preferred one-shot training to supernet. 17 | PreNAS improves the Pareto Frontier benefited from the preferred one-shot learning and is search-free after training by offering the models with the 18 | advance selected architectures from the zero-cost search. 19 | 20 | ## Environment Setup 21 | 22 | To set up the environment you can easily run the following command: 23 | ```buildoutcfg 24 | conda create -n PreNAS python=3.7 25 | conda activate PreNAS 26 | pip install -r requirements.txt 27 | ``` 28 | 29 | ## Data Preparation 30 | You need to download the [ImageNet-2012](http://www.image-net.org/) to the folder `../data/imagenet`. 31 | 32 | ## Run example 33 | The code was run on 8 x 80G A100. 34 | - Zero-Shot Search 35 | 36 | `bash 01_zero_shot_search.sh` 37 | 38 | - One-Shot Training 39 | 40 | `bash 02_one_shot_training.sh` 41 | 42 | - Evaluation 43 | 44 | `bash 03_evaluation.sh` 45 | 46 | ## Model Zoo 47 | 48 | | Model | TOP-1 (%) | TOP-5 (%) | #Params (M) | FLOPs (G) | Download Link | 49 | | ------------ | ---------- | ------------- | ------------- | --------- | ------------- | 50 | | PreNAS-Ti | 77.1 | 93.4 | 5.9 | 1.4 | [AliCloud](https://idstcv.oss-cn-zhangjiakou.aliyuncs.com/PreNAS/supernet-tiny.pth) | 51 | | PreNAS-S | 81.8 | 95.9 | 22.9 | 5.1 | [AliCloud](https://idstcv.oss-cn-zhangjiakou.aliyuncs.com/PreNAS/supernet-small.pth) | 52 | | PreNAS-B | 82.6 | 96.0 | 54 | 11 | [AliCloud](https://idstcv.oss-cn-zhangjiakou.aliyuncs.com/PreNAS/supernet-base.pth) | 53 | 54 | ## Bibtex 55 | 56 | If PreNAS is useful for you, please consider to cite it. Thank you! :) 57 | ```bibtex 58 | @InProceedings{PreNAS, 59 | title = {PreNAS: Preferred One-Shot Learning Towards Efficient Neural Architecture Search}, 60 | author = {Wang, Haibin and Ge, Ce and Chen, Hesen and Sun, Xiuyu}, 61 | booktitle = {International Conference on Machine Learning (ICML)}, 62 | month = {July}, 63 | year = {2023} 64 | } 65 | ``` 66 | 67 | ## Acknowledgements 68 | 69 | The codes are inspired by [AutoFormer](https://github.com/microsoft/Cream/tree/main/AutoFormer). 70 | -------------------------------------------------------------------------------- /candidates_to_choices.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import json 3 | from collections import defaultdict 4 | 5 | def candidate_to_choices(candidate_path, topN=float('inf')): 6 | interval_cands = json.load(open(candidate_path)) 7 | 8 | # init 9 | new_embed_dim = [] 10 | new_mlp_ratio = defaultdict(lambda : defaultdict(list)) 11 | new_num_heads = defaultdict(lambda : defaultdict(list)) 12 | new_depth = defaultdict(list) 13 | 14 | for cand_list in interval_cands.values(): 15 | for i in range(min(topN, len(cand_list))): 16 | cur_cand = cand_list[i] 17 | # embed dim 18 | embed_dim = cur_cand['embed_dim'][0] 19 | new_embed_dim.append(embed_dim) if embed_dim not in new_embed_dim else None 20 | # depth 21 | depth = cur_cand['layer_num'] 22 | new_depth[embed_dim].append(depth) if depth not in new_depth[embed_dim] else None 23 | # mlp & heads 24 | for layer_id, (mlp_ratio, num_heads) in enumerate(zip(cur_cand['mlp_ratio'], cur_cand['num_heads'])): 25 | pt_mlp_ratio = new_mlp_ratio[embed_dim][layer_id] 26 | pt_mlp_ratio.append(mlp_ratio) if mlp_ratio not in pt_mlp_ratio else None 27 | pt_num_heads = new_num_heads[embed_dim][layer_id] 28 | pt_num_heads.append(num_heads) if num_heads not in pt_num_heads else None 29 | 30 | return {'embed_dim': sorted(new_embed_dim), 31 | 'mlp_ratio': {dim: [sorted(ratios[layer]) for layer in sorted(ratios)] for dim, ratios in new_mlp_ratio.items()}, 32 | 'num_heads': {dim: [sorted(heads[layer]) for layer in sorted(heads)] for dim, heads in new_num_heads.items()}, 33 | 'depth': {dim: sorted(deps) for dim, deps in new_depth.items()}, 34 | } 35 | 36 | 37 | if __name__ == '__main__': 38 | import os, sys, yaml 39 | 40 | cand_file = os.path.normpath(sys.argv[1]) 41 | conf_file = os.path.normpath(sys.argv[2]) 42 | if os.path.exists(conf_file): 43 | print(f'Target file already exists: {conf_file}') 44 | exit() 45 | 46 | new_choices = candidate_to_choices(cand_file) 47 | #print(new_choices) 48 | cfg = dict() 49 | cfg['SEARCH_SPACE'] = {k.upper(): v for k, v in new_choices.items()} 50 | max_depth = max({dep for deps in new_choices['depth'].values() for dep in deps}) 51 | max_ratio = max(max(ratio_list) for ratios in new_choices['mlp_ratio'].values() for ratio_list in ratios) 52 | max_heads = max(max(heads_list) for heads in new_choices['num_heads'].values() for heads_list in heads) 53 | max_dim = max_heads * 64 54 | assert max_dim >= max(new_choices['embed_dim']) 55 | cfg['SUPERNET'] = {'DEPTH': max_depth, 'MLP_RATIO': max_ratio, 'NUM_HEADS': max_heads, 'EMBED_DIM': max_dim} 56 | 57 | yaml.safe_dump(cfg, open(conf_file, 'w')) 58 | print(f'Saved to: {conf_file}') 59 | -------------------------------------------------------------------------------- /evolution_pre_train.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import time 5 | import torch 6 | import torch.backends.cudnn as cudnn 7 | from pathlib import Path 8 | 9 | from lib.datasets import build_dataset 10 | from lib import utils 11 | from supernet_engine import evaluate 12 | from model.supernet_transformer import Vision_TransformerSuper 13 | import argparse 14 | import os 15 | import yaml 16 | from lib.config import cfg, update_config_from_file 17 | from lib.score_maker import ScoreMaker 18 | import math 19 | from itertools import combinations 20 | import json 21 | 22 | 23 | def decode_cand_tuple(cand_tuple): 24 | depth = cand_tuple[0] 25 | return depth, list(cand_tuple[1:depth+1]), list(cand_tuple[depth + 1: 2 * depth + 1]), cand_tuple[-1] 26 | 27 | 28 | def get_max_min_model(choices): 29 | max_depth = max(choices['depth']) 30 | max_emb = max(choices['embed_dim']) 31 | max_num_head = max(choices['num_heads']) 32 | max_mlp_ratio = max(choices['mlp_ratio']) 33 | min_depth = min(choices['depth']) 34 | min_emb = min(choices['embed_dim']) 35 | min_num_head = min(choices['num_heads']) 36 | min_mlp_ratio = min(choices['mlp_ratio']) 37 | max_model = tuple([max_depth] + [max_mlp_ratio] * max_depth + [max_num_head] * max_depth + [max_emb]) 38 | min_model = tuple([min_depth] + [min_mlp_ratio] * min_depth + [min_num_head] * min_depth + [min_emb]) 39 | return max_model, min_model 40 | 41 | 42 | class Searcher(object): 43 | 44 | def __init__(self, args, device, model, model_without_ddp, choices, output_dir, score_maker): 45 | self.device = device 46 | self.model = model 47 | self.model_without_ddp = model_without_ddp 48 | self.args = args 49 | self.max_epochs = args.max_epochs 50 | self.select_num = args.select_num 51 | self.population_num = args.population_num 52 | self.m_prob = args.m_prob 53 | self.crossover_num = args.crossover_num 54 | self.mutation_num = args.mutation_num 55 | self.parameters_limits = args.param_limits 56 | self.min_parameters_limits = args.min_param_limits 57 | self.output_dir = output_dir 58 | self.s_prob =args.s_prob 59 | self.memory = [] 60 | self.vis_dict = {} 61 | self.keep_top_k = {} 62 | self.epoch = 0 63 | self.checkpoint_path = args.resume 64 | self.candidates = [] 65 | self.top_accuracies = [] 66 | self.cand_params = [] 67 | self.choices = choices 68 | self.choices['num_heads'].sort() 69 | self.choices['mlp_ratio'].sort() 70 | 71 | self.score_maker = score_maker 72 | self.eval_cnt = 0 73 | self.update_num = 0 74 | self.un_update_cnt = 0 75 | 76 | self.all_cands = {} 77 | min_param = self.min_parameters_limits 78 | max_param = min_param + self.args.param_interval 79 | while max_param < self.parameters_limits + 1e-6: 80 | params = (max_param + min_param) / 2 81 | self.all_cands[self.param_to_index(params)] = [] 82 | min_param = max_param 83 | max_param = min_param + self.args.param_interval 84 | 85 | self.cur_min_param = args.min_param_limits 86 | self.cur_max_param = args.param_limits 87 | self.interval_cands = {} 88 | self.max_model, self.min_model = get_max_min_model(choices) 89 | self.search_mode = args.search_mode 90 | self.head_mlp_scores = {} 91 | 92 | def get_params_range(self): 93 | depth, mlp_ratio, num_heads, embed_dim = decode_cand_tuple(self.max_model) 94 | sampled_config = {} 95 | sampled_config['layer_num'] = depth 96 | sampled_config['mlp_ratio'] = mlp_ratio 97 | sampled_config['num_heads'] = num_heads 98 | sampled_config['embed_dim'] = [embed_dim] * depth 99 | 100 | n_parameters = self.model_without_ddp.get_sampled_params_numel(sampled_config) 101 | max_params = n_parameters / 10. ** 6 102 | 103 | depth, mlp_ratio, num_heads, embed_dim = decode_cand_tuple(self.min_model) 104 | sampled_config = {} 105 | sampled_config['layer_num'] = depth 106 | sampled_config['mlp_ratio'] = mlp_ratio 107 | sampled_config['num_heads'] = num_heads 108 | sampled_config['embed_dim'] = [embed_dim] * depth 109 | 110 | n_parameters = self.model_without_ddp.get_sampled_params_numel(sampled_config) 111 | min_params = n_parameters / 10. ** 6 112 | return min_params, max_params 113 | 114 | def select_cands(self, *, key, reverse=True): 115 | for k in self.all_cands.keys(): 116 | t = self.all_cands[k] 117 | t.sort(key=key, reverse=reverse) 118 | self.all_cands[k] = t[:self.args.cand_per_interval] 119 | 120 | def param_to_index(self, param): 121 | if param < self.min_parameters_limits: 122 | return -1 123 | if param >= self.parameters_limits: 124 | return -1 125 | return math.floor((param - self.min_parameters_limits) / self.args.param_interval) 126 | 127 | def index_to_param_interval(self, index): 128 | if index == -1: 129 | return (0, self.min_parameters_limits) 130 | if index == -2: 131 | return (self.parameters_limits, 2*self.parameters_limits) 132 | down = self.min_parameters_limits + index * self.args.param_interval 133 | up = down + self.args.param_interval 134 | return (down, up) 135 | 136 | def stack_random_cand(self, random_func, *, batchsize=10): 137 | while True: 138 | cands = [random_func() for _ in range(batchsize)] 139 | for cand in cands: 140 | if cand not in self.vis_dict: 141 | self.vis_dict[cand] = {} 142 | info = self.vis_dict[cand] 143 | for cand in cands: 144 | yield cand 145 | 146 | 147 | def get_random_cand_without_reallocate(self): 148 | 149 | cand_tuple = list() 150 | dimensions = ['mlp_ratio', 'num_heads'] 151 | depth = random.choice(self.choices['depth']) 152 | cand_tuple.append(depth) 153 | for dimension in dimensions: 154 | idx = list(range(len(self.choices[dimension]))) 155 | random.shuffle(idx) 156 | choice_cnt = {} 157 | left_layers = depth 158 | for i in idx[:-1]: 159 | choice = self.choices[dimension][i] 160 | cnt = random.choice(range(left_layers + 1)) 161 | left_layers = left_layers - cnt 162 | choice_cnt[choice] = cnt 163 | choice = self.choices[dimension][idx[-1]] 164 | choice_cnt[choice] = left_layers 165 | conf = [0] * depth 166 | 167 | for choice in self.choices[dimension][1:][::-1]: 168 | scores = np.random.rand(depth) 169 | mask = np.where(np.array(conf) > 0, -1, 1) 170 | mask_scores = scores * mask 171 | for i in mask_scores.argsort()[::-1][:choice_cnt[choice]]: 172 | conf[i] = choice 173 | for i in range(len(conf)): 174 | if conf[i] == 0: 175 | conf[i] = self.choices[dimension][0] 176 | 177 | cand_tuple.extend(conf) 178 | 179 | cand_tuple.append(random.choice(self.choices['embed_dim'])) 180 | return tuple(cand_tuple) 181 | 182 | def get_random_cand(self): 183 | 184 | cand_tuple = list() 185 | dimensions = ['mlp_ratio', 'num_heads'] 186 | score_names = ['mlp_scores', 'head_scores'] 187 | depth = random.choice(self.choices['depth']) 188 | cand_tuple.append(depth) 189 | emb_dim = random.choice(self.choices['embed_dim']) 190 | max_dim = max(self.choices['embed_dim']) 191 | for (dimension, score_name) in zip(dimensions, score_names): 192 | idx = list(range(len(self.choices[dimension]))) 193 | random.shuffle(idx) 194 | choice_cnt = {} 195 | left_layers = depth 196 | for i in idx[:-1]: 197 | choice = self.choices[dimension][i] 198 | cnt = random.choice(range(left_layers + 1)) 199 | left_layers = left_layers - cnt 200 | choice_cnt[choice] = cnt 201 | choice = self.choices[dimension][idx[-1]] 202 | choice_cnt[choice] = left_layers 203 | choice_cnt_list = [choice_cnt[choice] for choice in self.choices[dimension]] 204 | method = None 205 | if dimension == 'mlp_ratio': 206 | method = self.args.block_score_method_for_mlp 207 | else: 208 | method = self.args.block_score_method_for_head 209 | cand_tuple.extend(self.reallocate(depth, 210 | emb_dim, 211 | dimension, 212 | self.head_mlp_scores[score_name], 213 | choice_cnt_list, 214 | method)) 215 | 216 | cand_tuple.append(emb_dim) 217 | return tuple(cand_tuple) 218 | 219 | def get_random(self, num): 220 | print('random select ........') 221 | if self.args.search_mode == 'iteration' or self.args.reallocate: 222 | cand_iter = self.stack_random_cand(self.get_random_cand) 223 | else: 224 | cand_iter = self.stack_random_cand(self.get_random_cand_without_reallocate) 225 | while len(self.candidates) < num: 226 | cand = next(cand_iter) 227 | if not self.is_legal(cand): 228 | continue 229 | self.candidates.append(cand) 230 | print('random {}/{}'.format(len(self.candidates), num)) 231 | print('random_num = {}'.format(len(self.candidates))) 232 | 233 | def is_legal(self, cand): 234 | assert isinstance(cand, tuple) 235 | 236 | if cand not in self.vis_dict: 237 | self.vis_dict[cand] = {} 238 | info = self.vis_dict[cand] 239 | if 'visited' in info: 240 | return False 241 | depth, mlp_ratio, num_heads, embed_dim = decode_cand_tuple(cand) 242 | sampled_config = {} 243 | sampled_config['layer_num'] = depth 244 | sampled_config['mlp_ratio'] = mlp_ratio 245 | sampled_config['num_heads'] = num_heads 246 | sampled_config['embed_dim'] = [embed_dim]*depth 247 | 248 | n_parameters = self.model_without_ddp.get_sampled_params_numel(sampled_config) 249 | info['params'] = n_parameters / 10.**6 250 | 251 | if info['params'] > self.cur_max_param: 252 | print('parameters limit exceed {}'.format(self.cur_max_param)) 253 | return False 254 | 255 | if info['params'] < self.cur_min_param: 256 | print('under minimum parameters limit {}'.format(self.cur_min_param)) 257 | return False 258 | 259 | info['visited'] = True 260 | 261 | return True 262 | 263 | def conf_to_cnt_list(self, conf, part): 264 | cnt_list = [0]*len(self.choices[part]) 265 | for choice in conf: 266 | cnt_list[self.choices[part].index(choice)] += 1 267 | return cnt_list 268 | 269 | def reallocate(self, depth, embed_dim, part, scores, choice_cnt, method): 270 | 271 | if method == 'deeper_is_better': 272 | conf = [] 273 | for choice, cnt in zip(self.choices[part], choice_cnt): 274 | conf = conf + ([choice] * cnt) 275 | return conf 276 | 277 | if 'max_dim' in method: 278 | embed_dim = max(self.choices['embed_dim']) 279 | 280 | conf = [0] * depth 281 | for choice, cnt in zip(self.choices[part][1:][::-1], choice_cnt[1:][::-1]): 282 | cur_scores = np.array(scores[(f"{embed_dim},{choice}")][:depth]) 283 | mask = np.where(np.array(conf) > 0, -1, 1) 284 | mask_scores = cur_scores * mask 285 | for i in mask_scores.argsort()[::-1][:cnt]: 286 | conf[i] = choice 287 | for i in range(len(conf)): 288 | if conf[i] == 0: 289 | conf[i] = self.choices[part][0] 290 | return conf 291 | 292 | def get_score(self): 293 | for cand in self.candidates: 294 | info = self.vis_dict[cand] 295 | if self.args.score_method == 'params': 296 | info['score'] = info['params'] 297 | else: 298 | depth, mlp_ratio, num_heads, embed_dim = decode_cand_tuple(cand) 299 | sampled_config = {} 300 | sampled_config['layer_num'] = depth 301 | sampled_config['mlp_ratio'] = mlp_ratio 302 | sampled_config['num_heads'] = num_heads 303 | sampled_config['embed_dim'] = [embed_dim] * depth 304 | score = self.score_maker.get_score(self.model, self.args.score_method, config=sampled_config) 305 | info['score'] = score 306 | 307 | def update_top_k(self, candidates, *, k, key, reverse=True, get_update_num=False): 308 | assert k in self.keep_top_k 309 | print('select ......') 310 | t = self.keep_top_k[k] 311 | t += candidates 312 | t.sort(key=key, reverse=reverse) 313 | self.keep_top_k[k] = t[:k] 314 | if get_update_num: 315 | self.update_num = 0 316 | for cand in self.keep_top_k[k]: 317 | if cand in candidates: 318 | self.update_num += 1 319 | print('update {} models in top {}.'.format(self.update_num, k)) 320 | if self.update_num == 0: 321 | self.un_update_cnt += 1 322 | 323 | def get_mutation(self, k, mutation_num, m_prob, s_prob): 324 | assert k in self.keep_top_k 325 | print('mutation ......') 326 | res = [] 327 | iter = 0 328 | max_iters = mutation_num * 10 329 | 330 | def random_func(): 331 | cand = list(random.choice(self.keep_top_k[k])) 332 | depth, mlp_ratio, num_heads, embed_dim = decode_cand_tuple(cand) 333 | random_s = random.random() 334 | 335 | # depth 336 | if random_s < s_prob: 337 | new_depth = random.choice(self.choices['depth']) 338 | 339 | if new_depth > depth: 340 | mlp_ratio = mlp_ratio + [random.choice(self.choices['mlp_ratio']) for _ in range(new_depth - depth)] 341 | num_heads = num_heads + [random.choice(self.choices['num_heads']) for _ in range(new_depth - depth)] 342 | else: 343 | mlp_ratio = mlp_ratio[:new_depth] 344 | num_heads = num_heads[:new_depth] 345 | 346 | depth = new_depth 347 | # mlp_ratio 348 | for i in range(depth): 349 | random_s = random.random() 350 | if random_s < m_prob: 351 | mlp_ratio[i] = random.choice(self.choices['mlp_ratio']) 352 | 353 | # num_heads 354 | 355 | for i in range(depth): 356 | random_s = random.random() 357 | if random_s < m_prob: 358 | num_heads[i] = random.choice(self.choices['num_heads']) 359 | 360 | # embed_dim 361 | random_s = random.random() 362 | if random_s < s_prob: 363 | embed_dim = random.choice(self.choices['embed_dim']) 364 | 365 | mlp_cnt = self.conf_to_cnt_list(mlp_ratio, 'mlp_ratio') 366 | head_cnt = self.conf_to_cnt_list(num_heads, 'num_heads') 367 | mlp_ratio = self.reallocate(depth, 368 | embed_dim, 369 | 'mlp_ratio', 370 | self.head_mlp_scores['mlp_scores'], 371 | mlp_cnt, 372 | self.args.block_score_method_for_mlp) 373 | num_heads = self.reallocate(depth, 374 | embed_dim, 375 | 'num_heads', 376 | self.head_mlp_scores['head_scores'], 377 | head_cnt, 378 | self.args.block_score_method_for_head) 379 | 380 | result_cand = [depth] + mlp_ratio + num_heads + [embed_dim] 381 | 382 | return tuple(result_cand) 383 | 384 | cand_iter = self.stack_random_cand(random_func) 385 | while len(res) < mutation_num and max_iters > 0: 386 | max_iters -= 1 387 | cand = next(cand_iter) 388 | if not self.is_legal(cand): 389 | continue 390 | res.append(cand) 391 | print('mutation {}/{}'.format(len(res), mutation_num)) 392 | 393 | print('mutation_num = {}'.format(len(res))) 394 | return res 395 | 396 | def get_crossover(self, k, crossover_num): 397 | assert k in self.keep_top_k 398 | print('crossover ......') 399 | res = [] 400 | iter = 0 401 | max_iters = 10 * crossover_num 402 | 403 | def random_func(): 404 | 405 | p1 = random.choice(self.keep_top_k[k]) 406 | p2 = random.choice(self.keep_top_k[k]) 407 | max_iters_tmp = 50 408 | while len(p1) != len(p2) and max_iters_tmp > 0: 409 | max_iters_tmp -= 1 410 | p1 = random.choice(self.keep_top_k[k]) 411 | p2 = random.choice(self.keep_top_k[k]) 412 | cand = tuple(random.choice([i, j]) for i, j in zip(p1, p2)) 413 | depth, mlp_ratio, num_heads, embed_dim = decode_cand_tuple(cand) 414 | mlp_cnt = self.conf_to_cnt_list(mlp_ratio, 'mlp_ratio') 415 | head_cnt = self.conf_to_cnt_list(num_heads, 'num_heads') 416 | mlp_ratio = self.reallocate(depth, 417 | embed_dim, 418 | 'mlp_ratio', 419 | self.head_mlp_scores['mlp_scores'], 420 | mlp_cnt, 421 | self.args.block_score_method_for_mlp) 422 | num_heads = self.reallocate(depth, 423 | embed_dim, 424 | 'num_heads', 425 | self.head_mlp_scores['head_scores'], 426 | head_cnt, 427 | self.args.block_score_method_for_head) 428 | result_cand = [depth] + mlp_ratio + num_heads + [embed_dim] 429 | return tuple(result_cand) 430 | 431 | cand_iter = self.stack_random_cand(random_func) 432 | while len(res) < crossover_num and max_iters > 0: 433 | max_iters -= 1 434 | cand = next(cand_iter) 435 | if not self.is_legal(cand): 436 | continue 437 | res.append(cand) 438 | print('crossover {}/{}'.format(len(res), crossover_num)) 439 | 440 | print('crossover_num = {}'.format(len(res))) 441 | return res 442 | 443 | def search(self, out_file_name=None): 444 | 445 | print('searching...') 446 | if not self.args.block_score_method_for_mlp == 'deeper_is_better' or not self.args.block_score_method_for_head == 'deeper_is_better': 447 | self.head_mlp_scores = self.score_maker.get_block_scores(self.model, self.args, self.choices) 448 | 449 | # random search 450 | if self.args.search_mode == 'random': 451 | self.cur_min_param = self.min_parameters_limits 452 | self.cur_max_param = self.cur_min_param + self.args.param_interval 453 | 454 | while self.cur_max_param < self.parameters_limits + 1e-6: 455 | self.candidates = [] 456 | self.keep_top_k = {100: []} 457 | self.get_random(self.population_num) 458 | self.get_score() 459 | self.update_top_k( 460 | self.candidates, k=100, key=lambda x: self.vis_dict[x]['score']) 461 | for i, cand in enumerate(self.keep_top_k[100]): 462 | print('No.{} {} score = {}, params = {}'.format( 463 | i + 1, cand, self.vis_dict[cand]['score'], self.vis_dict[cand]['params'])) 464 | self.interval_cands[(self.cur_min_param, self.cur_max_param)] = self.keep_top_k[100][:self.args.cand_per_interval] 465 | self.cur_min_param = self.cur_max_param 466 | self.cur_max_param = self.cur_min_param + self.args.param_interval 467 | # evolution search 468 | elif self.args.search_mode == 'evolution': 469 | self.cur_min_param = self.min_parameters_limits 470 | self.cur_max_param = self.cur_min_param + self.args.param_interval 471 | 472 | while self.cur_max_param < self.parameters_limits + 1e-6: 473 | self.update_num = 0 474 | self.un_update_cnt = 0 475 | self.epoch = 0 476 | self.candidates = [] 477 | self.keep_top_k = {self.select_num: [], 100: []} 478 | self.get_random(self.population_num) 479 | while self.epoch < self.max_epochs: 480 | print('epoch = {} for param {} to param {}'.format(self.epoch, self.cur_min_param, self.cur_max_param)) 481 | 482 | if self.un_update_cnt == 2: 483 | self.epoch += 1 484 | continue 485 | 486 | self.get_score() 487 | self.update_top_k( 488 | self.candidates, k=self.select_num, key=lambda x: self.vis_dict[x]['score'], get_update_num=True) 489 | self.update_top_k( 490 | self.candidates, k=100, key=lambda x: self.vis_dict[x]['score']) 491 | 492 | print('epoch = {} for param {} to param {} : top {} result'.format( 493 | self.epoch, self.cur_min_param, self.cur_max_param, len(self.keep_top_k[100]))) 494 | for i, cand in enumerate(self.keep_top_k[100]): 495 | print('No.{} {} score = {}, params = {}'.format( 496 | i + 1, cand, self.vis_dict[cand]['score'], self.vis_dict[cand]['params'])) 497 | 498 | self.epoch += 1 499 | if self.epoch >= self.max_epochs: 500 | break 501 | 502 | # check 503 | mutation = self.get_mutation( 504 | self.select_num, self.mutation_num, self.m_prob, self.s_prob) 505 | crossover = self.get_crossover(self.select_num, self.crossover_num) 506 | 507 | self.candidates = mutation + crossover 508 | 509 | self.get_random(self.population_num) 510 | 511 | self.interval_cands[(self.cur_min_param, self.cur_max_param)] = self.keep_top_k[100][:self.args.cand_per_interval] 512 | self.cur_min_param = self.cur_max_param 513 | self.cur_max_param = self.cur_min_param + self.args.param_interval 514 | # force search 515 | else: 516 | max_dim = max(self.choices['embed_dim']) 517 | iter_cnt = 0 518 | for embed_dim in self.choices['embed_dim']: 519 | for depth in self.choices['depth']: 520 | depth_ids = list(range(depth+1)) 521 | num_head_choice = len(self.choices['num_heads']) 522 | num_mlp_choice = len(self.choices['mlp_ratio']) 523 | mlp_confs = [] 524 | head_confs = [] 525 | 526 | for mlp_dist in combinations(depth_ids, num_mlp_choice - 1): 527 | mlp_dist = [0] + list(mlp_dist) + [depth] 528 | mlp_cnt = [mlp_dist[i+1] - mlp_dist[i] for i in range(len(mlp_dist)-1)] 529 | mlp_confs.append(self.reallocate(depth, 530 | embed_dim, 531 | 'mlp_ratio', 532 | self.head_mlp_scores['mlp_scores'], 533 | mlp_cnt, 534 | self.args.block_score_method_for_mlp)) 535 | 536 | for head_dist in combinations(depth_ids, num_head_choice - 1): 537 | head_dist = [0] + list(head_dist) + [depth] 538 | head_cnt = [head_dist[i+1] - head_dist[i] for i in range(len(head_dist)-1)] 539 | head_confs.append(self.reallocate(depth, 540 | embed_dim, 541 | 'num_heads', 542 | self.head_mlp_scores['head_scores'], 543 | head_cnt, 544 | self.args.block_score_method_for_head)) 545 | 546 | for mlp_conf in mlp_confs: 547 | iter_cnt += 1 548 | for head_conf in head_confs: 549 | cand = tuple([depth] + mlp_conf + head_conf + [embed_dim]) 550 | depth, mlp_ratio, num_heads, embed_dim = decode_cand_tuple(cand) 551 | sampled_config = {} 552 | sampled_config['layer_num'] = depth 553 | sampled_config['mlp_ratio'] = mlp_ratio 554 | sampled_config['num_heads'] = num_heads 555 | sampled_config['embed_dim'] = [embed_dim] * depth 556 | n_parameters = self.model_without_ddp.get_sampled_params_numel(sampled_config) 557 | params = n_parameters / 10. ** 6 558 | index = self.param_to_index(params) 559 | 560 | if self.args.score_method == 'params': 561 | score = params 562 | else: 563 | score = self.score_maker.get_score(self.model, self.args.score_method, config=sampled_config) 564 | 565 | info = {'cand': cand, 'score': score, 'params': params} 566 | self.vis_dict[cand] = info 567 | if index in self.all_cands.keys(): 568 | self.all_cands[index].append(info) 569 | 570 | self.select_cands(key=lambda x: x['score']) 571 | 572 | for index in self.all_cands.keys(): 573 | k = self.index_to_param_interval(index) 574 | self.interval_cands[k] = [item['cand'] for item in self.all_cands[index]] 575 | 576 | if out_file_name is None: 577 | out_file_name = f'out/interval_cands_{self.args.super_model_size}_{self.args.score_method}_{self.args.block_score_method_for_mlp}_for_mlp_{self.args.block_score_method_for_head}_for_head' 578 | out_file_name += f'_i{self.args.param_interval}_top_{self.args.cand_per_interval}.pt' 579 | torch.save(self.interval_cands, out_file_name) 580 | else: 581 | json_dict = {} 582 | for interval in self.interval_cands.keys(): 583 | cand_list = [] 584 | for cand in self.interval_cands[interval]: 585 | depth, mlp_ratio, num_heads, embed_dim = decode_cand_tuple(cand) 586 | info = { 587 | 'layer_num': depth, 588 | 'mlp_ratio': mlp_ratio, 589 | 'num_heads': num_heads, 590 | 'embed_dim': [embed_dim]*depth, 591 | 'num_params': float(self.vis_dict[cand]['params']), 592 | 'score': float(self.vis_dict[cand]['score']) 593 | } 594 | cand_list.append(info) 595 | if len(cand_list) > 0: 596 | json_dict[str(interval[1])] = cand_list 597 | print("selected candidates:") 598 | print(json_dict) 599 | with open(out_file_name, "w") as fp: 600 | json.dump(json_dict, fp, indent=2) 601 | fp.close() 602 | 603 | 604 | 605 | return self.interval_cands 606 | 607 | -------------------------------------------------------------------------------- /experiments/supernet/base.yaml: -------------------------------------------------------------------------------- 1 | SEARCH_SPACE: 2 | DEPTH: 3 | 528: 4 | - 14 5 | 576: 6 | - 14 7 | 624: 8 | - 14 9 | EMBED_DIM: 10 | - 528 11 | - 576 12 | - 624 13 | MLP_RATIO: 14 | 528: 15 | - - 3.0 16 | - - 3.0 17 | - - 3.0 18 | - - 3.0 19 | - - 3.0 20 | - - 3.0 21 | - - 3.0 22 | - - 3.0 23 | - - 3.0 24 | - - 3.0 25 | - - 3.0 26 | - - 3.0 27 | - - 3.5 28 | - - 3.0 29 | 576: 30 | - - 3.5 31 | - - 4.0 32 | - - 3.5 33 | - - 3.5 34 | - - 3.5 35 | - - 4.0 36 | - - 3.5 37 | - - 3.5 38 | - - 3.5 39 | - - 3.5 40 | - - 4.0 41 | - - 3.0 42 | - - 3.5 43 | - - 4.0 44 | 624: 45 | - - 4.0 46 | - - 4.0 47 | - - 3.5 48 | - - 4.0 49 | - - 4.0 50 | - - 4.0 51 | - - 4.0 52 | - - 4.0 53 | - - 4.0 54 | - - 4.0 55 | - - 4.0 56 | - - 4.0 57 | - - 4.0 58 | - - 4.0 59 | NUM_HEADS: 60 | 528: 61 | - - 9 62 | - - 9 63 | - - 9 64 | - - 9 65 | - - 9 66 | - - 9 67 | - - 9 68 | - - 9 69 | - - 9 70 | - - 9 71 | - - 9 72 | - - 9 73 | - - 9 74 | - - 9 75 | 576: 76 | - - 9 77 | - - 9 78 | - - 10 79 | - - 9 80 | - - 9 81 | - - 10 82 | - - 9 83 | - - 9 84 | - - 9 85 | - - 9 86 | - - 9 87 | - - 9 88 | - - 10 89 | - - 9 90 | 624: 91 | - - 9 92 | - - 10 93 | - - 10 94 | - - 9 95 | - - 9 96 | - - 10 97 | - - 10 98 | - - 9 99 | - - 10 100 | - - 9 101 | - - 9 102 | - - 9 103 | - - 10 104 | - - 10 105 | SUPERNET: 106 | DEPTH: 14 107 | EMBED_DIM: 640 108 | MLP_RATIO: 4.0 109 | NUM_HEADS: 10 110 | -------------------------------------------------------------------------------- /experiments/supernet/small.yaml: -------------------------------------------------------------------------------- 1 | SEARCH_SPACE: 2 | DEPTH: 3 | 320: 4 | - 13 5 | 384: 6 | - 13 7 | 448: 8 | - 13 9 | - 14 10 | EMBED_DIM: 11 | - 320 12 | - 384 13 | - 448 14 | MLP_RATIO: 15 | 320: 16 | - - 4.0 17 | - - 4.0 18 | - - 3.5 19 | - - 4.0 20 | - - 3.5 21 | - - 4.0 22 | - - 4.0 23 | - - 3.5 24 | - - 4.0 25 | - - 4.0 26 | - - 3.5 27 | - - 4.0 28 | - - 4.0 29 | 384: 30 | - - 4.0 31 | - - 4.0 32 | - - 3.5 33 | - - 4.0 34 | - - 3.5 35 | - - 4.0 36 | - - 3.5 37 | - - 3.5 38 | - - 4.0 39 | - - 4.0 40 | - - 3.0 41 | - - 4.0 42 | - - 3.5 43 | 448: 44 | - - 4.0 45 | - - 4.0 46 | - - 3.0 47 | - 3.5 48 | - - 4.0 49 | - - 3.0 50 | - 3.5 51 | - - 3.5 52 | - 4.0 53 | - - 3.0 54 | - 4.0 55 | - - 3.0 56 | - 3.5 57 | - - 4.0 58 | - - 3.0 59 | - 4.0 60 | - - 3.0 61 | - 3.5 62 | - - 4.0 63 | - - 3.0 64 | - 4.0 65 | - - 3.5 66 | NUM_HEADS: 67 | 320: 68 | - - 7 69 | - - 7 70 | - - 7 71 | - - 7 72 | - - 7 73 | - - 7 74 | - - 7 75 | - - 7 76 | - - 6 77 | - - 7 78 | - - 5 79 | - - 6 80 | - - 5 81 | 384: 82 | - - 7 83 | - - 7 84 | - - 7 85 | - - 7 86 | - - 5 87 | - - 7 88 | - - 7 89 | - - 5 90 | - - 6 91 | - - 5 92 | - - 5 93 | - - 6 94 | - - 5 95 | 448: 96 | - - 7 97 | - - 7 98 | - - 7 99 | - - 7 100 | - - 5 101 | - 7 102 | - - 7 103 | - - 7 104 | - - 5 105 | - 7 106 | - - 5 107 | - 6 108 | - - 5 109 | - 7 110 | - - 5 111 | - - 6 112 | - - 5 113 | - - 7 114 | SUPERNET: 115 | DEPTH: 14 116 | EMBED_DIM: 448 117 | MLP_RATIO: 4.0 118 | NUM_HEADS: 7 119 | -------------------------------------------------------------------------------- /experiments/supernet/supernet-B.yaml: -------------------------------------------------------------------------------- 1 | SUPERNET: 2 | MLP_RATIO: 4.0 3 | NUM_HEADS: 10 4 | EMBED_DIM: 640 5 | DEPTH: 16 6 | SEARCH_SPACE: 7 | MLP_RATIO: 8 | - 3.0 9 | - 3.5 10 | - 4.0 11 | NUM_HEADS: 12 | - 9 13 | - 10 14 | DEPTH: 15 | - 14 16 | - 15 17 | - 16 18 | EMBED_DIM: 19 | - 528 20 | - 576 21 | - 624 22 | -------------------------------------------------------------------------------- /experiments/supernet/supernet-S.yaml: -------------------------------------------------------------------------------- 1 | SUPERNET: 2 | MLP_RATIO: 4.0 3 | NUM_HEADS: 7 4 | EMBED_DIM: 448 5 | DEPTH: 14 6 | SEARCH_SPACE: 7 | MLP_RATIO: 8 | - 3.0 9 | - 3.5 10 | - 4.0 11 | NUM_HEADS: 12 | - 5 13 | - 6 14 | - 7 15 | DEPTH: 16 | - 13 17 | - 14 18 | EMBED_DIM: 19 | - 320 20 | - 384 21 | - 448 22 | -------------------------------------------------------------------------------- /experiments/supernet/supernet-T.yaml: -------------------------------------------------------------------------------- 1 | SUPERNET: 2 | MLP_RATIO: 4.0 3 | NUM_HEADS: 4 4 | EMBED_DIM: 256 5 | DEPTH: 14 6 | SEARCH_SPACE: 7 | MLP_RATIO: 8 | - 3.5 9 | - 4 10 | NUM_HEADS: 11 | - 3 12 | - 4 13 | DEPTH: 14 | - 12 15 | - 13 16 | - 14 17 | EMBED_DIM: 18 | - 192 19 | - 216 20 | - 240 21 | -------------------------------------------------------------------------------- /experiments/supernet/tiny.yaml: -------------------------------------------------------------------------------- 1 | SEARCH_SPACE: 2 | DEPTH: 3 | 192: 4 | - 12 5 | 216: 6 | - 12 7 | 240: 8 | - 12 9 | - 14 10 | EMBED_DIM: 11 | - 192 12 | - 216 13 | - 240 14 | MLP_RATIO: 15 | 192: 16 | - - 4 17 | - - 4 18 | - - 4 19 | - - 4 20 | - - 4 21 | - - 3.5 22 | - - 3.5 23 | - - 4 24 | - - 4 25 | - - 3.5 26 | - - 4 27 | - - 4 28 | 216: 29 | - - 4 30 | - - 3.5 31 | - - 4 32 | - - 3.5 33 | - - 3.5 34 | - - 3.5 35 | - - 3.5 36 | - - 3.5 37 | - - 3.5 38 | - - 3.5 39 | - - 3.5 40 | - - 3.5 41 | 240: 42 | - - 4 43 | - - 3.5 44 | - 4 45 | - - 4 46 | - - 4 47 | - - 3.5 48 | - 4 49 | - - 3.5 50 | - 4 51 | - - 3.5 52 | - 4 53 | - - 4 54 | - - 4 55 | - - 3.5 56 | - 4 57 | - - 4 58 | - - 3.5 59 | - 4 60 | - - 4 61 | - - 4 62 | NUM_HEADS: 63 | 192: 64 | - - 4 65 | - - 4 66 | - - 4 67 | - - 4 68 | - - 3 69 | - - 3 70 | - - 3 71 | - - 3 72 | - - 3 73 | - - 4 74 | - - 3 75 | - - 4 76 | 216: 77 | - - 4 78 | - - 4 79 | - - 4 80 | - - 4 81 | - - 4 82 | - - 3 83 | - - 3 84 | - - 3 85 | - - 4 86 | - - 4 87 | - - 3 88 | - - 4 89 | 240: 90 | - - 3 91 | - 4 92 | - - 3 93 | - 4 94 | - - 3 95 | - 4 96 | - - 3 97 | - 4 98 | - - 3 99 | - 4 100 | - - 3 101 | - 4 102 | - - 3 103 | - 4 104 | - - 3 105 | - 4 106 | - - 3 107 | - 4 108 | - - 3 109 | - 4 110 | - - 3 111 | - 4 112 | - - 3 113 | - 4 114 | - - 3 115 | - 4 116 | - - 4 117 | SUPERNET: 118 | DEPTH: 14 119 | EMBED_DIM: 256 120 | MLP_RATIO: 4 121 | NUM_HEADS: 4 122 | -------------------------------------------------------------------------------- /interval_cands/base.json: -------------------------------------------------------------------------------- 1 | {"42.0": [{"layer_num": 14, "mlp_ratio": [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.5, 3.0], "num_heads": [9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9], "embed_dim": [528, 528, 528, 528, 528, 528, 528, 528, 528, 528, 528, 528, 528, 528], "num_params": 41.966944, "score": 5461899.5}], "54.0": [{"layer_num": 14, "mlp_ratio": [3.5, 4.0, 3.5, 3.5, 3.5, 4.0, 3.5, 3.5, 3.5, 3.5, 4.0, 3.0, 3.5, 4.0], "num_heads": [9, 9, 10, 9, 9, 10, 9, 9, 9, 9, 9, 9, 10, 9], "embed_dim": [576, 576, 576, 576, 576, 576, 576, 576, 576, 576, 576, 576, 576, 576], "num_params": 53.876104, "score": 6887513.0}], "66.0": [{"layer_num": 14, "mlp_ratio": [4.0, 4.0, 3.5, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0], "num_heads": [9, 10, 10, 9, 9, 10, 10, 9, 10, 9, 9, 9, 10, 10], "embed_dim": [624, 624, 624, 624, 624, 624, 624, 624, 624, 624, 624, 624, 624, 624], "num_params": 65.916448, "score": 8251315.0}]} -------------------------------------------------------------------------------- /interval_cands/small.json: -------------------------------------------------------------------------------- 1 | {"18.0": [{"layer_num": 13, "mlp_ratio": [4.0, 4.0, 3.5, 4.0, 3.5, 4.0, 4.0, 3.5, 4.0, 4.0, 3.5, 4.0, 4.0], "num_heads": [7, 7, 7, 7, 7, 7, 7, 7, 6, 7, 5, 6, 5], "embed_dim": [320, 320, 320, 320, 320, 320, 320, 320, 320, 320, 320, 320, 320], "num_params": 17.9914, "score": 2764111.5}], "23.0": [{"layer_num": 13, "mlp_ratio": [4.0, 4.0, 3.5, 4.0, 3.5, 4.0, 3.5, 3.5, 4.0, 4.0, 3.0, 4.0, 3.5], "num_heads": [7, 7, 7, 7, 5, 7, 7, 5, 6, 5, 5, 6, 5], "embed_dim": [384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384, 384], "num_params": 22.989928, "score": 3528041.5}], "28.0": [{"layer_num": 13, "mlp_ratio": [4.0, 4.0, 3.0, 4.0, 3.0, 3.5, 3.0, 3.0, 4.0, 3.0, 3.0, 4.0, 3.0], "num_heads": [7, 7, 7, 7, 5, 7, 7, 5, 6, 5, 5, 6, 5], "embed_dim": [448, 448, 448, 448, 448, 448, 448, 448, 448, 448, 448, 448, 448], "num_params": 27.976008, "score": 4264070.0}], "33.0": [{"layer_num": 14, "mlp_ratio": [4.0, 4.0, 3.5, 4.0, 3.5, 4.0, 4.0, 3.5, 4.0, 4.0, 3.5, 4.0, 4.0, 3.5], "num_heads": [7, 7, 7, 7, 7, 7, 7, 7, 5, 7, 5, 6, 5, 7], "embed_dim": [448, 448, 448, 448, 448, 448, 448, 448, 448, 448, 448, 448, 448, 448], "num_params": 32.98164, "score": 4739829.0}]} -------------------------------------------------------------------------------- /interval_cands/tiny.json: -------------------------------------------------------------------------------- 1 | { 2 | "6.0": [ 3 | { 4 | "layer_num": 12, 5 | "mlp_ratio": [ 6 | 4, 7 | 4, 8 | 4, 9 | 4, 10 | 4, 11 | 3.5, 12 | 3.5, 13 | 4, 14 | 4, 15 | 3.5, 16 | 4, 17 | 4 18 | ], 19 | "num_heads": [ 20 | 4, 21 | 4, 22 | 4, 23 | 4, 24 | 3, 25 | 3, 26 | 3, 27 | 3, 28 | 3, 29 | 4, 30 | 3, 31 | 4 32 | ], 33 | "embed_dim": [ 34 | 192, 35 | 192, 36 | 192, 37 | 192, 38 | 192, 39 | 192, 40 | 192, 41 | 192, 42 | 192, 43 | 192, 44 | 192, 45 | 192 46 | ], 47 | "num_params": 5.99476, 48 | "score": 1010076.75 49 | } 50 | ], 51 | "7.0": [ 52 | { 53 | "layer_num": 12, 54 | "mlp_ratio": [ 55 | 4, 56 | 3.5, 57 | 4, 58 | 3.5, 59 | 3.5, 60 | 3.5, 61 | 3.5, 62 | 3.5, 63 | 3.5, 64 | 3.5, 65 | 3.5, 66 | 3.5 67 | ], 68 | "num_heads": [ 69 | 4, 70 | 4, 71 | 4, 72 | 4, 73 | 4, 74 | 3, 75 | 3, 76 | 3, 77 | 4, 78 | 4, 79 | 3, 80 | 4 81 | ], 82 | "embed_dim": [ 83 | 216, 84 | 216, 85 | 216, 86 | 216, 87 | 216, 88 | 216, 89 | 216, 90 | 216, 91 | 216, 92 | 216, 93 | 216, 94 | 216 95 | ], 96 | "num_params": 6.997192, 97 | "score": 1156532.875 98 | } 99 | ], 100 | "8.0": [ 101 | { 102 | "layer_num": 12, 103 | "mlp_ratio": [ 104 | 4, 105 | 3.5, 106 | 4, 107 | 4, 108 | 3.5, 109 | 3.5, 110 | 3.5, 111 | 4, 112 | 4, 113 | 3.5, 114 | 4, 115 | 3.5 116 | ], 117 | "num_heads": [ 118 | 3, 119 | 3, 120 | 3, 121 | 3, 122 | 3, 123 | 3, 124 | 3, 125 | 3, 126 | 3, 127 | 3, 128 | 3, 129 | 3 130 | ], 131 | "embed_dim": [ 132 | 240, 133 | 240, 134 | 240, 135 | 240, 136 | 240, 137 | 240, 138 | 240, 139 | 240, 140 | 240, 141 | 240, 142 | 240, 143 | 240 144 | ], 145 | "num_params": 7.996552, 146 | "score": 1308445.75 147 | } 148 | ], 149 | "9.0": [ 150 | { 151 | "layer_num": 12, 152 | "mlp_ratio": [ 153 | 4, 154 | 4, 155 | 4, 156 | 4, 157 | 4, 158 | 3.5, 159 | 4, 160 | 4, 161 | 4, 162 | 3.5, 163 | 4, 164 | 4 165 | ], 166 | "num_heads": [ 167 | 4, 168 | 4, 169 | 4, 170 | 4, 171 | 4, 172 | 4, 173 | 4, 174 | 4, 175 | 4, 176 | 4, 177 | 4, 178 | 4 179 | ], 180 | "embed_dim": [ 181 | 240, 182 | 240, 183 | 240, 184 | 240, 185 | 240, 186 | 240, 187 | 240, 188 | 240, 189 | 240, 190 | 240, 191 | 240, 192 | 240 193 | ], 194 | "num_params": 8.967016, 195 | "score": 1426972.25 196 | } 197 | ], 198 | "10.0": [ 199 | { 200 | "layer_num": 14, 201 | "mlp_ratio": [ 202 | 4, 203 | 4, 204 | 4, 205 | 4, 206 | 4, 207 | 3.5, 208 | 3.5, 209 | 4, 210 | 4, 211 | 3.5, 212 | 4, 213 | 3.5, 214 | 4, 215 | 4 216 | ], 217 | "num_heads": [ 218 | 4, 219 | 4, 220 | 4, 221 | 4, 222 | 4, 223 | 3, 224 | 3, 225 | 3, 226 | 4, 227 | 4, 228 | 3, 229 | 4, 230 | 3, 231 | 4 232 | ], 233 | "embed_dim": [ 234 | 240, 235 | 240, 236 | 240, 237 | 240, 238 | 240, 239 | 240, 240 | 240, 241 | 240, 242 | 240, 243 | 240, 244 | 240, 245 | 240, 246 | 240, 247 | 240 248 | ], 249 | "num_params": 9.978232, 250 | "score": 1525250.25 251 | } 252 | ], 253 | "11.0": [ 254 | { 255 | "layer_num": 14, 256 | "mlp_ratio": [ 257 | 4, 258 | 4, 259 | 4, 260 | 4, 261 | 4, 262 | 4, 263 | 4, 264 | 4, 265 | 4, 266 | 4, 267 | 4, 268 | 4, 269 | 4, 270 | 4 271 | ], 272 | "num_heads": [ 273 | 4, 274 | 4, 275 | 4, 276 | 4, 277 | 4, 278 | 4, 279 | 4, 280 | 4, 281 | 4, 282 | 4, 283 | 4, 284 | 4, 285 | 4, 286 | 4 287 | ], 288 | "embed_dim": [ 289 | 240, 290 | 240, 291 | 240, 292 | 240, 293 | 240, 294 | 240, 295 | 240, 296 | 240, 297 | 240, 298 | 240, 299 | 240, 300 | 240, 301 | 240, 302 | 240 303 | ], 304 | "num_params": 10.517272, 305 | "score": 1581370.75 306 | } 307 | ] 308 | } -------------------------------------------------------------------------------- /lib/config.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict 2 | import yaml 3 | 4 | 5 | class edict(EasyDict): 6 | def __setattr__(self, name, value): 7 | if isinstance(value, (list, tuple)): 8 | value = [self.__class__(x) 9 | if isinstance(x, dict) else x for x in value] 10 | elif isinstance(value, dict) and not isinstance(value, self.__class__): 11 | if not isinstance(next(iter(value)), int): 12 | value = self.__class__(value) 13 | dict.__setattr__(self, name, value) 14 | dict.__setitem__(self, name, value) 15 | 16 | 17 | cfg = edict() 18 | 19 | 20 | def _edict2dict(dest_dict, src_edict): 21 | if isinstance(dest_dict, dict) and isinstance(src_edict, dict): 22 | for k, v in src_edict.items(): 23 | if not isinstance(v, edict): 24 | dest_dict[k] = v 25 | else: 26 | dest_dict[k] = {} 27 | _edict2dict(dest_dict[k], v) 28 | else: 29 | return 30 | 31 | def gen_config(config_file): 32 | cfg_dict = {} 33 | _edict2dict(cfg_dict, cfg) 34 | with open(config_file, 'w') as f: 35 | yaml.dump(cfg_dict, f, default_flow_style=False) 36 | 37 | 38 | def _update_config(base_cfg, exp_cfg): 39 | if isinstance(base_cfg, edict) and isinstance(exp_cfg, edict): 40 | for k, v in exp_cfg.items(): 41 | base_cfg[k] = v 42 | else: 43 | return 44 | 45 | 46 | def update_config_from_file(filename): 47 | exp_config = None 48 | with open(filename) as f: 49 | exp_config = edict(yaml.safe_load(f)) 50 | _update_config(cfg, exp_config) 51 | 52 | 53 | -------------------------------------------------------------------------------- /lib/cuda.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from timm.utils.clip_grad import dispatch_clip_grad 3 | 4 | 5 | class NativeScaler: 6 | state_dict_key = "amp_scaler" 7 | 8 | def __init__(self): 9 | self._scaler = torch.cuda.amp.GradScaler() 10 | 11 | def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False): 12 | #self._scaler.scale(loss).backward(create_graph=create_graph) 13 | if clip_grad is not None: 14 | assert parameters is not None 15 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 16 | dispatch_clip_grad(parameters, clip_grad, mode=clip_mode) 17 | self._scaler.step(optimizer) 18 | self._scaler.update() 19 | 20 | def state_dict(self): 21 | return self._scaler.state_dict() 22 | 23 | def load_state_dict(self, state_dict): 24 | self._scaler.load_state_dict(state_dict) 25 | -------------------------------------------------------------------------------- /lib/datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import json 4 | import torch 5 | import scipy 6 | import scipy.io as sio 7 | from skimage import io 8 | 9 | from torchvision import datasets, transforms 10 | from torchvision.datasets.folder import ImageFolder, default_loader 11 | 12 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 13 | from timm.data import create_transform 14 | 15 | class Flowers(ImageFolder): 16 | def __init__(self, root, train=True, transform=None, **kwargs): 17 | self.dataset_root = root 18 | self.loader = default_loader 19 | self.target_transform = None 20 | self.transform = transform 21 | label_path = os.path.join(root, 'imagelabels.mat') 22 | split_path = os.path.join(root, 'setid.mat') 23 | 24 | print('Dataset Flowers is trained with resolution 224!') 25 | 26 | # labels 27 | labels = sio.loadmat(label_path)['labels'][0] 28 | self.img_to_label = dict() 29 | for i in range(len(labels)): 30 | self.img_to_label[i] = labels[i] 31 | 32 | splits = sio.loadmat(split_path) 33 | self.trnid, self.valid, self.tstid = sorted(splits['trnid'][0].tolist()), \ 34 | sorted(splits['valid'][0].tolist()), \ 35 | sorted(splits['tstid'][0].tolist()) 36 | if train: 37 | self.imgs = self.trnid + self.valid 38 | else: 39 | self.imgs = self.tstid 40 | 41 | self.samples = [] 42 | for item in self.imgs: 43 | self.samples.append((os.path.join(root, 'jpg', "image_{:05d}.jpg".format(item)), self.img_to_label[item-1]-1)) 44 | 45 | class Cars196(ImageFolder, datasets.CIFAR10): 46 | base_folder_devkit = 'devkit' 47 | base_folder_trainims = 'cars_train' 48 | base_folder_testims = 'cars_test' 49 | 50 | filename_testanno = 'cars_test_annos.mat' 51 | filename_trainanno = 'cars_train_annos.mat' 52 | 53 | base_folder = 'cars_train' 54 | train_list = [ 55 | ['00001.jpg', '8df595812fee3ca9a215e1ad4b0fb0c4'], 56 | ['00002.jpg', '4b9e5efcc3612378ec63a22f618b5028'] 57 | ] 58 | test_list = [] 59 | num_training_classes = 98 # 196/2 60 | 61 | def __init__(self, root, train=False, transform=None, target_transform=None, **kwargs): 62 | self.root = root 63 | self.transform = transform 64 | 65 | self.target_transform = target_transform 66 | self.loader = default_loader 67 | print('Dataset Cars196 is trained with resolution 224!') 68 | 69 | self.samples = [] 70 | self.nb_classes = 196 71 | 72 | if train: 73 | labels = \ 74 | sio.loadmat(os.path.join(self.root, self.base_folder_devkit, self.filename_trainanno))['annotations'][0] 75 | for item in labels: 76 | img_name = item[-1].tolist()[0] 77 | label = int(item[4]) - 1 78 | self.samples.append((os.path.join(self.root, self.base_folder_trainims, img_name), label)) 79 | else: 80 | labels = \ 81 | sio.loadmat(os.path.join(self.root, 'cars_test_annos_withlabels.mat'))['annotations'][0] 82 | for item in labels: 83 | img_name = item[-1].tolist()[0] 84 | label = int(item[-2]) - 1 85 | self.samples.append((os.path.join(self.root, self.base_folder_testims, img_name), label)) 86 | 87 | class Pets(ImageFolder): 88 | def __init__(self, root, train=True, transform=None, target_transform=None, **kwargs): 89 | self.dataset_root = root 90 | self.loader = default_loader 91 | self.target_transform = None 92 | self.transform = transform 93 | train_list_path = os.path.join(self.dataset_root, 'annotations', 'trainval.txt') 94 | test_list_path = os.path.join(self.dataset_root, 'annotations', 'test.txt') 95 | 96 | self.samples = [] 97 | if train: 98 | with open(train_list_path, 'r') as f: 99 | for line in f: 100 | img_name = line.split(' ')[0] 101 | label = int(line.split(' ')[1]) 102 | self.samples.append((os.path.join(root, 'images', "{}.jpg".format(img_name)), label-1)) 103 | else: 104 | with open(test_list_path, 'r') as f: 105 | for line in f: 106 | img_name = line.split(' ')[0] 107 | label = int(line.split(' ')[1]) 108 | self.samples.append((os.path.join(root, 'images', "{}.jpg".format(img_name)), label-1)) 109 | 110 | class INatDataset(ImageFolder): 111 | def __init__(self, root, train=True, year=2018, transform=None, target_transform=None, 112 | category='name', loader=default_loader): 113 | self.transform = transform 114 | self.loader = loader 115 | self.target_transform = target_transform 116 | self.year = year 117 | # assert category in ['kingdom','phylum','class','order','supercategory','family','genus','name'] 118 | path_json = os.path.join(root, f'{"train" if train else "val"}{year}.json') 119 | with open(path_json) as json_file: 120 | data = json.load(json_file) 121 | 122 | with open(os.path.join(root, 'categories.json')) as json_file: 123 | data_catg = json.load(json_file) 124 | 125 | path_json_for_targeter = os.path.join(root, f"train{year}.json") 126 | 127 | with open(path_json_for_targeter) as json_file: 128 | data_for_targeter = json.load(json_file) 129 | 130 | targeter = {} 131 | indexer = 0 132 | for elem in data_for_targeter['annotations']: 133 | king = [] 134 | king.append(data_catg[int(elem['category_id'])][category]) 135 | if king[0] not in targeter.keys(): 136 | targeter[king[0]] = indexer 137 | indexer += 1 138 | self.nb_classes = len(targeter) 139 | 140 | self.samples = [] 141 | for elem in data['images']: 142 | cut = elem['file_name'].split('/') 143 | target_current = int(cut[2]) 144 | path_current = os.path.join(root, cut[0], cut[2], cut[3]) 145 | 146 | categors = data_catg[target_current] 147 | target_current_true = targeter[categors[category]] 148 | self.samples.append((path_current, target_current_true)) 149 | 150 | # __getitem__ and __len__ inherited from ImageFolder 151 | 152 | def build_dataset(is_train, args, folder_name=None): 153 | transform = build_transform(is_train, args) 154 | 155 | if args.data_set == 'CIFAR10': 156 | dataset = datasets.CIFAR10(args.data_path, train=is_train, transform=transform, download=True) 157 | nb_classes = 10 158 | elif args.data_set == 'CIFAR100': 159 | dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform, download=True) 160 | nb_classes = 100 161 | elif args.data_set == 'CARS': 162 | dataset = Cars196(args.data_path, train=is_train, transform=transform) 163 | nb_classes = 196 164 | elif args.data_set == 'PETS': 165 | dataset = Pets(args.data_path, train=is_train, transform=transform) 166 | nb_classes = 37 167 | elif args.data_set == 'FLOWERS': 168 | dataset = Flowers(args.data_path, train=is_train, transform=transform) 169 | nb_classes = 102 170 | elif args.data_set == 'IMNET': 171 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 172 | dataset = datasets.ImageFolder(root, transform=transform) 173 | nb_classes = 1000 174 | elif args.data_set == 'EVO_IMNET': 175 | root = os.path.join(args.data_path, folder_name) 176 | dataset = datasets.ImageFolder(root, transform=transform) 177 | nb_classes = 1000 178 | elif args.data_set == 'INAT': 179 | dataset = INatDataset(args.data_path, train=is_train, year=2018, 180 | category=args.inat_category, transform=transform) 181 | nb_classes = dataset.nb_classes 182 | elif args.data_set == 'INAT19': 183 | dataset = INatDataset(args.data_path, train=is_train, year=2019, 184 | category=args.inat_category, transform=transform) 185 | nb_classes = dataset.nb_classes 186 | 187 | return dataset, nb_classes 188 | 189 | def build_transform(is_train, args): 190 | resize_im = args.input_size > 32 191 | if is_train: 192 | # this should always dispatch to transforms_imagenet_train 193 | transform = create_transform( 194 | input_size=args.input_size, 195 | is_training=True, 196 | color_jitter=args.color_jitter, 197 | auto_augment=args.aa, 198 | interpolation=args.train_interpolation, 199 | re_prob=args.reprob, 200 | re_mode=args.remode, 201 | re_count=args.recount, 202 | ) 203 | if not resize_im: 204 | # replace RandomResizedCropAndInterpolation with 205 | # RandomCrop 206 | transform.transforms[0] = transforms.RandomCrop( 207 | args.input_size, padding=4) 208 | return transform 209 | 210 | t = [] 211 | if resize_im: 212 | size = int((256 / 224) * args.input_size) 213 | t.append( 214 | transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images 215 | ) 216 | #t.append(transforms.CenterCrop(args.input_size)) 217 | crop_tf = { 218 | 1: transforms.CenterCrop, 219 | 5: transforms.FiveCrop, 220 | 10: transforms.TenCrop, 221 | } 222 | t.append(crop_tf[args.eval_crops](args.input_size)) 223 | 224 | if resize_im and args.eval_crops > 1: 225 | t.append(transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops]))) 226 | else: 227 | t.append(transforms.ToTensor()) 228 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) 229 | return transforms.Compose(t) 230 | -------------------------------------------------------------------------------- /lib/imagenet_withhold.py: -------------------------------------------------------------------------------- 1 | 2 | from PIL import Image 3 | import io 4 | 5 | import torch 6 | from torch.utils.data import Dataset 7 | import torchvision.transforms as transforms 8 | 9 | 10 | class ImageNet_Withhold(Dataset): 11 | def __init__(self, data_root, ann_file='', transform=None, train=True, task ='train'): 12 | super(ImageNet_Withhold, self).__init__() 13 | ann_file = ann_file + '/' + 'val_true.txt' 14 | train_split = (task == 'train' or task == 'val') 15 | self.data_root = data_root + '/'+ ('train' if train_split else 'val') 16 | 17 | self.data = [] 18 | self.nb_classes = 0 19 | folders = {} 20 | cnt = 0 21 | self.z = ZipReader() 22 | # if train: 23 | # for member in self.tarfile.getmembers(): 24 | # print(member) 25 | # self.tarfile = tarfile.open(self.data_root) 26 | 27 | f = open(ann_file) 28 | prefix = 'data/sdb/imagenet'+'/'+ ('train' if train_split else 'val') + '/' 29 | for line in f: 30 | tmp = line.strip().split('\t')[0] 31 | class_pic = tmp.split('/') 32 | class_tmp = class_pic[0] 33 | pic = class_pic[1] 34 | 35 | if class_tmp in folders: 36 | # print(self.tarfile.getmember(('train/' if train else 'val/') + tmp[0] + '.JPEG')) 37 | self.data.append((class_tmp + '.zip', prefix + tmp + '.JPEG', folders[class_tmp])) 38 | else: 39 | folders[class_tmp] = cnt 40 | cnt += 1 41 | self.data.append((class_tmp + '.zip', prefix + tmp + '.JPEG',folders[class_tmp])) 42 | 43 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 44 | std=[0.229, 0.224, 0.225]) 45 | if transform is not None: 46 | self.transforms = transform 47 | else: 48 | if train: 49 | self.transforms = transforms.Compose([ 50 | transforms.RandomSizedCrop(224), 51 | transforms.RandomHorizontalFlip(), 52 | transforms.ToTensor(), 53 | normalize, 54 | ]) 55 | else: 56 | self.transforms = transforms.Compose([ 57 | transforms.Scale(256), 58 | transforms.CenterCrop(224), 59 | transforms.ToTensor(), 60 | normalize, 61 | ]) 62 | 63 | 64 | self.nb_classes = cnt 65 | def __len__(self): 66 | return len(self.data) 67 | 68 | def __getitem__(self, idx): 69 | 70 | # print('extract_file', time.time()-start_time) 71 | iob = self.z.read(self.data_root + '/' + self.data[idx][0], self.data[idx][1]) 72 | iob = io.BytesIO(iob) 73 | img = Image.open(iob).convert('RGB') 74 | target = self.data[idx][2] 75 | if self.transforms is not None: 76 | img = self.transforms(img) 77 | # print('open', time.time()-start_time) 78 | return img, target 79 | -------------------------------------------------------------------------------- /lib/samplers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | import math 4 | 5 | 6 | class RASampler(torch.utils.data.Sampler): 7 | """Sampler that restricts data loading to a subset of the dataset for distributed, 8 | with repeated augmentation. 9 | It ensures that different each augmented version of a sample will be visible to a 10 | different process (GPU) 11 | Heavily based on torch.utils.data.DistributedSampler 12 | """ 13 | 14 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): 15 | if num_replicas is None: 16 | if not dist.is_available(): 17 | raise RuntimeError("Requires distributed package to be available") 18 | num_replicas = dist.get_world_size() 19 | if rank is None: 20 | if not dist.is_available(): 21 | raise RuntimeError("Requires distributed package to be available") 22 | rank = dist.get_rank() 23 | self.dataset = dataset 24 | self.num_replicas = num_replicas 25 | self.rank = rank 26 | self.epoch = 0 27 | self.num_samples = int(math.ceil(len(self.dataset) * 3.0 / self.num_replicas)) 28 | self.total_size = self.num_samples * self.num_replicas 29 | # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas)) 30 | self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas)) 31 | self.shuffle = shuffle 32 | 33 | def __iter__(self): 34 | # deterministically shuffle based on epoch 35 | g = torch.Generator() 36 | g.manual_seed(self.epoch) 37 | if self.shuffle: 38 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 39 | else: 40 | indices = list(range(len(self.dataset))) 41 | 42 | # add extra samples to make it evenly divisible 43 | indices = [ele for ele in indices for i in range(3)] 44 | indices += indices[:(self.total_size - len(indices))] 45 | assert len(indices) == self.total_size 46 | 47 | # subsample 48 | indices = indices[self.rank:self.total_size:self.num_replicas] 49 | 50 | assert len(indices) == self.num_samples 51 | return iter(indices[:self.num_selected_samples]) 52 | 53 | def __len__(self): 54 | return self.num_selected_samples 55 | 56 | def set_epoch(self, epoch): 57 | self.epoch = epoch 58 | -------------------------------------------------------------------------------- /lib/score_maker.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | import math 4 | import logging 5 | from torch import optim 6 | import torch.nn.functional as F 7 | import torch.distributed as dist 8 | from contextlib import suppress 9 | from scipy import stats 10 | import numpy as np 11 | from sklearn.metrics import roc_auc_score 12 | from tqdm import tqdm 13 | import random 14 | import functools 15 | import torch.distributed as dist 16 | from typing import Iterable, Optional 17 | from timm.data import Mixup 18 | from timm.optim import create_optimizer 19 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 20 | 21 | from model.supernet_transformer import Vision_TransformerSuper 22 | 23 | 24 | def is_master(args, local=False): 25 | return is_local_master(args) if local else is_global_master(args) 26 | 27 | 28 | def unwrap_model(model): 29 | if hasattr(model, 'module'): 30 | return model.module 31 | else: 32 | return model 33 | 34 | 35 | def gen_key(embed_dim, choice): 36 | return f'{embed_dim},{choice}' 37 | 38 | 39 | class ScoreMaker(object): 40 | def __init__(self): 41 | self.grad_dict_before_train = {} 42 | self.grad_dict_after_train = {} 43 | self.param_val_dict = {} 44 | self.item_score_dict = {} 45 | self.key_items = ['attn.qkv.weight', 'fc1.weight'] 46 | 47 | def drop_gradient(self): 48 | self.grad_dict = None 49 | 50 | def nan_to_zero(self, a): 51 | return torch.where(torch.isnan(a), torch.full_like(a, 0), a) 52 | 53 | def build_avg_image(self, s): 54 | # 3 channel in image 55 | assert s[1] == 3 56 | 57 | img = torch.zeros(s) 58 | mean = IMAGENET_DEFAULT_MEAN 59 | std = IMAGENET_DEFAULT_STD 60 | for i in range(3): 61 | torch.nn.init.normal_(img[:, i, :, :], mean=mean[i], std=std[i]) 62 | return img 63 | 64 | def get_gradient(self, model, criterion, data_loader, args, choices, device, mixup_fn: Optional[Mixup] = None): 65 | config = {} 66 | dimensions = ['mlp_ratio', 'num_heads'] 67 | depth = max(choices['depth']) 68 | for dimension in dimensions: 69 | config[dimension] = [max(choices[dimension]) for _ in range(depth)] 70 | config['embed_dim'] = [max(choices['embed_dim'])] * depth 71 | config['layer_num'] = depth 72 | 73 | model_module = unwrap_model(model) 74 | model_module.set_sample_config(config=config) 75 | 76 | model.train() 77 | criterion.train() 78 | 79 | random.seed(0) 80 | 81 | optimizer = create_optimizer(args, model_module) 82 | 83 | batch_num = 0 84 | grad_dict = {} 85 | 86 | optimizer.zero_grad() 87 | 88 | for samples, targets in data_loader: 89 | samples = samples.to(device, non_blocking=True) 90 | targets = targets.to(device, non_blocking=True) 91 | 92 | if mixup_fn is not None: 93 | samples, targets = mixup_fn(samples, targets) 94 | 95 | if args.data_free: 96 | input_dim = list(samples[0, :].shape) 97 | inputs = self.build_avg_image([64] + input_dim).to(device) # 64 batch image 98 | output = model.forward(inputs) 99 | torch.sum(output).backward() 100 | batch_num += 1 101 | print('data free!') 102 | break 103 | 104 | outputs = model(samples) 105 | loss = criterion(outputs, targets) 106 | loss.backward() 107 | batch_num += 1 108 | 109 | for k, param in model_module.named_parameters(): 110 | if param.requires_grad: 111 | grad_dict[k] = param.grad 112 | 113 | if args.distributed: 114 | dist.barrier() 115 | result = torch.tensor([batch_num]).to(args.device, non_blocking=True) 116 | dist.all_reduce(result) 117 | batch_num = result[0] 118 | for k, v in grad_dict.items(): 119 | dist.all_reduce(grad_dict[k]) 120 | 121 | for k, v in grad_dict.items(): 122 | grad_dict[k] = v / batch_num 123 | # grad_dict[k] = grad_dict[k].cpu() 124 | 125 | self.grad_dict = grad_dict 126 | 127 | def get_block_scores(self, model, args, choices): 128 | model_module = unwrap_model(model) 129 | param_val_dict = model_module.state_dict() 130 | 131 | head_choices = choices['num_heads'] 132 | mlp_choices = choices['mlp_ratio'] 133 | layers = max(choices['depth']) 134 | embed_dims = choices['embed_dim'] 135 | max_dim = max(choices['embed_dim']) 136 | head_dim = model_module.super_embed_dim // model_module.super_num_heads 137 | 138 | head_score_dict = {} 139 | mlp_score_dict = {} 140 | for embed_dim in embed_dims: 141 | for j in range(len(head_choices) - 1): 142 | head_score_dict[gen_key(embed_dim, head_choices[j + 1])] = [] 143 | for j in range(len(mlp_choices) - 1): 144 | mlp_score_dict[gen_key(embed_dim, mlp_choices[j + 1])] = [] 145 | 146 | def get_item_score(pv, gv, block_score_method): 147 | pg = pv.mul(gv) 148 | 149 | if 'balance_taylor6_norm' in block_score_method: 150 | item_score = pg.abs() / pv.abs().sum() / gv.abs().sum() / pg.abs().sum() 151 | elif 'taylor6_doublenorm' in block_score_method: 152 | item_score = pg.abs() / pg.abs().sum() / pg.abs().sum() 153 | elif 'taylor6_norm' in block_score_method: 154 | item_score = pg.abs() / pg.abs().sum() 155 | elif 'balance_taylor6' in block_score_method: 156 | item_score = pg.abs() / pv.abs().sum() / gv.abs().sum() 157 | elif 'taylor6' in block_score_method: 158 | item_score = pg.abs() 159 | elif 'balance_taylor5_norm' in block_score_method: 160 | item_score = pg / pv.sum().abs() / gv.sum().abs() / pg.sum().abs() 161 | elif 'taylor5_doublenorm' in block_score_method: 162 | item_score = pg / pg.sum().abs() / pg.sum().abs() 163 | elif 'taylor5_norm' in block_score_method: 164 | item_score = pg / pg.sum().abs() 165 | elif 'balance_taylor5' in block_score_method: 166 | item_score = pg / pv.sum().abs() / gv.sum().abs() 167 | elif 'taylor5' in block_score_method: 168 | item_score = pg 169 | elif 'taylor9_norm' in block_score_method: 170 | item_score = gv.abs() / gv.abs().sum() 171 | elif 'taylor9_doublenorm' in block_score_method: 172 | item_score = gv.abs() / gv.abs().sum() / gv.abs().sum() 173 | elif 'taylor9' in block_score_method: 174 | item_score = gv.abs() 175 | elif 'l1norm' in block_score_method: 176 | item_score = pv.abs() 177 | else: 178 | item_score = pv 179 | return item_score 180 | 181 | for embed_dim in embed_dims: 182 | for i in range(layers): 183 | qkv_w = f'blocks.{i}.attn.qkv.weight' 184 | c_fc_w = f'blocks.{i}.fc1.weight' 185 | c_proj_w = f'blocks.{i}.fc2.weight' 186 | qkv_score = get_item_score(param_val_dict[qkv_w][:, :max_dim], 187 | self.grad_dict[qkv_w][:, :max_dim], 188 | args.block_score_method_for_head)[:, :embed_dim] 189 | c_fc_score = get_item_score(param_val_dict[c_fc_w][:, :max_dim], 190 | self.grad_dict[c_fc_w][:, :max_dim], 191 | args.block_score_method_for_mlp)[:, :embed_dim] 192 | c_proj_score = get_item_score(param_val_dict[c_proj_w][:max_dim, :], 193 | self.grad_dict[c_proj_w][:max_dim, :], 194 | args.block_score_method_for_mlp)[:embed_dim, :] 195 | for j in range(len(head_choices) - 1): 196 | qkv_embed_base = head_dim * head_choices[j] 197 | qkv_embed_dim = head_dim * head_choices[j + 1] 198 | left_qkv_score = torch.cat([qkv_score[qkv_embed_base * 3 + k:qkv_embed_dim * 3:3, :] for k in range(3)], dim=0) 199 | score = left_qkv_score.sum().abs().cpu() 200 | head_score_dict[gen_key(embed_dim, head_choices[j + 1])].append(score) 201 | for j in range(len(mlp_choices) - 1): 202 | mlp_embed_base = int(embed_dim * mlp_choices[j]) 203 | mlp_embed_dim = int(embed_dim * mlp_choices[j + 1]) 204 | left_c_fc_score = c_fc_score[mlp_embed_base:mlp_embed_dim, :] 205 | left_c_proj_score = c_proj_score[:, mlp_embed_base:mlp_embed_dim] 206 | score = left_c_fc_score.sum().abs().cpu() + left_c_proj_score.sum().abs().cpu() 207 | mlp_score_dict[gen_key(embed_dim, mlp_choices[j + 1])].append(score) 208 | 209 | return {'head_scores': head_score_dict, 'mlp_scores': mlp_score_dict} 210 | 211 | 212 | def get_item_score(self, model, criterion, data_loader, args, choices, device, mixup_fn: Optional[Mixup] = None): 213 | 214 | config = {} 215 | dimensions = ['mlp_ratio', 'num_heads'] 216 | depth = max(choices['depth']) 217 | for dimension in dimensions: 218 | config[dimension] = [max(choices[dimension]) for _ in range(depth)] 219 | config['embed_dim'] = [max(choices['embed_dim'])] * depth 220 | config['layer_num'] = depth 221 | 222 | model_module = unwrap_model(model) 223 | model_module.set_sample_config(config=config) 224 | 225 | param_val_dict = model_module.state_dict() 226 | grad_dict = {} 227 | 228 | if 'taylor' in args.score_method: 229 | model.train() 230 | criterion.train() 231 | 232 | random.seed(0) 233 | 234 | optimizer = create_optimizer(args, model_module) 235 | 236 | batch_num = 0 237 | 238 | for samples, targets in data_loader: 239 | samples = samples.to(device, non_blocking=True) 240 | targets = targets.to(device, non_blocking=True) 241 | 242 | if mixup_fn is not None: 243 | samples, targets = mixup_fn(samples, targets) 244 | 245 | outputs = model(samples) 246 | loss = criterion(outputs, targets) 247 | 248 | optimizer.zero_grad() 249 | loss.backward() 250 | 251 | for k, param in model_module.named_parameters(): 252 | if param.requires_grad: 253 | if batch_num == 0: 254 | grad_dict[k] = copy.deepcopy(param.grad) 255 | else: 256 | grad_dict[k] = grad_dict[k] + param.grad 257 | 258 | batch_num += 1 259 | 260 | if args.distributed: 261 | dist.barrier() 262 | result = torch.tensor([batch_num]).to(args.device, non_blocking=True) 263 | dist.all_reduce(result) 264 | batch_num = result[0] 265 | for k, v in grad_dict.items(): 266 | dist.all_reduce(grad_dict[k]) 267 | 268 | for k, v in grad_dict.items(): 269 | grad_dict[k] = v / batch_num 270 | grad_dict[k] = grad_dict[k] 271 | 272 | for k in param_val_dict.keys(): 273 | for key_item in self.key_items: 274 | if key_item in k: 275 | if 'l1norm' in args.score_method: 276 | self.item_score_dict[k] = param_val_dict[k].abs().cpu() 277 | elif 'taylor5' in args.score_method: 278 | self.item_score_dict[k] = param_val_dict[k].mul(grad_dict[k]).cpu() 279 | elif 'taylor6' in args.score_method: 280 | self.item_score_dict[k] = param_val_dict[k].mul(grad_dict[k]).abs().cpu() 281 | elif 'taylor9' in args.score_method: 282 | self.item_score_dict[k] = grad_dict[k].abs().cpu() 283 | else: 284 | assert False 285 | 286 | def get_head_score(self, head_choices, layers, head_dim, embed_dims, layer_norm=False): 287 | score_dict = {} 288 | for embed_dim in embed_dims: 289 | for j in range(len(head_choices) - 1): 290 | score_dict[gen_key(embed_dim, head_choices[j + 1])] = [] 291 | for embed_dim in embed_dims: 292 | for i in range(layers): 293 | qkv_w = f'blocks.{i}.attn.qkv.weight' 294 | for j in range(len(head_choices) - 1): 295 | qkv_embed_base = head_dim * head_choices[j] 296 | qkv_embed_dim = head_dim * head_choices[j + 1] 297 | qkv_score = self.item_score_dict[qkv_w][:,:embed_dim] 298 | item_score = torch.cat([qkv_score[qkv_embed_base * 3 + k:qkv_embed_dim * 3:3, :] for k in range(3)], dim=0) 299 | score = item_score.sum().abs() 300 | if layer_norm: 301 | score = score / qkv_score.sum().abs() 302 | score_dict[gen_key(embed_dim, head_choices[j + 1])].append(score) 303 | return score_dict 304 | 305 | def get_mlp_score(self, mlp_choices, layers, embed_dims, layer_norm=False): 306 | score_dict = {} 307 | for embed_dim in embed_dims: 308 | for j in range(len(mlp_choices) - 1): 309 | score_dict[gen_key(embed_dim, mlp_choices[j + 1])] = [] 310 | for embed_dim in embed_dims: 311 | for i in range(layers): 312 | c_fc_w = f'blocks.{i}.fc1.weight' 313 | for j in range(len(mlp_choices) - 1): 314 | mlp_embed_base = int(embed_dim * mlp_choices[j]) 315 | mlp_embed_dim = int(embed_dim * mlp_choices[j+1]) 316 | mlp_score = self.item_score_dict[c_fc_w][:, :embed_dim] 317 | item_score = mlp_score[mlp_embed_base:mlp_embed_dim, :] 318 | score = item_score.sum().abs() 319 | if layer_norm: 320 | score = score / mlp_score.sum().abs() 321 | score_dict[gen_key(embed_dim, mlp_choices[j + 1])].append(score) 322 | return score_dict 323 | 324 | def get_left_part_from_super_model(self, model: Vision_TransformerSuper, para_dict, sample_config): 325 | layers = model.super_layer_num 326 | sample_layers = sample_config['layer_num'] 327 | left_dict = {} 328 | 329 | embed_dims = sample_config['embed_dim'] 330 | output_dims = [out_dim for out_dim in sample_config['embed_dim'][1:]] + [sample_config['embed_dim'][-1]] 331 | 332 | left_dict['patch_embed_super.proj.weight'] = para_dict['patch_embed_super.proj.weight'][:embed_dims[0], ...] 333 | left_dict['patch_embed_super.proj.bias'] = para_dict['patch_embed_super.proj.bias'][:embed_dims[0], ...] 334 | left_dict['norm.weight'] = para_dict['norm.weight'][:embed_dims[-1]] 335 | left_dict['norm.bias'] = para_dict['norm.bias'][:embed_dims[-1]] 336 | left_dict['head.weight'] = para_dict['head.weight'][:, :embed_dims[-1]] 337 | left_dict['head.bias'] = para_dict['head.bias'][:embed_dims[-1]] 338 | 339 | for i in range(layers): 340 | qkv_w = f'blocks.{i}.attn.qkv.weight' 341 | qkv_b = f'blocks.{i}.attn.qkv.bias' 342 | proj_w = f'blocks.{i}.attn.proj.weight' 343 | proj_b = f'blocks.{i}.attn.proj.bias' 344 | ln1_w = f'blocks.{i}.attn_layer_norm.weight' 345 | ln1_b = f'blocks.{i}.attn_layer_norm.bias' 346 | c_fc_w = f'blocks.{i}.fc1.weight' 347 | c_fc_b = f'blocks.{i}.fc1.bias' 348 | c_proj_w = f'blocks.{i}.fc2.weight' 349 | c_proj_b = f'blocks.{i}.fc2.bias' 350 | ln2_w = f'blocks.{i}.ffn_layer_norm.weight' 351 | ln2_b = f'blocks.{i}.ffn_layer_norm.bias' 352 | if i < sample_layers: 353 | num_heads = sample_config['num_heads'][i] 354 | head_dim = model.super_embed_dim // model.super_num_heads 355 | qk_embed_dim = head_dim * num_heads 356 | mlp_ratio = sample_config['mlp_ratio'][i] 357 | embed_dim = embed_dims[i] 358 | mlp_width = int(embed_dim * mlp_ratio) 359 | output_dim = output_dims[i] 360 | 361 | left_dict[qkv_w] = para_dict[qkv_w][:, :embed_dim] 362 | left_dict[qkv_w] = torch.cat([left_dict[qkv_w][i:qk_embed_dim*3:3, :] for i in range(3)], dim=0) 363 | 364 | # left_dict[qkv_b] = para_dict[qkv_b][:qk_embed_dim*3] 365 | left_dict[qkv_b] = torch.cat([para_dict[qkv_b][i:qk_embed_dim*3:3] for i in range(3)]) 366 | 367 | left_dict[proj_w] = para_dict[proj_w][:, :qk_embed_dim] 368 | left_dict[proj_w] = left_dict[proj_w][:embed_dim, :] 369 | 370 | left_dict[proj_b] = para_dict[proj_b][:embed_dim] 371 | 372 | left_dict[ln1_w] = para_dict[ln1_w][:embed_dim] 373 | 374 | left_dict[ln1_b] = para_dict[ln1_b][:embed_dim] 375 | 376 | left_dict[c_fc_w] = para_dict[c_fc_w][:, :embed_dim] 377 | left_dict[c_fc_w] = left_dict[c_fc_w][:mlp_width, :] 378 | 379 | left_dict[c_fc_b] = para_dict[c_fc_b][:mlp_width] 380 | 381 | left_dict[c_proj_w] = para_dict[c_proj_w][:, :mlp_width] 382 | left_dict[c_proj_w] = left_dict[c_proj_w][:output_dim, :] 383 | 384 | left_dict[c_proj_b] = para_dict[c_proj_b][:output_dim] 385 | 386 | left_dict[ln2_w] = para_dict[ln2_w][:output_dim] 387 | 388 | left_dict[ln2_b] = para_dict[ln2_b][:output_dim] 389 | else: 390 | continue 391 | 392 | num_paras = 0 393 | for k, v in left_dict.items(): 394 | num_paras += v.numel() 395 | 396 | return left_dict, num_paras 397 | 398 | def get_left_attn_mlp_from_super_model(self, model: Vision_TransformerSuper, para_dict, sample_config): 399 | 400 | layers = model.super_layer_num 401 | sample_layers = sample_config['layer_num'] 402 | left_dict = {} 403 | 404 | embed_dims = sample_config['embed_dim'] 405 | output_dims = [out_dim for out_dim in sample_config['embed_dim'][1:]] + [sample_config['embed_dim'][-1]] 406 | 407 | for i in range(layers): 408 | qkv_w = f'blocks.{i}.attn.qkv.weight' 409 | qkv_b = f'blocks.{i}.attn.qkv.bias' 410 | proj_w = f'blocks.{i}.attn.proj.weight' 411 | proj_b = f'blocks.{i}.attn.proj.bias' 412 | c_fc_w = f'blocks.{i}.fc1.weight' 413 | c_fc_b = f'blocks.{i}.fc1.bias' 414 | c_proj_w = f'blocks.{i}.fc2.weight' 415 | c_proj_b = f'blocks.{i}.fc2.bias' 416 | if i < sample_layers: 417 | num_heads = sample_config['num_heads'][i] 418 | head_dim = model.super_embed_dim // model.super_num_heads 419 | qk_embed_dim = head_dim * num_heads 420 | mlp_ratio = sample_config['mlp_ratio'][i] 421 | embed_dim = embed_dims[i] 422 | mlp_width = int(embed_dim * mlp_ratio) 423 | output_dim = output_dims[i] 424 | 425 | left_dict[qkv_w] = para_dict[qkv_w][:, :embed_dim] 426 | left_dict[qkv_w] = torch.cat([left_dict[qkv_w][i:qk_embed_dim * 3:3, :] for i in range(3)], dim=0) 427 | 428 | # left_dict[qkv_b] = para_dict[qkv_b][:qk_embed_dim*3] 429 | left_dict[qkv_b] = torch.cat([para_dict[qkv_b][i:qk_embed_dim * 3:3] for i in range(3)], dim=0) 430 | 431 | left_dict[proj_w] = para_dict[proj_w][:, :qk_embed_dim] 432 | left_dict[proj_w] = left_dict[proj_w][:embed_dim, :] 433 | 434 | left_dict[proj_b] = para_dict[proj_b][:embed_dim] 435 | 436 | left_dict[c_fc_w] = para_dict[c_fc_w][:, :embed_dim] 437 | left_dict[c_fc_w] = left_dict[c_fc_w][:mlp_width, :] 438 | 439 | left_dict[c_fc_b] = para_dict[c_fc_b][:mlp_width] 440 | 441 | left_dict[c_proj_w] = para_dict[c_proj_w][:, :mlp_width] 442 | left_dict[c_proj_w] = left_dict[c_proj_w][:output_dim, :] 443 | 444 | left_dict[c_proj_b] = para_dict[c_proj_b][:output_dim] 445 | else: 446 | continue 447 | 448 | num_paras = 0 449 | for k, v in left_dict.items(): 450 | num_paras += v.numel() 451 | 452 | return left_dict, num_paras 453 | 454 | def get_scores(self, model, score_methods, config): 455 | score_methods = score_methods.strip().split('+') 456 | scores = [] 457 | for score_method in score_methods: 458 | scores.append(self.get_score(model, score_method, config)) 459 | return scores 460 | 461 | def get_score(self, model, score_method, config): 462 | 463 | if score_method == 'entropy': 464 | depth, mlp_ratio, num_heads, embed_dim = config['layer_num'], config['mlp_ratio'], config['num_heads'], config['embed_dim'] 465 | entropy_score = 0. 466 | for i in range(depth): 467 | d = embed_dim[i] 468 | n = 14 * 14 # input_size = 224, patch_size = 16 469 | d_f = mlp_ratio[i] * d 470 | d_h = 64 471 | n_h = num_heads[i] 472 | entropy_score += math.log(d_f) + math.log(d_h * n_h) + math.log(n) + 4 * math.log(d) 473 | return entropy_score 474 | 475 | super_paras = unwrap_model(model).state_dict() 476 | 477 | if 'left_attn_mlp' in score_method: 478 | para, num_paras = self.get_left_attn_mlp_from_super_model(unwrap_model(model), super_paras, config) 479 | else: 480 | para, num_paras = self.get_left_part_from_super_model(unwrap_model(model), super_paras, config) 481 | grad = None 482 | if 'taylor' in score_method: 483 | if 'left_attn_mlp' in score_method: 484 | grad, _ = self.get_left_attn_mlp_from_super_model(unwrap_model(model), self.grad_dict, config) 485 | else: 486 | grad, _ = self.get_left_part_from_super_model(unwrap_model(model), self.grad_dict, config) 487 | 488 | if 'avg' not in score_method: 489 | num_paras = None 490 | 491 | if 'l1norm' in score_method: 492 | res = self.criterion_l_l1norm(para, num_paras) 493 | elif 'l1norm_norm' in score_method: 494 | res = self.criterion_l_l1norm(paras, super_paras=super_paras) 495 | elif 'taylor5' in score_method: 496 | res = self.criterion_l_taylor5(para, grad, num_paras) 497 | elif 'taylor5_norm' in score_method: 498 | res = self.criterion_l_taylor5(paras, grads, super_paras=super_paras, super_grads=super_grads) 499 | elif 'taylor6' in score_method: 500 | res = self.criterion_l_taylor6(para, grad, num_paras) 501 | elif 'taylor6_norm' in score_method: 502 | res = self.criterion_l_taylor6(paras, grads, super_paras=super_paras, super_grads=super_grads) 503 | elif 'taylor9' in score_method: 504 | res = self.criterion_l_taylor9(para, grad, num_paras) 505 | elif 'taylor9_norm' in score_method: 506 | res = self.criterion_l_taylor9(paras, grads, super_paras=super_paras, super_grads=super_grads) 507 | else: 508 | assert False 509 | 510 | if 'pruned' in score_method: 511 | res = - res 512 | if type(res) == float: 513 | return res 514 | else: 515 | return res.cpu() 516 | 517 | def criterion_l_l1norm(self, paras, num_paras=None, super_paras=None): 518 | score = 0. 519 | for k, v in paras.items(): 520 | if super_paras is not None: 521 | score += v.abs().sum() / super_paras[k].abs().sum() 522 | else: 523 | score += v.abs().sum() 524 | if num_paras: 525 | score /= num_paras 526 | return score.cpu() 527 | 528 | def criterion_l_l2norm(self, paras, num_paras=None, super_paras=None): 529 | score = 0. 530 | for k, v in paras.items(): 531 | if super_paras is not None: 532 | score += v.norm() / super_paras[k].norm() 533 | else: 534 | score += v.norm() 535 | if num_paras: 536 | score /= num_paras 537 | return score.cpu() 538 | 539 | def criterion_l_taylor1(self, paras, grads, num_paras=None, super_paras=None, super_grads=None): 540 | score = 0. 541 | for k, v in paras.items(): 542 | g = grads[k] 543 | if super_paras is not None and super_grads is not None: 544 | score += v.mul(g).sum() / super_paras[k].mul(super_grads[k]).sum() 545 | else: 546 | score += v.mul(g).sum() 547 | if num_paras: 548 | score /= num_paras 549 | score = score ** 2 550 | return score.cpu() 551 | 552 | def criterion_l_taylor2(self, paras, grads, num_paras=None, super_paras=None, super_grads=None): # fisher 553 | score = 0. 554 | for k, v in paras.items(): 555 | g = grads[k] 556 | if super_paras is not None and super_grads is not None: 557 | score += (v.mul(g) ** 2).sum() / (super_paras[k].mul(super_grads[k]) ** 2).sum() 558 | else: 559 | score += (v.mul(g) ** 2).sum() 560 | if num_paras: 561 | score /= num_paras 562 | return score.cpu() 563 | 564 | def criterion_l_taylor3(self, paras, grads, num_paras=None, super_paras=None, super_grads=None): 565 | score = 0. 566 | for k, v in paras.items(): 567 | g = grads[k] 568 | if super_paras is not None and super_grads is not None: 569 | score += g.sum() / super_grads[k].sum() 570 | else: 571 | score += g.sum() 572 | if num_paras: 573 | score /= num_paras 574 | score = score ** 2 575 | return score.cpu() 576 | 577 | def criterion_l_taylor4(self, paras, grads, num_paras=None, super_paras=None, super_grads=None): 578 | score = 0. 579 | for k, v in paras.items(): 580 | g = grads[k] 581 | if super_paras is not None and super_grads is not None: 582 | score += (g ** 2).sum() / (super_grads[k] ** 2).sum() 583 | else: 584 | score += (g ** 2).sum() 585 | if num_paras: 586 | score /= num_paras 587 | return score.cpu() 588 | 589 | def criterion_l_taylor5(self, paras, grads, num_paras=None, super_paras=None, super_grads=None): # synflow 590 | score = 0. 591 | for k, v in paras.items(): 592 | g = grads[k] 593 | if super_paras is not None and super_grads is not None: 594 | score += v.mul(g).sum() / super_paras[k].mul(super_grads[k]).sum() 595 | else: 596 | score += v.mul(g).sum() 597 | score = score.abs() 598 | if num_paras: 599 | score /= num_paras 600 | return score.cpu() 601 | 602 | def criterion_l_taylor6(self, paras, grads, num_paras=None, super_paras=None, super_grads=None): 603 | score = 0. 604 | for k, v in paras.items(): 605 | g = grads[k] 606 | if super_paras is not None and super_grads is not None: 607 | score += v.mul(g).abs().sum() / super_paras[k].mul(super_grads[k]).abs().sum() 608 | else: 609 | score += v.mul(g).abs().sum() 610 | if num_paras: 611 | score /= num_paras 612 | return score.cpu() 613 | 614 | def criterion_l_taylor7(self, paras, grads, num_paras=None, super_paras=None, super_grads=None): # Euclidean norm of the gradients v1 615 | score = 0. 616 | for k, v in paras.items(): 617 | g = grads[k] 618 | if super_paras is not None and super_grads is not None: 619 | score += (g ** 2).sum() / (super_grads[k] ** 2).sum() 620 | else: 621 | score += (g ** 2).sum() 622 | if num_paras: 623 | score /= num_paras 624 | score = math.sqrt(score) 625 | return score 626 | 627 | def criterion_l_taylor8(self, paras, grads, num_paras=None, super_paras=None, super_grads=None): # Euclidean norm of the gradients v2 628 | score = 0. 629 | for k, v in paras.items(): 630 | g = grads[k] 631 | if super_paras is not None and super_grads is not None: 632 | score += g.norm() / super_grads[k].norm() 633 | else: 634 | score += g.norm() 635 | if num_paras: 636 | score /= num_paras 637 | return score.cpu() 638 | 639 | def criterion_l_taylor9(self, paras, grads, num_paras=None, super_paras=None, super_grads=None): # snip 640 | score = 0. 641 | for k, v in paras.items(): 642 | g = grads[k] 643 | if super_paras is not None and super_grads is not None: 644 | score += g.abs().sum() / super_grads[k].abs().sum() 645 | else: 646 | score += g.abs().sum() 647 | if num_paras: 648 | score /= num_paras 649 | score = math.sqrt(score) 650 | return score 651 | 652 | def sort_eval(self, vis_dict, top_num=50): 653 | acc_list = [] 654 | score_lists = {} 655 | for cand, info in vis_dict.items(): 656 | # only consider the model under limits 657 | if 'acc' not in info.keys(): 658 | continue 659 | acc_list.append(info['acc']) 660 | score_stats = info['score_stats'] 661 | for k, score in score_stats.items(): 662 | if k not in score_lists.keys(): 663 | score_lists[k] = [] 664 | if type(score) == float: 665 | score_lists[k].append(score) 666 | else: 667 | score_lists[k].append(score.cpu()) 668 | acc_list = np.array(acc_list) 669 | for k in score_lists.keys(): 670 | score_lists[k] = np.array(score_lists[k]) 671 | 672 | p_vals = self.get_p_value(acc_list, score_lists) 673 | p_dict = {} 674 | for (p_val, k) in p_vals: 675 | p_dict[k] = p_val 676 | 677 | idx = acc_list.argsort()[-top_num:][::-1] 678 | sorted_acc = [(acc_list[idx], 'acc_list')] 679 | for k, scores in score_lists.items(): 680 | if p_dict[k] > 0: 681 | idx = scores.argsort()[-top_num:][::-1] 682 | else: 683 | idx = scores.argsort()[:top_num] 684 | acc = acc_list[idx] 685 | acc.sort() 686 | sorted_acc.append((acc[::-1], k)) 687 | 688 | def compare(a, b): 689 | for i in range(a[0].size): 690 | if a[0][i] > b[0][i]: 691 | return 1 692 | elif a[0][i] < b[0][i]: 693 | return -1 694 | return 0 695 | 696 | sorted_acc.sort(key=functools.cmp_to_key(compare)) 697 | for acc in sorted_acc: 698 | print(acc) 699 | 700 | -------------------------------------------------------------------------------- /lib/subImageNet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import argparse 4 | random.seed(0) 5 | parser = argparse.ArgumentParser('Generate SubImageNet', add_help=False) 6 | parser.add_argument('--data-path', default='../data/imagenet', type=str, 7 | help='dataset path') 8 | args = parser.parse_args() 9 | 10 | data_path = args.data_path 11 | ImageNet_train_path = os.path.join(data_path, 'train') 12 | subImageNet_name = 'subImageNet' 13 | class_idx_txt_path = os.path.join(data_path, subImageNet_name) 14 | 15 | # train 16 | classes = sorted(os.listdir(ImageNet_train_path)) 17 | if not os.path.exists(os.path.join(data_path, subImageNet_name)): 18 | os.mkdir(os.path.join(data_path, subImageNet_name)) 19 | 20 | subImageNet = dict() 21 | with open(os.path.join(class_idx_txt_path, 'subimages_list.txt'), 'w') as f: 22 | subImageNet_class = classes 23 | for iclass in subImageNet_class: 24 | class_path = os.path.join(ImageNet_train_path, iclass) 25 | if not os.path.exists( 26 | os.path.join( 27 | data_path, 28 | subImageNet_name, 29 | iclass)): 30 | os.mkdir(os.path.join(data_path, subImageNet_name, iclass)) 31 | subImages = random.sample(sorted(os.listdir(class_path)), 100) 32 | # print("{}\n".format(subImages)) 33 | f.write("{}\n".format(subImages)) 34 | subImageNet[iclass] = subImages 35 | for image in subImages: 36 | raw_path = os.path.join(ImageNet_train_path, iclass, image) 37 | new_ipath = os.path.join( 38 | data_path, subImageNet_name, iclass, image) 39 | os.system('cp {} {}'.format(raw_path, new_ipath)) 40 | 41 | sub_classes = sorted(subImageNet.keys()) 42 | with open(os.path.join(class_idx_txt_path, 'info.txt'), 'w') as f: 43 | class_idx = 0 44 | for key in sub_classes: 45 | images = sorted((subImageNet[key])) 46 | # print(len(images)) 47 | f.write("{}\n".format(key)) 48 | class_idx = class_idx + 1 -------------------------------------------------------------------------------- /lib/utils.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | import time 4 | from collections import defaultdict, deque 5 | import datetime 6 | 7 | import torch 8 | import torch.distributed as dist 9 | 10 | 11 | class SmoothedValue(object): 12 | """Track a series of values and provide access to smoothed values over a 13 | window or the global series average. 14 | """ 15 | 16 | def __init__(self, window_size=20, fmt=None): 17 | if fmt is None: 18 | fmt = "{median:.4f} ({global_avg:.4f})" 19 | self.deque = deque(maxlen=window_size) 20 | self.total = 0.0 21 | self.count = 0 22 | self.fmt = fmt 23 | 24 | def update(self, value, n=1): 25 | self.deque.append(value) 26 | self.count += n 27 | self.total += value * n 28 | 29 | def synchronize_between_processes(self): 30 | """ 31 | Warning: does not synchronize the deque! 32 | """ 33 | if not is_dist_avail_and_initialized(): 34 | return 35 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 36 | dist.barrier() 37 | dist.all_reduce(t) 38 | t = t.tolist() 39 | self.count = int(t[0]) 40 | self.total = t[1] 41 | 42 | @property 43 | def median(self): 44 | d = torch.tensor(list(self.deque)) 45 | return d.median().item() 46 | 47 | @property 48 | def avg(self): 49 | d = torch.tensor(list(self.deque), dtype=torch.float32) 50 | return d.mean().item() 51 | 52 | @property 53 | def global_avg(self): 54 | return self.total / self.count 55 | 56 | @property 57 | def max(self): 58 | return max(self.deque) 59 | 60 | @property 61 | def value(self): 62 | return self.deque[-1] 63 | 64 | def __str__(self): 65 | return self.fmt.format( 66 | median=self.median, 67 | avg=self.avg, 68 | global_avg=self.global_avg, 69 | max=self.max, 70 | value=self.value) 71 | 72 | 73 | class MetricLogger(object): 74 | def __init__(self, delimiter="\t"): 75 | self.meters = defaultdict(SmoothedValue) 76 | self.delimiter = delimiter 77 | 78 | def update(self, **kwargs): 79 | for k, v in kwargs.items(): 80 | if isinstance(v, torch.Tensor): 81 | v = v.item() 82 | assert isinstance(v, (float, int)) 83 | self.meters[k].update(v) 84 | 85 | def __getattr__(self, attr): 86 | if attr in self.meters: 87 | return self.meters[attr] 88 | if attr in self.__dict__: 89 | return self.__dict__[attr] 90 | raise AttributeError("'{}' object has no attribute '{}'".format( 91 | type(self).__name__, attr)) 92 | 93 | def __str__(self): 94 | loss_str = [] 95 | for name, meter in self.meters.items(): 96 | loss_str.append( 97 | "{}: {}".format(name, str(meter)) 98 | ) 99 | return self.delimiter.join(loss_str) 100 | 101 | def synchronize_between_processes(self): 102 | for meter in self.meters.values(): 103 | meter.synchronize_between_processes() 104 | 105 | def add_meter(self, name, meter): 106 | self.meters[name] = meter 107 | 108 | def log_every(self, iterable, print_freq, header=None): 109 | i = 0 110 | if not header: 111 | header = '' 112 | start_time = time.time() 113 | end = time.time() 114 | iter_time = SmoothedValue(fmt='{avg:.4f}') 115 | data_time = SmoothedValue(fmt='{avg:.4f}') 116 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 117 | log_msg = [ 118 | header, 119 | '[{0' + space_fmt + '}/{1}]', 120 | 'eta: {eta}', 121 | '{meters}', 122 | 'time: {time}', 123 | 'data: {data}' 124 | ] 125 | if torch.cuda.is_available(): 126 | log_msg.append('max mem: {memory:.0f}') 127 | log_msg = self.delimiter.join(log_msg) 128 | MB = 1024.0 * 1024.0 129 | for obj in iterable: 130 | data_time.update(time.time() - end) 131 | yield obj 132 | iter_time.update(time.time() - end) 133 | if i % print_freq == 0 or i == len(iterable) - 1: 134 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 135 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 136 | if torch.cuda.is_available(): 137 | print(log_msg.format( 138 | i, len(iterable), eta=eta_string, 139 | meters=str(self), 140 | time=str(iter_time), data=str(data_time), 141 | memory=torch.cuda.max_memory_allocated() / MB)) 142 | else: 143 | print(log_msg.format( 144 | i, len(iterable), eta=eta_string, 145 | meters=str(self), 146 | time=str(iter_time), data=str(data_time))) 147 | i += 1 148 | end = time.time() 149 | total_time = time.time() - start_time 150 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 151 | print('{} Total time: {} ({:.4f} s / it)'.format( 152 | header, total_time_str, total_time / len(iterable))) 153 | 154 | 155 | def _load_checkpoint_for_ema(model_ema, checkpoint): 156 | """ 157 | Workaround for ModelEma._load_checkpoint to accept an already-loaded object 158 | """ 159 | mem_file = io.BytesIO() 160 | torch.save(checkpoint, mem_file) 161 | mem_file.seek(0) 162 | model_ema._load_checkpoint(mem_file) 163 | 164 | 165 | def setup_for_distributed(is_master): 166 | """ 167 | This function disables printing when not in master process 168 | """ 169 | import builtins as __builtin__ 170 | builtin_print = __builtin__.print 171 | 172 | def print(*args, **kwargs): 173 | force = kwargs.pop('force', False) 174 | if is_master or force: 175 | builtin_print(*args, **kwargs) 176 | 177 | __builtin__.print = print 178 | 179 | 180 | def is_dist_avail_and_initialized(): 181 | if not dist.is_available(): 182 | return False 183 | if not dist.is_initialized(): 184 | return False 185 | return True 186 | 187 | 188 | def get_world_size(): 189 | if not is_dist_avail_and_initialized(): 190 | return 1 191 | return dist.get_world_size() 192 | 193 | 194 | def get_rank(): 195 | if not is_dist_avail_and_initialized(): 196 | return 0 197 | return dist.get_rank() 198 | 199 | 200 | def is_main_process(): 201 | return get_rank() == 0 202 | 203 | 204 | def save_on_master(*args, **kwargs): 205 | if is_main_process(): 206 | torch.save(*args, **kwargs) 207 | 208 | 209 | def init_distributed_mode(args): 210 | if 'OMPI_COMM_WORLD_RANK' in os.environ: 211 | args.rank = int(os.environ.get('OMPI_COMM_WORLD_RANK')) 212 | args.world_size = int(os.environ.get('OMPI_COMM_WORLD_SIZE')) 213 | args.gpu = args.rank % torch.cuda.device_count() 214 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 215 | args.rank = int(os.environ["RANK"]) 216 | args.world_size = int(os.environ['WORLD_SIZE']) 217 | args.gpu = int(os.environ['LOCAL_RANK']) 218 | elif 'SLURM_PROCID' in os.environ: 219 | args.rank = int(os.environ['SLURM_PROCID']) 220 | args.gpu = args.rank % torch.cuda.device_count() 221 | else: 222 | print('Not using distributed mode') 223 | args.distributed = False 224 | return 225 | 226 | args.distributed = True 227 | 228 | torch.cuda.set_device(args.gpu) 229 | args.dist_backend = 'nccl' 230 | print('| distributed init (rank {}): {}'.format( 231 | args.rank, args.dist_url), flush=True) 232 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 233 | world_size=args.world_size, rank=args.rank) 234 | torch.distributed.barrier() 235 | setup_for_distributed(args.rank == 0) 236 | -------------------------------------------------------------------------------- /model/module/Linear_super.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | class LinearSuper(nn.Linear): 7 | def __init__(self, super_in_dim, super_out_dim, bias=True, uniform_=None, non_linear='linear', scale=False): 8 | super().__init__(super_in_dim, super_out_dim, bias=bias) 9 | 10 | # super_in_dim and super_out_dim indicate the largest network! 11 | self.super_in_dim = super_in_dim 12 | self.super_out_dim = super_out_dim 13 | 14 | # input_dim and output_dim indicate the current sampled size 15 | self.sample_in_dim = None 16 | self.sample_out_dim = None 17 | 18 | self.samples = {} 19 | 20 | self.scale = scale 21 | self._reset_parameters(bias, uniform_, non_linear) 22 | self.profiling = False 23 | 24 | def profile(self, mode=True): 25 | self.profiling = mode 26 | 27 | def sample_parameters(self, resample=False): 28 | if self.profiling or resample: 29 | return self._sample_parameters() 30 | return self.samples 31 | 32 | def _reset_parameters(self, bias, uniform_, non_linear): 33 | nn.init.xavier_uniform_(self.weight) if uniform_ is None else uniform_( 34 | self.weight, non_linear=non_linear) 35 | if bias: 36 | nn.init.constant_(self.bias, 0.) 37 | 38 | def set_sample_config(self, sample_in_dim, sample_out_dim): 39 | self.sample_in_dim = sample_in_dim 40 | self.sample_out_dim = sample_out_dim 41 | 42 | self._sample_parameters() 43 | 44 | def _sample_parameters(self): 45 | self.samples['weight'] = sample_weight(self.weight, self.sample_in_dim, self.sample_out_dim) 46 | self.samples['bias'] = self.bias 47 | self.sample_scale = self.super_out_dim/self.sample_out_dim 48 | if self.bias is not None: 49 | self.samples['bias'] = sample_bias(self.bias, self.sample_out_dim) 50 | return self.samples 51 | 52 | def forward(self, x): 53 | self.sample_parameters() 54 | return F.linear(x, self.samples['weight'], self.samples['bias']) * (self.sample_scale if self.scale else 1) 55 | 56 | def calc_sampled_param_num(self): 57 | assert 'weight' in self.samples.keys() 58 | weight_numel = self.samples['weight'].numel() 59 | 60 | if self.samples['bias'] is not None: 61 | bias_numel = self.samples['bias'].numel() 62 | else: 63 | bias_numel = 0 64 | 65 | return weight_numel + bias_numel 66 | def get_complexity(self, sequence_length): 67 | total_flops = 0 68 | total_flops += sequence_length * np.prod(self.samples['weight'].size()) 69 | return total_flops 70 | 71 | def sample_weight(weight, sample_in_dim, sample_out_dim): 72 | sample_weight = weight[:, :sample_in_dim] 73 | sample_weight = sample_weight[:sample_out_dim, :] 74 | 75 | return sample_weight 76 | 77 | 78 | def sample_bias(bias, sample_out_dim): 79 | sample_bias = bias[:sample_out_dim] 80 | 81 | return sample_bias 82 | -------------------------------------------------------------------------------- /model/module/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinyvision/PreNAS/0050c7a22482e8736f148bc41ab0d952968a8748/model/module/__init__.py -------------------------------------------------------------------------------- /model/module/embedding_super.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from model.utils import to_2tuple 5 | import numpy as np 6 | 7 | class PatchembedSuper(nn.Module): 8 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, scale=False): 9 | super(PatchembedSuper, self).__init__() 10 | 11 | img_size = to_2tuple(img_size) 12 | patch_size = to_2tuple(patch_size) 13 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) 14 | self.img_size = img_size 15 | self.patch_size = patch_size 16 | self.num_patches = num_patches 17 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 18 | self.super_embed_dim = embed_dim 19 | self.scale = scale 20 | 21 | # sampled_ 22 | self.sample_embed_dim = None 23 | self.sampled_weight = None 24 | self.sampled_bias = None 25 | self.sampled_scale = None 26 | 27 | def set_sample_config(self, sample_embed_dim): 28 | self.sample_embed_dim = sample_embed_dim 29 | self.sampled_weight = self.proj.weight[:sample_embed_dim, ...] 30 | self.sampled_bias = self.proj.bias[:self.sample_embed_dim, ...] 31 | if self.scale: 32 | self.sampled_scale = self.super_embed_dim / sample_embed_dim 33 | def forward(self, x): 34 | B, C, H, W = x.shape 35 | assert H == self.img_size[0] and W == self.img_size[1], \ 36 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 37 | x = F.conv2d(x, self.sampled_weight, self.sampled_bias, stride=self.patch_size, padding=self.proj.padding, dilation=self.proj.dilation).flatten(2).transpose(1,2) 38 | if self.scale: 39 | return x * self.sampled_scale 40 | return x 41 | def calc_sampled_param_num(self): 42 | return self.sampled_weight.numel() + self.sampled_bias.numel() 43 | 44 | def get_complexity(self, sequence_length): 45 | total_flops = 0 46 | if self.sampled_bias is not None: 47 | total_flops += self.sampled_bias.size(0) 48 | total_flops += sequence_length * np.prod(self.sampled_weight.size()) 49 | return total_flops -------------------------------------------------------------------------------- /model/module/layernorm_super.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class SwitchableLayerNormSuper(nn.Module): 7 | def __init__(self, embed_dim_list): 8 | super(SwitchableLayerNormSuper, self).__init__() 9 | 10 | self.embed_dim_list = embed_dim_list 11 | 12 | # the largest embed dim 13 | self.super_embed_dim = max(embed_dim_list) 14 | 15 | # the current sampled embed dim 16 | self.sample_embed_dim = None 17 | 18 | self.lns = nn.ModuleList([nn.LayerNorm(dim) for dim in embed_dim_list]) 19 | 20 | def set_sample_config(self, sample_embed_dim): 21 | self.sample_embed_dim = sample_embed_dim 22 | self.sample_idx = self.embed_dim_list.index(sample_embed_dim) 23 | 24 | def forward(self, x): 25 | return self.lns[self.sample_idx](x) 26 | 27 | def calc_sampled_param_num(self): 28 | ln = self.lns[self.sample_idx] 29 | return ln.weight.numel() + ln.bias.numel() 30 | 31 | def get_complexity(self, sequence_length): 32 | return sequence_length * self.sample_embed_dim 33 | 34 | 35 | class LayerNormSuper(torch.nn.LayerNorm): 36 | def __init__(self, super_embed_dim): 37 | super().__init__(super_embed_dim) 38 | 39 | # the largest embed dim 40 | self.super_embed_dim = super_embed_dim 41 | 42 | # the current sampled embed dim 43 | self.sample_embed_dim = None 44 | 45 | self.samples = {} 46 | self.profiling = False 47 | 48 | def profile(self, mode=True): 49 | self.profiling = mode 50 | 51 | def sample_parameters(self, resample=False): 52 | if self.profiling or resample: 53 | return self._sample_parameters() 54 | return self.samples 55 | 56 | def _sample_parameters(self): 57 | self.samples['weight'] = self.weight[:self.sample_embed_dim] 58 | self.samples['bias'] = self.bias[:self.sample_embed_dim] 59 | return self.samples 60 | 61 | def set_sample_config(self, sample_embed_dim): 62 | self.sample_embed_dim = sample_embed_dim 63 | self._sample_parameters() 64 | 65 | def forward(self, x): 66 | self.sample_parameters() 67 | return F.layer_norm(x, (self.sample_embed_dim,), weight=self.samples['weight'], bias=self.samples['bias'], eps=self.eps) 68 | 69 | def calc_sampled_param_num(self): 70 | assert 'weight' in self.samples.keys() 71 | assert 'bias' in self.samples.keys() 72 | return self.samples['weight'].numel() + self.samples['bias'].numel() 73 | 74 | def get_complexity(self, sequence_length): 75 | return sequence_length * self.sample_embed_dim 76 | -------------------------------------------------------------------------------- /model/module/multihead_super.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import Parameter 4 | import torch.nn.functional as F 5 | from .Linear_super import LinearSuper 6 | from .qkv_super import qkv_super 7 | from .scaling_super import ScalingSuper 8 | from ..utils import trunc_normal_ 9 | from torch.cuda.amp import autocast 10 | 11 | def softmax(x, dim, onnx_trace=False): 12 | if onnx_trace: 13 | return F.softmax(x.float(), dim=dim) 14 | else: 15 | return F.softmax(x, dim=dim, dtype=torch.float32) 16 | 17 | class RelativePosition2D_super(nn.Module): 18 | 19 | def __init__(self, num_units, max_relative_position): 20 | super().__init__() 21 | 22 | self.num_units = num_units 23 | self.max_relative_position = max_relative_position 24 | # The first element in embeddings_table_v is the vertical embedding for the class 25 | self.embeddings_table_v = nn.Parameter(torch.randn(max_relative_position * 2 + 2, num_units)) 26 | self.embeddings_table_h = nn.Parameter(torch.randn(max_relative_position * 2 + 2, num_units)) 27 | 28 | trunc_normal_(self.embeddings_table_v, std=.02) 29 | trunc_normal_(self.embeddings_table_h, std=.02) 30 | 31 | self.sample_head_dim = None 32 | self.sample_embeddings_table_h = None 33 | self.sample_embeddings_table_v = None 34 | 35 | def set_sample_config(self, sample_head_dim): 36 | self.sample_head_dim = sample_head_dim 37 | self.sample_embeddings_table_h = self.embeddings_table_h[:,:sample_head_dim] 38 | self.sample_embeddings_table_v = self.embeddings_table_v[:,:sample_head_dim] 39 | 40 | def calc_sampled_param_num(self): 41 | return self.sample_embeddings_table_h.numel() + self.sample_embeddings_table_v.numel() 42 | 43 | def forward(self, length_q, length_k): 44 | # remove the first cls token distance computation 45 | length_q = length_q - 1 46 | length_k = length_k - 1 47 | device = self.embeddings_table_v.device 48 | range_vec_q = torch.arange(length_q, device=device) 49 | range_vec_k = torch.arange(length_k, device=device) 50 | # compute the row and column distance 51 | distance_mat_v = (range_vec_k[None, :] // int(length_q ** 0.5 ) - range_vec_q[:, None] // int(length_q ** 0.5 )) 52 | distance_mat_h = (range_vec_k[None, :] % int(length_q ** 0.5 ) - range_vec_q[:, None] % int(length_q ** 0.5 )) 53 | # clip the distance to the range of [-max_relative_position, max_relative_position] 54 | distance_mat_clipped_v = torch.clamp(distance_mat_v, -self.max_relative_position, self.max_relative_position) 55 | distance_mat_clipped_h = torch.clamp(distance_mat_h, -self.max_relative_position, self.max_relative_position) 56 | 57 | # translate the distance from [1, 2 * max_relative_position + 1], 0 is for the cls token 58 | final_mat_v = distance_mat_clipped_v + self.max_relative_position + 1 59 | final_mat_h = distance_mat_clipped_h + self.max_relative_position + 1 60 | # pad the 0 which represent the cls token 61 | final_mat_v = torch.nn.functional.pad(final_mat_v, (1,0,1,0), "constant", 0) 62 | final_mat_h = torch.nn.functional.pad(final_mat_h, (1,0,1,0), "constant", 0) 63 | 64 | final_mat_v = final_mat_v.long() 65 | final_mat_h = final_mat_h.long() 66 | # get the embeddings with the corresponding distance 67 | embeddings = self.sample_embeddings_table_v[final_mat_v] + self.sample_embeddings_table_h[final_mat_h] 68 | 69 | return embeddings 70 | 71 | class AttentionSuper(nn.Module): 72 | def __init__(self, super_embed_dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., normalization = False, relative_position = False, 73 | num_patches = None, max_relative_position=14, scale=False, change_qkv = False, choices=None, scale_attn=False): 74 | super().__init__() 75 | self.num_heads = num_heads 76 | head_dim = super_embed_dim // num_heads 77 | self.scale = qk_scale or head_dim ** -0.5 78 | self.super_embed_dim = super_embed_dim 79 | 80 | self.fc_scale = scale 81 | self.change_qkv = change_qkv 82 | self.scale_attn = scale_attn 83 | 84 | self.choices = choices 85 | 86 | if change_qkv: 87 | self.qkv = qkv_super(super_embed_dim, 3 * super_embed_dim, bias=qkv_bias) 88 | if scale_attn: 89 | self.qk_scaling = ScalingSuper([n_head * 64 for n_head in self.choices['num_heads']]) 90 | self.v_scaling = ScalingSuper([n_head * 64 for n_head in self.choices['num_heads']]) 91 | else: 92 | self.qkv = LinearSuper(super_embed_dim, 3 * super_embed_dim, bias=qkv_bias) 93 | 94 | self.relative_position = relative_position 95 | if self.relative_position: 96 | self.rel_pos_embed_k = RelativePosition2D_super(super_embed_dim //num_heads, max_relative_position) 97 | self.rel_pos_embed_v = RelativePosition2D_super(super_embed_dim //num_heads, max_relative_position) 98 | self.max_relative_position = max_relative_position 99 | self.sample_qk_embed_dim = None 100 | self.sample_v_embed_dim = None 101 | self.sample_num_heads = None 102 | self.sample_scale = None 103 | self.sample_in_embed_dim = None 104 | 105 | self.proj = LinearSuper(super_embed_dim, super_embed_dim) 106 | 107 | self.attn_drop = nn.Dropout(attn_drop) 108 | self.proj_drop = nn.Dropout(proj_drop) 109 | 110 | def set_sample_config(self, sample_q_embed_dim=None, sample_num_heads=None, sample_in_embed_dim=None): 111 | 112 | self.sample_in_embed_dim = sample_in_embed_dim 113 | self.sample_num_heads = sample_num_heads 114 | if not self.change_qkv: 115 | self.sample_qk_embed_dim = self.super_embed_dim 116 | self.sample_scale = (sample_in_embed_dim // self.sample_num_heads) ** -0.5 117 | else: 118 | if self.scale_attn: 119 | self.qk_scaling.set_sample_config(sample_q_embed_dim) 120 | self.v_scaling.set_sample_config(sample_q_embed_dim) 121 | self.sample_qk_embed_dim = sample_q_embed_dim 122 | self.sample_scale = (self.sample_qk_embed_dim // self.sample_num_heads) ** -0.5 123 | 124 | self.qkv.set_sample_config(sample_in_dim=sample_in_embed_dim, sample_out_dim=3*self.sample_qk_embed_dim) 125 | self.proj.set_sample_config(sample_in_dim=self.sample_qk_embed_dim, sample_out_dim=sample_in_embed_dim) 126 | if self.relative_position: 127 | self.rel_pos_embed_k.set_sample_config(self.sample_qk_embed_dim // sample_num_heads) 128 | self.rel_pos_embed_v.set_sample_config(self.sample_qk_embed_dim // sample_num_heads) 129 | def calc_sampled_param_num(self): 130 | 131 | return 0 132 | def get_complexity(self, sequence_length): 133 | total_flops = 0 134 | total_flops += self.qkv.get_complexity(sequence_length) 135 | # attn 136 | total_flops += sequence_length * sequence_length * self.sample_qk_embed_dim 137 | # x 138 | total_flops += sequence_length * sequence_length * self.sample_qk_embed_dim 139 | total_flops += self.proj.get_complexity(sequence_length) 140 | if self.relative_position: 141 | total_flops += self.max_relative_position * sequence_length * sequence_length + sequence_length * sequence_length / 2.0 142 | total_flops += self.max_relative_position * sequence_length * sequence_length + sequence_length * self.sample_qk_embed_dim / 2.0 143 | return total_flops 144 | 145 | def forward(self, x): 146 | B, N, C = x.shape 147 | qkv = self.qkv(x).reshape(B, N, 3, self.sample_num_heads, -1).permute(2, 0, 3, 1, 4) 148 | with autocast(enabled=False): 149 | q, k, v = qkv[0].float(), qkv[1].float(), qkv[2].float() # make torchscript happy (cannot use tensor as tuple) 150 | 151 | if self.scale_attn: 152 | q = self.qk_scaling(q.transpose(1, 2).reshape(B, N, -1)).reshape(B, N, self.sample_num_heads, -1).transpose(1, 2) 153 | attn = (q @ k.transpose(-2, -1)) * self.sample_scale 154 | if self.relative_position: 155 | r_p_k = self.rel_pos_embed_k(N, N) 156 | attn = attn + (q.permute(2, 0, 1, 3).reshape(N, self.sample_num_heads * B, -1) @ r_p_k.transpose(2, 1)) \ 157 | .transpose(1, 0).reshape(B, self.sample_num_heads, N, N) * self.sample_scale 158 | 159 | attn = attn.softmax(dim=-1) 160 | attn = self.attn_drop(attn) 161 | 162 | if self.scale_attn: 163 | v = self.v_scaling(v.transpose(1, 2).reshape(B, N, -1)).reshape(B, N, self.sample_num_heads, -1).transpose(1, 2) 164 | x = (attn @ v).transpose(1,2).reshape(B, N, -1) 165 | if self.relative_position: 166 | r_p_v = self.rel_pos_embed_v(N, N) 167 | attn_1 = attn.permute(2, 0, 1, 3).reshape(N, B * self.sample_num_heads, -1) 168 | # The size of attention is (B, num_heads, N, N), reshape it to (N, B*num_heads, N) and do batch matmul with 169 | # the relative position embedding of V (N, N, head_dim) get shape like (N, B*num_heads, head_dim). We reshape it to the 170 | # same size as x (B, num_heads, N, hidden_dim) 171 | x = x + (attn_1 @ r_p_v).transpose(1, 0).reshape(B, self.sample_num_heads, N, -1).transpose(2,1).reshape(B, N, -1) 172 | 173 | if self.fc_scale: 174 | x = x * (self.super_embed_dim / self.sample_qk_embed_dim) 175 | x = self.proj(x) 176 | x = self.proj_drop(x) 177 | return x 178 | -------------------------------------------------------------------------------- /model/module/qkv_super.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | class qkv_super(nn.Linear): 8 | def __init__(self, super_in_dim, super_out_dim, bias=True, uniform_=None, non_linear='linear', scale=False): 9 | super().__init__(super_in_dim, super_out_dim, bias=bias) 10 | 11 | # super_in_dim and super_out_dim indicate the largest network! 12 | self.super_in_dim = super_in_dim 13 | self.super_out_dim = super_out_dim 14 | 15 | # input_dim and output_dim indicate the current sampled size 16 | self.sample_in_dim = None 17 | self.sample_out_dim = None 18 | 19 | self.samples = {} 20 | 21 | self.scale = scale 22 | # self._reset_parameters(bias, uniform_, non_linear) 23 | self.profiling = False 24 | 25 | def profile(self, mode=True): 26 | self.profiling = mode 27 | 28 | def sample_parameters(self, resample=False): 29 | if self.profiling or resample: 30 | return self._sample_parameters() 31 | return self.samples 32 | 33 | def _reset_parameters(self, bias, uniform_, non_linear): 34 | nn.init.xavier_uniform_(self.weight) if uniform_ is None else uniform_( 35 | self.weight, non_linear=non_linear) 36 | if bias: 37 | nn.init.constant_(self.bias, 0.) 38 | 39 | def set_sample_config(self, sample_in_dim, sample_out_dim): 40 | self.sample_in_dim = sample_in_dim 41 | self.sample_out_dim = sample_out_dim 42 | 43 | self._sample_parameters() 44 | 45 | def _sample_parameters(self): 46 | self.samples['weight'] = sample_weight(self.weight, self.sample_in_dim, self.sample_out_dim) 47 | self.samples['bias'] = self.bias 48 | self.sample_scale = self.super_out_dim/self.sample_out_dim 49 | if self.bias is not None: 50 | self.samples['bias'] = sample_bias(self.bias, self.sample_out_dim) 51 | return self.samples 52 | 53 | def forward(self, x): 54 | self.sample_parameters() 55 | return F.linear(x, self.samples['weight'], self.samples['bias']) * (self.sample_scale if self.scale else 1) 56 | 57 | def calc_sampled_param_num(self): 58 | assert 'weight' in self.samples.keys() 59 | weight_numel = self.samples['weight'].numel() 60 | 61 | if self.samples['bias'] is not None: 62 | bias_numel = self.samples['bias'].numel() 63 | else: 64 | bias_numel = 0 65 | 66 | return weight_numel + bias_numel 67 | def get_complexity(self, sequence_length): 68 | total_flops = 0 69 | total_flops += sequence_length * np.prod(self.samples['weight'].size()) 70 | return total_flops 71 | 72 | def sample_weight(weight, sample_in_dim, sample_out_dim): 73 | 74 | sample_weight = weight[:, :sample_in_dim] 75 | sample_weight = torch.cat([sample_weight[i:sample_out_dim:3, :] for i in range(3)], dim =0) 76 | 77 | return sample_weight 78 | 79 | 80 | def sample_bias(bias, sample_out_dim): 81 | #sample_bias = bias[:sample_out_dim] 82 | sample_bias = torch.cat([bias[i:sample_out_dim:3] for i in range(3)]) 83 | 84 | return sample_bias 85 | -------------------------------------------------------------------------------- /model/module/scaling_super.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ScalingSuper(nn.Module): 7 | def __init__(self, embed_dim_list): 8 | super(ScalingSuper, self).__init__() 9 | 10 | self.embed_dim_list = embed_dim_list 11 | 12 | # the largest embed dim 13 | self.super_embed_dim = max(embed_dim_list) 14 | 15 | # the current sampled embed dim 16 | self.sample_embed_dim = None 17 | 18 | self.scalings = nn.Parameter(1e-4 * torch.ones(len(embed_dim_list), self.super_embed_dim)) 19 | 20 | def set_sample_config(self, sample_embed_dim): 21 | self.sample_embed_dim = sample_embed_dim 22 | self.sample_idx = self.embed_dim_list.index(sample_embed_dim) 23 | 24 | def forward(self, x): 25 | return x * self.scalings[self.sample_idx][:self.sample_embed_dim] 26 | 27 | def calc_sampled_param_num(self): 28 | return 0 #self.scalings[self.sample_idx][:self.sample_embed_dim].numel() 29 | 30 | def get_complexity(self, sequence_length): 31 | return 0 #sequence_length * self.sample_embed_dim 32 | 33 | -------------------------------------------------------------------------------- /model/supernet_transformer.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from model.module.Linear_super import LinearSuper 7 | from model.module.layernorm_super import LayerNormSuper, SwitchableLayerNormSuper 8 | from model.module.scaling_super import ScalingSuper 9 | from model.module.multihead_super import AttentionSuper 10 | from model.module.embedding_super import PatchembedSuper 11 | from model.utils import trunc_normal_ 12 | from model.utils import DropPath 13 | import numpy as np 14 | 15 | def gelu(x: torch.Tensor) -> torch.Tensor: 16 | if hasattr(torch.nn.functional, 'gelu'): 17 | return torch.nn.functional.gelu(x.float()).type_as(x) 18 | else: 19 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 20 | 21 | 22 | class Vision_TransformerSuper(nn.Module): 23 | 24 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 25 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 26 | drop_path_rate=0., pre_norm=True, scale=False, gp=False, relative_position=False, change_qkv=False, abs_pos = True, max_relative_position=14, 27 | choices=None, switch_ln=False, scale_attn=False, scale_mlp=False, scale_embed=False): 28 | super(Vision_TransformerSuper, self).__init__() 29 | # the configs of super arch 30 | self.super_embed_dim = embed_dim 31 | # self.super_embed_dim = args.embed_dim 32 | self.super_mlp_ratio = mlp_ratio 33 | self.super_layer_num = depth 34 | self.super_num_heads = num_heads 35 | self.super_dropout = drop_rate 36 | self.super_attn_dropout = attn_drop_rate 37 | self.num_classes = num_classes 38 | self.pre_norm=pre_norm 39 | self.scale=scale 40 | self.patch_embed_super = PatchembedSuper(img_size=img_size, patch_size=patch_size, 41 | in_chans=in_chans, embed_dim=embed_dim) 42 | self.gp = gp 43 | self.choices = choices 44 | self.scale_embed = scale_embed 45 | 46 | # configs for the sampled subTransformer 47 | self.sample_embed_dim = None 48 | self.sample_mlp_ratio = None 49 | self.sample_layer_num = None 50 | self.sample_num_heads = None 51 | self.sample_dropout = None 52 | self.sample_output_dim = None 53 | 54 | self.blocks = nn.ModuleList() 55 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 56 | 57 | for i in range(depth): 58 | self.blocks.append(TransformerEncoderLayer(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, 59 | qkv_bias=qkv_bias, qk_scale=qk_scale, dropout=drop_rate, 60 | attn_drop=attn_drop_rate, drop_path=dpr[i], 61 | pre_norm=pre_norm, scale=self.scale, 62 | change_qkv=change_qkv, relative_position=relative_position, 63 | max_relative_position=max_relative_position, 64 | choices=choices, switch_ln=switch_ln, 65 | scale_attn=scale_attn, scale_mlp=scale_mlp, scale_embed=scale_embed)) 66 | 67 | # parameters for vision transformer 68 | num_patches = self.patch_embed_super.num_patches 69 | 70 | self.abs_pos = abs_pos 71 | if self.abs_pos: 72 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 73 | trunc_normal_(self.pos_embed, std=.02) 74 | 75 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 76 | trunc_normal_(self.cls_token, std=.02) 77 | 78 | if scale_embed: 79 | self.embed_scaling = ScalingSuper(self.choices['embed_dim']) 80 | 81 | # self.pos_drop = nn.Dropout(p=drop_rate) 82 | if self.pre_norm: 83 | if switch_ln: 84 | self.norm = SwitchableLayerNormSuper(self.choices['embed_dim']) 85 | else: 86 | self.norm = LayerNormSuper(super_embed_dim=embed_dim) 87 | 88 | # classifier head 89 | self.head = LinearSuper(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 90 | 91 | self.apply(self._init_weights) 92 | 93 | def _init_weights(self, m): 94 | if isinstance(m, nn.Linear): 95 | trunc_normal_(m.weight, std=.02) 96 | if isinstance(m, nn.Linear) and m.bias is not None: 97 | nn.init.constant_(m.bias, 0) 98 | elif isinstance(m, nn.LayerNorm): 99 | nn.init.constant_(m.bias, 0) 100 | nn.init.constant_(m.weight, 1.0) 101 | 102 | @torch.jit.ignore 103 | def no_weight_decay(self): 104 | return {'pos_embed', 'cls_token', 'rel_pos_embed'} 105 | 106 | def get_classifier(self): 107 | return self.head 108 | 109 | #def reset_classifier(self, num_classes, global_pool=''): 110 | # self.num_classes = num_classes 111 | # self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 112 | 113 | def set_sample_config(self, config: dict): 114 | self.sample_embed_dim = config['embed_dim'] 115 | self.sample_mlp_ratio = config['mlp_ratio'] 116 | self.sample_layer_num = config['layer_num'] 117 | self.sample_num_heads = config['num_heads'] 118 | self.sample_dropout = calc_dropout(self.super_dropout, self.sample_embed_dim[0], self.super_embed_dim) 119 | self.patch_embed_super.set_sample_config(self.sample_embed_dim[0]) 120 | if self.scale_embed: 121 | self.embed_scaling.set_sample_config(self.sample_embed_dim[0]) 122 | self.sample_output_dim = [out_dim for out_dim in self.sample_embed_dim[1:]] + [self.sample_embed_dim[-1]] 123 | for i, blocks in enumerate(self.blocks): 124 | # not exceed sample layer number 125 | if i < self.sample_layer_num: 126 | sample_dropout = calc_dropout(self.super_dropout, self.sample_embed_dim[i], self.super_embed_dim) 127 | sample_attn_dropout = calc_dropout(self.super_attn_dropout, self.sample_embed_dim[i], self.super_embed_dim) 128 | blocks.set_sample_config(is_identity_layer=False, 129 | sample_embed_dim=self.sample_embed_dim[i], 130 | sample_mlp_ratio=self.sample_mlp_ratio[i], 131 | sample_num_heads=self.sample_num_heads[i], 132 | sample_dropout=sample_dropout, 133 | sample_out_dim=self.sample_output_dim[i], 134 | sample_attn_dropout=sample_attn_dropout) 135 | # exceeds sample layer number 136 | else: 137 | blocks.set_sample_config(is_identity_layer=True) 138 | if self.pre_norm: 139 | self.norm.set_sample_config(self.sample_embed_dim[-1]) 140 | self.head.set_sample_config(self.sample_embed_dim[-1], self.num_classes) 141 | 142 | def get_sampled_params_numel(self, config): 143 | self.set_sample_config(config) 144 | numels = [] 145 | for name, module in self.named_modules(): 146 | if hasattr(module, 'calc_sampled_param_num'): 147 | if name.split('.')[0] == 'blocks' and int(name.split('.')[1]) >= config['layer_num']: 148 | continue 149 | numels.append(module.calc_sampled_param_num()) 150 | 151 | return sum(numels) + self.sample_embed_dim[0]* (2 +self.patch_embed_super.num_patches) 152 | def get_complexity(self, sequence_length): 153 | total_flops = 0 154 | total_flops += self.patch_embed_super.get_complexity(sequence_length) 155 | total_flops += np.prod(self.pos_embed[..., :self.sample_embed_dim[0]].size()) / 2.0 156 | for blk in self.blocks: 157 | total_flops += blk.get_complexity(sequence_length+1) 158 | total_flops += self.head.get_complexity(sequence_length+1) 159 | return total_flops 160 | def forward_features(self, x): 161 | B = x.shape[0] 162 | x = self.patch_embed_super(x) 163 | cls_tokens = self.cls_token[..., :self.sample_embed_dim[0]].expand(B, -1, -1) 164 | x = torch.cat((cls_tokens, x), dim=1) 165 | if self.abs_pos: 166 | x = x + self.pos_embed[..., :self.sample_embed_dim[0]] 167 | 168 | if self.scale_embed: 169 | x = self.embed_scaling(x) 170 | 171 | x = F.dropout(x, p=self.sample_dropout, training=self.training) 172 | 173 | # start_time = time.time() 174 | for blk in self.blocks: 175 | x = blk(x) 176 | # print(time.time()-start_time) 177 | if self.pre_norm: 178 | x = self.norm(x) 179 | 180 | if self.gp: 181 | return torch.mean(x[:, 1:] , dim=1) 182 | 183 | return x[:, 0] 184 | 185 | def forward(self, x): 186 | x = self.forward_features(x) 187 | x = self.head(x) 188 | return x 189 | 190 | 191 | class TransformerEncoderLayer(nn.Module): 192 | """Encoder layer block. 193 | 194 | Args: 195 | args (argparse.Namespace): parsed command-line arguments which 196 | """ 197 | 198 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, dropout=0., attn_drop=0., 199 | drop_path=0., act_layer=nn.GELU, pre_norm=True, scale=False, 200 | relative_position=False, change_qkv=False, max_relative_position=14, choices=None, switch_ln=False, 201 | scale_attn=False, scale_mlp=False, scale_embed=False): 202 | super().__init__() 203 | 204 | # the configs of super arch of the encoder, three dimension [embed_dim, mlp_ratio, and num_heads] 205 | self.super_embed_dim = dim 206 | self.super_mlp_ratio = mlp_ratio 207 | self.super_ffn_embed_dim_this_layer = int(mlp_ratio * dim) 208 | self.super_num_heads = num_heads 209 | self.normalize_before = pre_norm 210 | self.super_dropout = attn_drop 211 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 212 | self.scale = scale 213 | self.relative_position = relative_position 214 | # self.super_activation_dropout = getattr(args, 'activation_dropout', 0) 215 | self.choices = choices 216 | self.scale_mlp = scale_mlp 217 | self.scale_embed = scale_embed 218 | 219 | # the configs of current sampled arch 220 | self.sample_embed_dim = None 221 | self.sample_mlp_ratio = None 222 | self.sample_ffn_embed_dim_this_layer = None 223 | self.sample_num_heads_this_layer = None 224 | self.sample_scale = None 225 | self.sample_dropout = None 226 | self.sample_attn_dropout = None 227 | 228 | self.is_identity_layer = None 229 | self.attn = AttentionSuper( 230 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, 231 | proj_drop=dropout, scale=self.scale, relative_position=self.relative_position, change_qkv=change_qkv, 232 | max_relative_position=max_relative_position, choices=choices, scale_attn=scale_attn, 233 | ) 234 | 235 | if switch_ln: 236 | self.attn_layer_norm = SwitchableLayerNormSuper(self.choices['embed_dim']) 237 | self.ffn_layer_norm = SwitchableLayerNormSuper(self.choices['embed_dim']) 238 | else: 239 | self.attn_layer_norm = LayerNormSuper(self.super_embed_dim) 240 | self.ffn_layer_norm = LayerNormSuper(self.super_embed_dim) 241 | 242 | if scale_embed: 243 | self.embed_scaling1 = ScalingSuper(self.choices['embed_dim']) 244 | self.embed_scaling2 = ScalingSuper(self.choices['embed_dim']) 245 | 246 | # self.dropout = dropout 247 | self.activation_fn = gelu 248 | # self.normalize_before = args.encoder_normalize_before 249 | 250 | self.fc1 = LinearSuper(super_in_dim=self.super_embed_dim, super_out_dim=self.super_ffn_embed_dim_this_layer) 251 | if scale_mlp: 252 | self.mlp_scaling = ScalingSuper([emb * ratio for emb in self.choices['embed_dim'] for ratio in self.choices['mlp_ratio']]) 253 | self.fc2 = LinearSuper(super_in_dim=self.super_ffn_embed_dim_this_layer, super_out_dim=self.super_embed_dim) 254 | 255 | 256 | def set_sample_config(self, is_identity_layer, sample_embed_dim=None, sample_mlp_ratio=None, sample_num_heads=None, sample_dropout=None, sample_attn_dropout=None, sample_out_dim=None): 257 | 258 | if is_identity_layer: 259 | self.is_identity_layer = True 260 | return 261 | 262 | self.is_identity_layer = False 263 | 264 | self.sample_embed_dim = sample_embed_dim 265 | self.sample_out_dim = sample_out_dim 266 | self.sample_mlp_ratio = sample_mlp_ratio 267 | self.sample_ffn_embed_dim_this_layer = int(sample_embed_dim*sample_mlp_ratio) 268 | self.sample_num_heads_this_layer = sample_num_heads 269 | 270 | self.sample_dropout = sample_dropout 271 | self.sample_attn_dropout = sample_attn_dropout 272 | self.attn_layer_norm.set_sample_config(sample_embed_dim=self.sample_embed_dim) 273 | 274 | self.attn.set_sample_config(sample_q_embed_dim=self.sample_num_heads_this_layer*64, sample_num_heads=self.sample_num_heads_this_layer, sample_in_embed_dim=self.sample_embed_dim) 275 | 276 | self.fc1.set_sample_config(sample_in_dim=self.sample_embed_dim, sample_out_dim=self.sample_ffn_embed_dim_this_layer) 277 | if self.scale_mlp: 278 | self.mlp_scaling.set_sample_config(self.sample_ffn_embed_dim_this_layer) 279 | self.fc2.set_sample_config(sample_in_dim=self.sample_ffn_embed_dim_this_layer, sample_out_dim=self.sample_out_dim) 280 | 281 | self.ffn_layer_norm.set_sample_config(sample_embed_dim=self.sample_embed_dim) 282 | 283 | if self.scale_embed: 284 | self.embed_scaling1.set_sample_config(sample_embed_dim=self.sample_embed_dim) 285 | self.embed_scaling2.set_sample_config(sample_embed_dim=self.sample_embed_dim) 286 | 287 | 288 | def forward(self, x): 289 | """ 290 | Args: 291 | x (Tensor): input to the layer of shape `(batch, patch_num , sample_embed_dim)` 292 | 293 | Returns: 294 | encoded output of shape `(batch, patch_num, sample_embed_dim)` 295 | """ 296 | if self.is_identity_layer: 297 | return x 298 | 299 | # compute attn 300 | # start_time = time.time() 301 | 302 | residual = x 303 | x = self.maybe_layer_norm(self.attn_layer_norm, x, before=True) 304 | x = self.attn(x) 305 | if self.scale_embed: 306 | x = self.embed_scaling1(x) 307 | x = F.dropout(x, p=self.sample_attn_dropout, training=self.training) 308 | x = self.drop_path(x) 309 | x = residual + x 310 | x = self.maybe_layer_norm(self.attn_layer_norm, x, after=True) 311 | # print("attn :", time.time() - start_time) 312 | # compute the ffn 313 | # start_time = time.time() 314 | residual = x 315 | x = self.maybe_layer_norm(self.ffn_layer_norm, x, before=True) 316 | x = self.fc1(x) 317 | if self.scale_mlp: 318 | x = self.mlp_scaling(x) 319 | x = self.activation_fn(x) 320 | x = F.dropout(x, p=self.sample_dropout, training=self.training) 321 | x = self.fc2(x) 322 | if self.scale_embed: 323 | x = self.embed_scaling2(x) 324 | x = F.dropout(x, p=self.sample_dropout, training=self.training) 325 | if self.scale: 326 | x = x * (self.super_mlp_ratio / self.sample_mlp_ratio) 327 | x = self.drop_path(x) 328 | x = residual + x 329 | x = self.maybe_layer_norm(self.ffn_layer_norm, x, after=True) 330 | # print("ffn :", time.time() - start_time) 331 | return x 332 | 333 | def maybe_layer_norm(self, layer_norm, x, before=False, after=False): 334 | assert before ^ after 335 | if after ^ self.normalize_before: 336 | return layer_norm(x) 337 | else: 338 | return x 339 | def get_complexity(self, sequence_length): 340 | total_flops = 0 341 | if self.is_identity_layer: 342 | return total_flops 343 | total_flops += self.attn_layer_norm.get_complexity(sequence_length+1) 344 | total_flops += self.attn.get_complexity(sequence_length+1) 345 | total_flops += self.ffn_layer_norm.get_complexity(sequence_length+1) 346 | total_flops += self.fc1.get_complexity(sequence_length+1) 347 | total_flops += self.fc2.get_complexity(sequence_length+1) 348 | return total_flops 349 | 350 | def calc_dropout(dropout, sample_embed_dim, super_embed_dim): 351 | return dropout * 1.0 * sample_embed_dim / super_embed_dim 352 | 353 | 354 | 355 | 356 | 357 | -------------------------------------------------------------------------------- /model/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import warnings 4 | from itertools import repeat 5 | import collections.abc as container_abcs 6 | import torch.nn as nn 7 | 8 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 9 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 10 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 11 | def norm_cdf(x): 12 | # Computes standard normal cumulative distribution function 13 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 14 | 15 | if (mean < a - 2 * std) or (mean > b + 2 * std): 16 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 17 | "The distribution of values may be incorrect.", 18 | stacklevel=2) 19 | 20 | with torch.no_grad(): 21 | # Values are generated by using a truncated uniform distribution and 22 | # then using the inverse CDF for the normal distribution. 23 | # Get upper and lower cdf values 24 | l = norm_cdf((a - mean) / std) 25 | u = norm_cdf((b - mean) / std) 26 | 27 | # Uniformly fill tensor with values from [l, u], then translate to 28 | # [2l-1, 2u-1]. 29 | tensor.uniform_(2 * l - 1, 2 * u - 1) 30 | 31 | # Use inverse cdf transform for normal distribution to get truncated 32 | # standard normal 33 | tensor.erfinv_() 34 | 35 | # Transform to proper mean, std 36 | tensor.mul_(std * math.sqrt(2.)) 37 | tensor.add_(mean) 38 | 39 | # Clamp to ensure it's in the proper range 40 | tensor.clamp_(min=a, max=b) 41 | return tensor 42 | 43 | 44 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 45 | # type: (Tensor, float, float, float, float) -> Tensor 46 | r"""Fills the input Tensor with values drawn from a truncated 47 | normal distribution. The values are effectively drawn from the 48 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 49 | with values outside :math:`[a, b]` redrawn until they are within 50 | the bounds. The method used for generating the random values works 51 | best when :math:`a \leq \text{mean} \leq b`. 52 | Args: 53 | tensor: an n-dimensional `torch.Tensor` 54 | mean: the mean of the normal distribution 55 | std: the standard deviation of the normal distribution 56 | a: the minimum cutoff value 57 | b: the maximum cutoff value 58 | Examples: 59 | >>> w = torch.empty(3, 5) 60 | >>> nn.init.trunc_normal_(w) 61 | """ 62 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 63 | 64 | def _ntuple(n): 65 | def parse(x): 66 | if isinstance(x, container_abcs.Iterable): 67 | return x 68 | return tuple(repeat(x, n)) 69 | return parse 70 | 71 | def drop_path(x, drop_prob: float = 0., training: bool = False): 72 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 73 | 74 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 75 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 76 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 77 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 78 | 'survival rate' as the argument. 79 | 80 | """ 81 | if drop_prob == 0. or not training: 82 | return x 83 | keep_prob = 1 - drop_prob 84 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 85 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 86 | random_tensor.floor_() # binarize 87 | output = x.div(keep_prob) * random_tensor 88 | return output 89 | 90 | 91 | class DropPath(nn.Module): 92 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 93 | """ 94 | def __init__(self, drop_prob=None): 95 | super(DropPath, self).__init__() 96 | self.drop_prob = drop_prob 97 | 98 | def forward(self, x): 99 | return drop_path(x, self.drop_prob, self.training) 100 | 101 | 102 | to_1tuple = _ntuple(1) 103 | to_2tuple = _ntuple(2) 104 | to_3tuple = _ntuple(3) 105 | to_4tuple = _ntuple(4) 106 | to_ntuple = _ntuple 107 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.7.0 2 | timm==0.3.2 3 | scikit-image 4 | ptflops 5 | easydict 6 | PyYAML 7 | pillow 8 | torchvision==0.2.1 9 | opencv-python 10 | -------------------------------------------------------------------------------- /supernet_engine.py: -------------------------------------------------------------------------------- 1 | import math 2 | import sys 3 | from typing import Iterable, Optional 4 | from timm.utils.model import unwrap_model 5 | import torch 6 | import torch.distributed as dist 7 | from copy import deepcopy 8 | from timm.data import Mixup 9 | from timm.utils import accuracy, ModelEma 10 | from lib import utils 11 | import random 12 | import time 13 | import json 14 | from contextlib import ExitStack 15 | 16 | def sample_a_cand(candidates, grouping=None, exclude=None): 17 | exclude = exclude or [] 18 | while not grouping: 19 | idx = random.choice(range(len(candidates))) 20 | if idx not in exclude: 21 | return candidates[idx] 22 | while True: 23 | idx = random.choice(random.choice(grouping)) 24 | if idx not in exclude: 25 | return candidates[idx] 26 | 27 | def sample_candidates(candidates, eval=False, sandwich=0, sandwich_base=True, sandwich_top=True, shuffle=False, grouping=None): 28 | if eval: 29 | return candidates[0] 30 | else: 31 | if sandwich == 0: 32 | cand = sample_a_cand(candidates, grouping) 33 | if shuffle: 34 | cand = deepcopy(cand) 35 | random.shuffle(cand['mlp_ratio']) 36 | random.shuffle(cand['num_heads']) 37 | return cand 38 | else: 39 | base_cand = [] 40 | top_cand = [] 41 | exclude = [] 42 | if sandwich_base: 43 | base_cand = [candidates[0]] 44 | exclude.append(0) 45 | if sandwich_top: 46 | top_cand = [candidates[-1]] 47 | exclude.append(len(candidates)-1) 48 | inter_cands = [sample_a_cand(candidates, grouping, exclude) for _ in range(sandwich)] 49 | return base_cand + inter_cands + top_cand 50 | 51 | def sample_a_config(choices, efunc=random.choice): 52 | config = {} 53 | embed_dim = efunc(choices['embed_dim']) 54 | if isinstance(choices['depth'], dict): 55 | depth = efunc(choices['depth'][embed_dim]) 56 | else: 57 | depth = efunc(choices['depth']) 58 | dimensions = ['mlp_ratio', 'num_heads'] 59 | for dimension in dimensions: 60 | if isinstance(choices[dimension], dict): 61 | config[dimension] = [efunc(choices[dimension][embed_dim][i]) for i in range(depth)] 62 | else: 63 | config[dimension] = [efunc(choices[dimension]) for _ in range(depth)] 64 | config['embed_dim'] = [embed_dim] * depth 65 | config['layer_num'] = depth 66 | return config 67 | 68 | def sample_configs(choices, eval=False, sandwich=0, sandwich_base=True, sandwich_top=True): 69 | if eval: 70 | return sample_a_config(choices, min) 71 | else: 72 | if sandwich == 0: 73 | return sample_a_config(choices) 74 | else: 75 | base_config = [sample_a_config(choices, min)] if sandwich_base else [] 76 | top_config = [sample_a_config(choices, max)] if sandwich_top else [] 77 | inter_configs = [sample_a_config(choices) for _ in range(sandwich)] 78 | return base_config + inter_configs + top_config 79 | 80 | def bp_once(loss, loss_scaler=None, create_graph=False): 81 | if loss_scaler: 82 | loss_scaler._scaler.scale(loss).backward(create_graph=create_graph) 83 | else: 84 | loss.backward() 85 | 86 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, 87 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 88 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 89 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, 90 | amp: bool = True, teacher_model: torch.nn.Module = None, 91 | teach_loss: torch.nn.Module = None, choices=None, mode='super', retrain_config=None, 92 | print2file=False, candidates=None, sandwich=0, sandwich_base=True, sandwich_top=True, 93 | shuffle=False, grouping=None): 94 | model.train() 95 | criterion.train() 96 | 97 | # set random seed 98 | random.seed(epoch) 99 | 100 | metric_logger = utils.MetricLogger(delimiter=" ") 101 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 102 | header = 'Epoch: [{}]'.format(epoch) 103 | print_freq = 10 104 | if mode == 'retrain': 105 | config = retrain_config 106 | model_module = unwrap_model(model) 107 | print("DEBUG:retrain {}".format(config), force=print2file) 108 | model_module.set_sample_config(config=config) 109 | print("DEBUG:retrain {}".format(model_module.get_sampled_params_numel(config)), force=print2file) 110 | 111 | for samples, targets in metric_logger.log_every(data_loader, print_freq, header): 112 | samples = samples.to(device, non_blocking=True) 113 | targets = targets.to(device, non_blocking=True) 114 | 115 | # sample random config 116 | if mode == 'super': 117 | sandwich_args = {'sandwich': sandwich, 118 | 'sandwich_base': sandwich_base, 119 | 'sandwich_top': sandwich_top, 120 | } 121 | if candidates is not None: 122 | config = sample_candidates(candidates, **sandwich_args, shuffle=shuffle, grouping=grouping) 123 | else: 124 | config = sample_configs(choices, **sandwich_args) 125 | if isinstance(config, dict): 126 | config = [config] 127 | model_module = unwrap_model(model) 128 | #model_module.set_sample_config(config=config) 129 | elif mode == 'retrain': 130 | config = retrain_config 131 | model_module = unwrap_model(model) 132 | model_module.set_sample_config(config=config) 133 | if mixup_fn is not None: 134 | samples, targets = mixup_fn(samples, targets) 135 | 136 | optimizer.zero_grad() 137 | # this attribute is added by timm on one optimizer (adahessian) 138 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 139 | 140 | loss_value = 0.0 141 | with ExitStack() if not amp else torch.cuda.amp.autocast(): 142 | if teacher_model: 143 | with torch.no_grad(): 144 | teach_output = teacher_model(samples) 145 | _, teacher_label = teach_output.topk(1, 1, True, True) 146 | del teach_output 147 | teacher_label.squeeze_() 148 | factor = 1.0 / len(config) 149 | for cfg in config: 150 | model_module.set_sample_config(cfg) 151 | outputs = model(samples) 152 | # gt 153 | loss = 0.5 * factor * criterion(outputs, targets) 154 | bp_once(loss, loss_scaler, is_second_order) 155 | loss_value += loss.item() 156 | # teacher 157 | loss = 0.5 * factor * teach_loss(outputs, teacher_label) 158 | bp_once(loss, loss_scaler, is_second_order) 159 | loss_value += loss.item() 160 | else: 161 | factor = 1.0 / len(config) 162 | for cfg in config: 163 | model_module.set_sample_config(cfg) 164 | loss = factor * criterion(model(samples), targets) 165 | bp_once(loss, loss_scaler, is_second_order) 166 | loss_value += loss.item() 167 | 168 | if not math.isfinite(loss_value): 169 | print("Loss is {}, stopping training".format(loss_value)) 170 | #sys.exit(1) 171 | continue 172 | 173 | if amp: 174 | loss_scaler(loss, optimizer, clip_grad=max_norm, 175 | parameters=model.parameters(), create_graph=is_second_order) 176 | else: 177 | optimizer.step() 178 | 179 | torch.cuda.synchronize() 180 | if model_ema is not None: 181 | model_ema.update(model) 182 | 183 | metric_logger.update(loss=loss_value) 184 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 185 | 186 | # only check at the end of epoch (avoid flooding) 187 | print("DEBUG:train {}".format(config), force=print2file) 188 | 189 | # gather the stats from all processes 190 | metric_logger.synchronize_between_processes() 191 | print("Averaged stats:", metric_logger) 192 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 193 | 194 | @torch.no_grad() 195 | def evaluate(data_loader, model, device, amp=True, choices=None, mode='super', retrain_config=None, print2file=False, candidates=None, eval_crops=1): 196 | criterion = torch.nn.CrossEntropyLoss() 197 | 198 | metric_logger = utils.MetricLogger(delimiter=" ") 199 | header = 'Test:' 200 | 201 | # switch to evaluation mode 202 | model.eval() 203 | if mode == 'super': 204 | if candidates is not None: 205 | config = sample_candidates(candidates, eval=True) 206 | else: 207 | config = sample_configs(choices, eval=True) 208 | config = [config] 209 | if utils.is_dist_avail_and_initialized(): 210 | dist.broadcast_object_list(config, src=0) 211 | config = config[0] 212 | model_module = unwrap_model(model) 213 | model_module.set_sample_config(config=config) 214 | else: 215 | config = retrain_config 216 | model_module = unwrap_model(model) 217 | model_module.set_sample_config(config=config) 218 | 219 | 220 | print("DEBUG:eval sampled model config: {}".format(config), force=print2file) 221 | parameters = model_module.get_sampled_params_numel(config) 222 | print("DEBUG:eval sampled model parameters: {}".format(parameters), force=print2file) 223 | 224 | for images, target in metric_logger.log_every(data_loader, 10, header): 225 | images = images.to(device, non_blocking=True) 226 | target = target.to(device, non_blocking=True) 227 | 228 | if eval_crops > 1: 229 | bs, ncrops, c, h, w = images.size() 230 | images = images.view(-1, c, h, w) 231 | 232 | # compute output 233 | with ExitStack() if not amp else torch.cuda.amp.autocast(): 234 | output = model(images) 235 | if eval_crops > 1: 236 | output = output.view(bs, ncrops, -1).mean(1) 237 | loss = criterion(output, target) 238 | 239 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 240 | 241 | batch_size = images.shape[0] 242 | metric_logger.update(loss=loss.item()) 243 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 244 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 245 | # gather the stats from all processes 246 | metric_logger.synchronize_between_processes() 247 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 248 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 249 | 250 | metric_logger.update(n_parameters=parameters) 251 | 252 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 253 | -------------------------------------------------------------------------------- /supernet_train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import os 4 | import sys 5 | import numpy as np 6 | import time 7 | import torch 8 | import torch.backends.cudnn as cudnn 9 | import json 10 | import yaml 11 | from collections import defaultdict 12 | from pathlib import Path 13 | from pprint import pprint 14 | from timm.data import Mixup 15 | from timm.models import create_model 16 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 17 | from timm.scheduler import create_scheduler 18 | from timm.optim import create_optimizer 19 | #from timm.utils import NativeScaler 20 | from lib.cuda import NativeScaler 21 | from lib.datasets import build_dataset 22 | from supernet_engine import train_one_epoch, evaluate 23 | from lib.samplers import RASampler 24 | from lib import utils 25 | from lib.config import cfg, update_config_from_file 26 | from model.supernet_transformer import Vision_TransformerSuper 27 | 28 | 29 | def get_args_parser(): 30 | parser = argparse.ArgumentParser('AutoFormer training and evaluation script', add_help=False) 31 | parser.add_argument('--batch-size', default=64, type=int) 32 | parser.add_argument('--epochs', default=300, type=int) 33 | # config file 34 | parser.add_argument('--cfg',help='experiment configure file name',required=True,type=str) 35 | 36 | # custom parameters 37 | parser.add_argument('--platform', default='pai', type=str, choices=['itp', 'pai', 'aml'], 38 | help='Name of model to train') 39 | parser.add_argument('--teacher_model', default='', type=str, 40 | help='Name of teacher model to train') 41 | parser.add_argument('--relative_position', action='store_true') 42 | parser.add_argument('--gp', action='store_true') 43 | parser.add_argument('--change_qkv', action='store_true') 44 | parser.add_argument('--max_relative_position', type=int, default=14, help='max distance in relative position embedding') 45 | 46 | # Model parameters 47 | parser.add_argument('--model', default='', type=str, metavar='MODEL', 48 | help='Name of model to train') 49 | # AutoFormer config 50 | parser.add_argument('--mode', type=str, default='super', choices=['super', 'retrain'], help='mode of AutoFormer') 51 | parser.add_argument('--input-size', default=224, type=int) 52 | parser.add_argument('--patch_size', default=16, type=int) 53 | 54 | parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', 55 | help='Dropout rate (default: 0.)') 56 | parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT', 57 | help='Drop path rate (default: 0.1)') 58 | parser.add_argument('--drop-block', type=float, default=None, metavar='PCT', 59 | help='Drop block rate (default: None)') 60 | 61 | parser.add_argument('--model-ema', action='store_true') 62 | parser.add_argument('--no-model-ema', action='store_false', dest='model_ema') 63 | # parser.set_defaults(model_ema=True) 64 | parser.add_argument('--model-ema-decay', type=float, default=0.99996, help='') 65 | parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='') 66 | parser.add_argument('--rpe_type', type=str, default='bias', choices=['bias', 'direct']) 67 | parser.add_argument('--post_norm', action='store_true') 68 | parser.add_argument('--no_abs_pos', action='store_true') 69 | 70 | # Optimizer parameters 71 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', 72 | help='Optimizer (default: "adamw"') 73 | parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON', 74 | help='Optimizer Epsilon (default: 1e-8)') 75 | parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', 76 | help='Optimizer Betas (default: None, use opt default)') 77 | parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM', 78 | help='Clip gradient norm (default: None, no clipping)') 79 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 80 | help='SGD momentum (default: 0.9)') 81 | parser.add_argument('--weight-decay', type=float, default=0.05, 82 | help='weight decay (default: 0.05)') 83 | 84 | # Learning rate schedule parameters 85 | parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', 86 | help='LR scheduler (default: "cosine"') 87 | parser.add_argument('--lr', type=float, default=5e-4, metavar='LR', 88 | help='learning rate (default: 5e-4)') 89 | parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', 90 | help='learning rate noise on/off epoch percentages') 91 | parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', 92 | help='learning rate noise limit percent (default: 0.67)') 93 | parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', 94 | help='learning rate noise std-dev (default: 1.0)') 95 | parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR', 96 | help='warmup learning rate (default: 1e-6)') 97 | parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', 98 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') 99 | parser.add_argument('--lr-power', type=float, default=1.0, 100 | help='power of the polynomial lr scheduler') 101 | 102 | parser.add_argument('--decay-epochs', type=float, default=30, metavar='N', 103 | help='epoch interval to decay LR') 104 | parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N', 105 | help='epochs to warmup LR, if scheduler supports') 106 | parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', 107 | help='epochs to cooldown LR at min_lr, after cyclic schedule ends') 108 | parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', 109 | help='patience epochs for Plateau LR scheduler (default: 10') 110 | parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', 111 | help='LR decay rate (default: 0.1)') 112 | 113 | # Augmentation parameters 114 | parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', 115 | help='Color jitter factor (default: 0.4)') 116 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', 117 | help='Use AutoAugment policy. "v0" or "original". " + \ 118 | "(default: rand-m9-mstd0.5-inc1)'), 119 | parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)') 120 | parser.add_argument('--train-interpolation', type=str, default='bicubic', 121 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")') 122 | 123 | parser.add_argument('--repeated-aug', action='store_true') 124 | parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug') 125 | 126 | 127 | parser.set_defaults(repeated_aug=True) 128 | 129 | # * Random Erase params 130 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', 131 | help='Random erase prob (default: 0.25)') 132 | parser.add_argument('--remode', type=str, default='pixel', 133 | help='Random erase mode (default: "pixel")') 134 | parser.add_argument('--recount', type=int, default=1, 135 | help='Random erase count (default: 1)') 136 | parser.add_argument('--resplit', action='store_true', default=False, 137 | help='Do not random erase first (clean) augmentation split') 138 | 139 | # * Mixup params 140 | parser.add_argument('--mixup', type=float, default=0.8, 141 | help='mixup alpha, mixup enabled if > 0. (default: 0.8)') 142 | parser.add_argument('--cutmix', type=float, default=1.0, 143 | help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)') 144 | parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None, 145 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') 146 | parser.add_argument('--mixup-prob', type=float, default=1.0, 147 | help='Probability of performing mixup or cutmix when either/both is enabled') 148 | parser.add_argument('--mixup-switch-prob', type=float, default=0.5, 149 | help='Probability of switching to cutmix when both mixup and cutmix enabled') 150 | parser.add_argument('--mixup-mode', type=str, default='batch', 151 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') 152 | 153 | # Dataset parameters 154 | parser.add_argument('--data-path', default='./data/imagenet/', type=str, 155 | help='dataset path') 156 | parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19', 'EVO_IMNET'], 157 | type=str, help='Image Net dataset path') 158 | parser.add_argument('--inat-category', default='name', 159 | choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'], 160 | type=str, help='semantic granularity') 161 | 162 | parser.add_argument('--output_dir', default='./logs/', 163 | help='path where to save, empty for no saving') 164 | parser.add_argument('--task', default='', help='task prefix') 165 | parser.add_argument('--device', default='cuda', 166 | help='device to use for training / testing') 167 | parser.add_argument('--seed', default=0, type=int) 168 | parser.add_argument('--resume', default='', help='resume from checkpoint') 169 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 170 | help='start epoch') 171 | parser.add_argument('--eval', action='store_true', help='Perform evaluation only') 172 | parser.add_argument('--num_workers', default=10, type=int) 173 | parser.add_argument('--dist-eval', action='store_true', default=False, help='Enabling distributed evaluation') 174 | parser.add_argument('--pin-mem', action='store_true', 175 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 176 | parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem', 177 | help='') 178 | parser.set_defaults(pin_mem=True) 179 | 180 | # distributed training parameters 181 | parser.add_argument('--world_size', default=1, type=int, 182 | help='number of distributed processes') 183 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 184 | 185 | parser.add_argument('--amp', action='store_true') 186 | parser.add_argument('--no-amp', action='store_false', dest='amp') 187 | parser.set_defaults(amp=True) 188 | 189 | parser.add_argument('--print2file', action='store_true', default=False, help='save stdout to file') 190 | parser.add_argument('--candfile', default='', type=str, help='candidates json file') 191 | parser.add_argument('--group-by-dim', action='store_true', default=False, help='group candidates by embed_dim') 192 | parser.add_argument('--group-by-depth', action='store_true', default=False, help='group candidates by depth') 193 | parser.add_argument('--sandwich', default=0, type=int, help='number of interlayers in sandwich, default 0 to turn off') 194 | parser.add_argument('--no-sandwich-base', action='store_true', default=False, help='remove the base layer of sandwich') 195 | parser.add_argument('--no-sandwich-top', action='store_true', default=False, help='remove the top layer of sandwich') 196 | parser.add_argument('--switch-ln', action='store_true', default=False, help='Enabling switchable layernorm') 197 | parser.add_argument('--scale-attn', action='store_true', default=False, help='scale attention') 198 | parser.add_argument('--scale-mlp', action='store_true', default=False, help='scale mlp') 199 | parser.add_argument('--scale-embed', action='store_true', default=False, help='scale embed dim') 200 | parser.add_argument('--shuffle', action='store_true', default=False, help='shuffle chosen candidate') 201 | parser.add_argument('--eval-crops', default=1, type=int, choices=[1, 5, 10], help='number of crops for evaluation') 202 | 203 | return parser 204 | 205 | def main(args): 206 | 207 | utils.init_distributed_mode(args) 208 | update_config_from_file(args.cfg) 209 | 210 | print(args) 211 | args_text = yaml.safe_dump(args.__dict__, default_flow_style=False) 212 | 213 | if args.print2file: 214 | sys._stdout = sys.stdout 215 | sys.stdout = open(os.path.join(args.output_dir, f'out{utils.get_rank()}'), 'w', buffering=1) 216 | 217 | device = torch.device(args.device) 218 | 219 | # fix the seed for reproducibility 220 | seed = args.seed + utils.get_rank() 221 | torch.manual_seed(seed) 222 | np.random.seed(seed) 223 | # random.seed(seed) 224 | cudnn.benchmark = True 225 | 226 | dataset_train, args.nb_classes = build_dataset(is_train=True, args=args, folder_name="subImageNet" if args.data_set == "EVO_IMNET" else "train") 227 | dataset_val, _ = build_dataset(is_train=False, args=args, folder_name="val") 228 | 229 | if args.distributed: 230 | num_tasks = utils.get_world_size() 231 | global_rank = utils.get_rank() 232 | if args.repeated_aug: 233 | sampler_train = RASampler( 234 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 235 | ) 236 | else: 237 | sampler_train = torch.utils.data.DistributedSampler( 238 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 239 | ) 240 | if args.dist_eval: 241 | if len(dataset_val) % num_tasks != 0: 242 | print( 243 | 'Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 244 | 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 245 | 'equal num of samples per-process.') 246 | sampler_val = torch.utils.data.DistributedSampler( 247 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) 248 | else: 249 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 250 | else: 251 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 252 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 253 | 254 | data_loader_train = torch.utils.data.DataLoader( 255 | dataset_train, sampler=sampler_train, 256 | batch_size=args.batch_size, 257 | num_workers=args.num_workers, 258 | pin_memory=args.pin_mem, 259 | drop_last=True, 260 | ) 261 | 262 | data_loader_val = torch.utils.data.DataLoader( 263 | dataset_val, batch_size=int(2 * args.batch_size), 264 | sampler=sampler_val, num_workers=args.num_workers, 265 | pin_memory=args.pin_mem, drop_last=False 266 | ) 267 | 268 | mixup_fn = None 269 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None 270 | if mixup_active: 271 | mixup_fn = Mixup( 272 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, 273 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, 274 | label_smoothing=args.smoothing, num_classes=args.nb_classes) 275 | 276 | print(f"Creating SuperVisionTransformer") 277 | print(cfg) 278 | 279 | choices = {'num_heads': cfg.SEARCH_SPACE.NUM_HEADS, 'mlp_ratio': cfg.SEARCH_SPACE.MLP_RATIO, 280 | 'embed_dim': cfg.SEARCH_SPACE.EMBED_DIM , 'depth': cfg.SEARCH_SPACE.DEPTH} 281 | 282 | model = Vision_TransformerSuper(img_size=args.input_size, 283 | patch_size=args.patch_size, 284 | embed_dim=cfg.SUPERNET.EMBED_DIM, depth=cfg.SUPERNET.DEPTH, 285 | num_heads=cfg.SUPERNET.NUM_HEADS,mlp_ratio=cfg.SUPERNET.MLP_RATIO, 286 | qkv_bias=True, drop_rate=args.drop, 287 | drop_path_rate=args.drop_path, 288 | gp=args.gp, 289 | num_classes=args.nb_classes, 290 | max_relative_position=args.max_relative_position, 291 | relative_position=args.relative_position, 292 | change_qkv=args.change_qkv, abs_pos=not args.no_abs_pos, 293 | choices=choices, switch_ln=args.switch_ln, 294 | scale_attn=args.scale_attn, scale_mlp=args.scale_mlp, 295 | scale_embed=args.scale_embed) 296 | 297 | model.to(device) 298 | if args.teacher_model: 299 | teacher_model = create_model( 300 | args.teacher_model, 301 | pretrained=True, 302 | num_classes=args.nb_classes, 303 | ) 304 | teacher_model.to(device) 305 | teacher_loss = LabelSmoothingCrossEntropy(smoothing=args.smoothing) 306 | else: 307 | teacher_model = None 308 | teacher_loss = None 309 | 310 | model_ema = None 311 | 312 | print(model) 313 | 314 | model_without_ddp = model 315 | if args.distributed: 316 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True) 317 | model_without_ddp = model.module 318 | 319 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 320 | print('number of params:', n_parameters) 321 | 322 | linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0 323 | args.lr = linear_scaled_lr 324 | optimizer = create_optimizer(args, model_without_ddp) 325 | loss_scaler = NativeScaler() 326 | lr_scheduler, _ = create_scheduler(args, optimizer) 327 | 328 | # criterion = LabelSmoothingCrossEntropy() 329 | 330 | if args.mixup > 0.: 331 | # smoothing is handled with mixup label transform 332 | criterion = SoftTargetCrossEntropy() 333 | elif args.smoothing: 334 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) 335 | else: 336 | criterion = torch.nn.CrossEntropyLoss() 337 | 338 | output_dir = Path(args.output_dir) 339 | 340 | if not output_dir.exists(): 341 | output_dir.mkdir(parents=True) 342 | # save config for later experiments 343 | if args.output_dir and utils.is_main_process(): 344 | with open(output_dir / "config.yaml", 'w') as f: 345 | f.write(args_text) 346 | if args.resume: 347 | if args.resume.startswith('https'): 348 | checkpoint = torch.hub.load_state_dict_from_url( 349 | args.resume, map_location='cpu', check_hash=True) 350 | else: 351 | checkpoint = torch.load(args.resume, map_location='cpu') 352 | model_without_ddp.load_state_dict(checkpoint['model']) 353 | if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: 354 | optimizer.load_state_dict(checkpoint['optimizer']) 355 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 356 | args.start_epoch = checkpoint['epoch'] + 1 357 | if 'scaler' in checkpoint: 358 | loss_scaler.load_state_dict(checkpoint['scaler']) 359 | if args.model_ema: 360 | utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema']) 361 | 362 | retrain_config = None 363 | if args.mode == 'retrain' and "RETRAIN" in cfg: 364 | retrain_config = {'layer_num': cfg.RETRAIN.DEPTH, 'embed_dim': [cfg.RETRAIN.EMBED_DIM]*cfg.RETRAIN.DEPTH, 365 | 'num_heads': cfg.RETRAIN.NUM_HEADS,'mlp_ratio': cfg.RETRAIN.MLP_RATIO} 366 | candidates = None 367 | cand_index_group = None 368 | if args.candfile: 369 | cand_dict = json.load(open(args.candfile)) 370 | candidates = [cand for cand_list in cand_dict.values() for cand in cand_list] 371 | if args.group_by_dim or args.group_by_depth: 372 | cand_index_group = defaultdict(list) 373 | for idx, cand in enumerate(candidates): 374 | k = () 375 | k = k + (cand['embed_dim'][0],) if args.group_by_dim else k 376 | k = k + (cand['layer_num'],) if args.group_by_depth else k 377 | cand_index_group[k].append(idx) 378 | cand_index_group = list(cand_index_group.values()) 379 | 380 | if args.eval: 381 | if args.candfile: 382 | batch_stats = [] 383 | for retrain_config in candidates: 384 | test_stats = evaluate(data_loader_val, model, device, mode = args.mode, retrain_config=retrain_config, print2file=args.print2file, eval_crops=args.eval_crops) 385 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 386 | test_stats['config'] = retrain_config 387 | batch_stats.append(test_stats) 388 | if args.output_dir and utils.is_main_process(): 389 | json.dump(batch_stats, open(output_dir / 'results.json', 'w'), indent=2) 390 | else: 391 | test_stats = evaluate(data_loader_val, model, device, mode = args.mode, retrain_config=retrain_config, print2file=args.print2file, eval_crops=args.eval_crops) 392 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 393 | 394 | return 395 | 396 | print("Start training") 397 | start_time = time.time() 398 | max_accuracy = 0.0 399 | 400 | for epoch in range(args.start_epoch, args.epochs): 401 | if args.distributed: 402 | data_loader_train.sampler.set_epoch(epoch) 403 | 404 | train_stats = train_one_epoch( 405 | model, criterion, data_loader_train, 406 | optimizer, device, epoch, loss_scaler, 407 | args.clip_grad, model_ema, mixup_fn, 408 | amp=args.amp, teacher_model=teacher_model, 409 | teach_loss=teacher_loss, 410 | choices=choices, mode = args.mode, retrain_config=retrain_config, 411 | print2file=args.print2file, candidates=candidates, sandwich=args.sandwich, 412 | sandwich_base=not args.no_sandwich_base, sandwich_top=not args.no_sandwich_top, 413 | shuffle=args.shuffle, grouping=cand_index_group, 414 | ) 415 | 416 | lr_scheduler.step(epoch) 417 | if args.output_dir and (epoch + 1) % 10 == 0: 418 | checkpoint_paths = [output_dir / 'checkpoint-{}.pth'.format(epoch+1)] 419 | for checkpoint_path in checkpoint_paths: 420 | utils.save_on_master({ 421 | 'model': model_without_ddp.state_dict(), 422 | 'optimizer': optimizer.state_dict(), 423 | 'lr_scheduler': lr_scheduler.state_dict(), 424 | 'epoch': epoch, 425 | # 'model_ema': get_state_dict(model_ema), 426 | 'scaler': loss_scaler.state_dict(), 427 | 'args': args, 428 | }, checkpoint_path) 429 | 430 | test_stats = evaluate(data_loader_val, model, device, amp=args.amp, choices=choices, mode = args.mode, retrain_config=retrain_config, print2file=args.print2file, candidates=candidates, eval_crops=args.eval_crops) 431 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 432 | max_accuracy = max(max_accuracy, test_stats["acc1"]) 433 | print(f'Max accuracy: {max_accuracy:.2f}%') 434 | 435 | log_stats = {'datetime': datetime.datetime.now().strftime("%m/%d %H:%M"), 436 | **{f'train_{k}': v for k, v in train_stats.items()}, 437 | **{f'test_{k}': v for k, v in test_stats.items()}, 438 | 'epoch': epoch} 439 | 440 | if args.output_dir and utils.is_main_process(): 441 | with (output_dir / "log.txt").open("a") as f: 442 | f.write(json.dumps(log_stats) + "\n") 443 | 444 | total_time = time.time() - start_time 445 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 446 | print('Training time {}'.format(total_time_str)) 447 | 448 | 449 | if __name__ == '__main__': 450 | now = datetime.datetime.now().strftime("%Y_%m_%d-%H_%M_%S") 451 | parser = argparse.ArgumentParser('AutoFormer training and evaluation script', parents=[get_args_parser()]) 452 | args = parser.parse_args() 453 | if args.task: 454 | suffix = args.task 455 | elif args.resume: 456 | resume_folder = os.path.basename(os.path.dirname(os.path.normpath(args.resume))) 457 | suffix = resume_folder.partition('@')[-1] 458 | else: 459 | suffix = '' 460 | sep = '@' if suffix else '' 461 | args.output_dir = os.path.join(args.output_dir, 'test' if args.eval else 'train', now+sep+suffix) 462 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 463 | main(args) 464 | -------------------------------------------------------------------------------- /two_step_search.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import numpy as np 4 | import time 5 | import torch 6 | import torch.backends.cudnn as cudnn 7 | import json 8 | import yaml 9 | import random 10 | from pathlib import Path 11 | from timm.data import Mixup 12 | from timm.models import create_model 13 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 14 | from timm.utils import NativeScaler 15 | from lib.datasets import build_dataset 16 | from lib.samplers import RASampler 17 | from lib import utils 18 | from lib.config import cfg, update_config_from_file 19 | from lib.score_maker import ScoreMaker 20 | from model.supernet_transformer import Vision_TransformerSuper 21 | from evolution_pre_train import Searcher 22 | 23 | 24 | def get_args_parser(): 25 | parser = argparse.ArgumentParser('AutoFormer training and evaluation script', add_help=False) 26 | parser.add_argument('--batch-size', default=64, type=int) 27 | # config file 28 | parser.add_argument('--cfg',help='experiment configure file name',required=True,type=str) 29 | 30 | # custom parameters 31 | parser.add_argument('--platform', default='pai', type=str, choices=['itp', 'pai', 'aml'], 32 | help='Name of model to train') 33 | parser.add_argument('--teacher_model', default='', type=str, 34 | help='Name of teacher model to train') 35 | parser.add_argument('--relative_position', action='store_true') 36 | parser.add_argument('--gp', action='store_true') 37 | parser.add_argument('--change_qkv', action='store_true') 38 | parser.add_argument('--max_relative_position', type=int, default=14, help='max distance in relative position embedding') 39 | 40 | # Model parameters 41 | parser.add_argument('--model', default='', type=str, metavar='MODEL', 42 | help='Name of model to train') 43 | 44 | parser.add_argument('--input-size', default=224, type=int) 45 | parser.add_argument('--patch_size', default=16, type=int) 46 | 47 | parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', 48 | help='Dropout rate (default: 0.)') 49 | parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT', 50 | help='Drop path rate (default: 0.1)') 51 | parser.add_argument('--drop-block', type=float, default=None, metavar='PCT', 52 | help='Drop block rate (default: None)') 53 | 54 | parser.add_argument('--model-ema', action='store_true') 55 | parser.add_argument('--no-model-ema', action='store_false', dest='model_ema') 56 | # parser.set_defaults(model_ema=True) 57 | parser.add_argument('--model-ema-decay', type=float, default=0.99996, help='') 58 | parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='') 59 | parser.add_argument('--rpe_type', type=str, default='bias', choices=['bias', 'direct']) 60 | parser.add_argument('--post_norm', action='store_true') 61 | parser.add_argument('--no_abs_pos', action='store_true') 62 | 63 | # Optimizer parameters 64 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', 65 | help='Optimizer (default: "adamw"') 66 | parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON', 67 | help='Optimizer Epsilon (default: 1e-8)') 68 | parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', 69 | help='Optimizer Betas (default: None, use opt default)') 70 | parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM', 71 | help='Clip gradient norm (default: None, no clipping)') 72 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 73 | help='SGD momentum (default: 0.9)') 74 | parser.add_argument('--weight-decay', type=float, default=0.05, 75 | help='weight decay (default: 0.05)') 76 | 77 | # Learning rate schedule parameters 78 | parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', 79 | help='LR scheduler (default: "cosine"') 80 | parser.add_argument('--lr', type=float, default=5e-4, metavar='LR', 81 | help='learning rate (default: 5e-4)') 82 | parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', 83 | help='learning rate noise on/off epoch percentages') 84 | parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', 85 | help='learning rate noise limit percent (default: 0.67)') 86 | parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', 87 | help='learning rate noise std-dev (default: 1.0)') 88 | parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR', 89 | help='warmup learning rate (default: 1e-6)') 90 | parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', 91 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') 92 | parser.add_argument('--lr-power', type=float, default=1.0, 93 | help='power of the polynomial lr scheduler') 94 | 95 | parser.add_argument('--decay-epochs', type=float, default=30, metavar='N', 96 | help='epoch interval to decay LR') 97 | parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N', 98 | help='epochs to warmup LR, if scheduler supports') 99 | parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', 100 | help='epochs to cooldown LR at min_lr, after cyclic schedule ends') 101 | parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', 102 | help='patience epochs for Plateau LR scheduler (default: 10') 103 | parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', 104 | help='LR decay rate (default: 0.1)') 105 | 106 | # Augmentation parameters 107 | parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', 108 | help='Color jitter factor (default: 0.4)') 109 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', 110 | help='Use AutoAugment policy. "v0" or "original". " + \ 111 | "(default: rand-m9-mstd0.5-inc1)'), 112 | parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)') 113 | parser.add_argument('--train-interpolation', type=str, default='bicubic', 114 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")') 115 | 116 | parser.add_argument('--repeated-aug', action='store_true') 117 | parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug') 118 | 119 | # * Random Erase params 120 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', 121 | help='Random erase prob (default: 0.25)') 122 | parser.add_argument('--remode', type=str, default='pixel', 123 | help='Random erase mode (default: "pixel")') 124 | parser.add_argument('--recount', type=int, default=1, 125 | help='Random erase count (default: 1)') 126 | parser.add_argument('--resplit', action='store_true', default=False, 127 | help='Do not random erase first (clean) augmentation split') 128 | 129 | # * Mixup params 130 | parser.add_argument('--mixup', type=float, default=0.8, 131 | help='mixup alpha, mixup enabled if > 0. (default: 0.8)') 132 | parser.add_argument('--cutmix', type=float, default=1.0, 133 | help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)') 134 | parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None, 135 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') 136 | parser.add_argument('--mixup-prob', type=float, default=1.0, 137 | help='Probability of performing mixup or cutmix when either/both is enabled') 138 | parser.add_argument('--mixup-switch-prob', type=float, default=0.5, 139 | help='Probability of switching to cutmix when both mixup and cutmix enabled') 140 | parser.add_argument('--mixup-mode', type=str, default='batch', 141 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') 142 | 143 | # Dataset parameters 144 | parser.add_argument('--data-path', default='./data/imagenet/', type=str, 145 | help='dataset path') 146 | parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19'], 147 | type=str, help='Image Net dataset path') 148 | parser.add_argument('--inat-category', default='name', 149 | choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'], 150 | type=str, help='semantic granularity') 151 | parser.add_argument('--output_dir', default='./', 152 | help='path where to save, empty for no saving') 153 | parser.add_argument('--device', default='cuda', 154 | help='device to use for training / testing') 155 | parser.add_argument('--seed', default=0, type=int) 156 | parser.add_argument('--resume', default='', help='resume from checkpoint') 157 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 158 | help='start epoch') 159 | parser.add_argument('--eval', action='store_true', help='Perform evaluation only') 160 | parser.add_argument('--num_workers', default=10, type=int) 161 | parser.add_argument('--dist-eval', action='store_true', default=False, help='Enabling distributed evaluation') 162 | parser.add_argument('--pin-mem', action='store_true', 163 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 164 | parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem') 165 | parser.set_defaults(pin_mem=True) 166 | 167 | # distributed training parameters 168 | parser.add_argument('--world_size', default=1, type=int, 169 | help='number of distributed processes') 170 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 171 | parser.add_argument('--amp', action='store_true') 172 | parser.add_argument('--no-amp', action='store_false', dest='amp') 173 | parser.add_argument('--score-method', default='left_super_taylor9', type=str, 174 | help='Score method in step two') 175 | parser.add_argument('--block-score-method-for-head', default='balance_taylor5_max_dim', type=str, 176 | help='Score method for head in step one') 177 | parser.add_argument('--block-score-method-for-mlp', default='deeper_is_better', type=str, 178 | help='Score method for mlp in step one') 179 | parser.add_argument('--candidate-path', default='the path of interval candidates',type=str) 180 | parser.add_argument('--super-model-size', default='T', type=str) 181 | parser.add_argument('--interval-cands-output', default='./out/interval_candidates.pt', type=str) 182 | parser.add_argument('--min_param_limits', default=4, type=float) 183 | parser.add_argument('--param_limits', default=12, type=float) 184 | parser.add_argument('--param-interval', default=2, type=float) 185 | parser.add_argument('--cand-per-interval', default=1, type=int) 186 | parser.add_argument('--population-num', default=50, type=int) 187 | parser.add_argument('--max-epochs', default=20, type=int) 188 | parser.add_argument('--select-num', type=int, default=20) 189 | parser.add_argument('--m_prob', type=float, default=0.2) 190 | parser.add_argument('--s_prob', type=float, default=0.4) 191 | parser.add_argument('--crossover-num', type=int, default=25) 192 | parser.add_argument('--mutation-num', type=int, default=25) 193 | parser.add_argument('--data-free', action='store_true', help='False if use the data to get gradient.') 194 | parser.add_argument('--reallocate', action='store_true', help='if reallocate when random and evolution search.') 195 | parser.add_argument('--avg-dim-sample', action='store_true', help='True if sample the dimension in uniform distribution.') 196 | parser.add_argument('--search-mode', default='iteration', choices=['iteration', 'random', 'evolution'], type=str, help='The mode of search the candidates.') 197 | parser.set_defaults(amp=True) 198 | 199 | return parser 200 | 201 | 202 | def main(args): 203 | 204 | utils.init_distributed_mode(args) 205 | update_config_from_file(args.cfg) 206 | 207 | print(args) 208 | args_text = yaml.safe_dump(args.__dict__, default_flow_style=False) 209 | 210 | device = torch.device(args.device) 211 | 212 | # fix the seed for reproducibility 213 | seed = args.seed + utils.get_rank() 214 | torch.manual_seed(seed) 215 | np.random.seed(seed) 216 | # random.seed(seed) 217 | cudnn.benchmark = True 218 | 219 | dataset_sub_train, args.nb_classes = build_dataset(is_train=True, args=args, folder_name="subImageNet") 220 | 221 | if args.distributed: 222 | num_tasks = utils.get_world_size() 223 | global_rank = utils.get_rank() 224 | if args.repeated_aug: 225 | sampler_sub_train = RASampler( 226 | dataset_sub_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 227 | ) 228 | else: 229 | sampler_sub_train = torch.utils.data.DistributedSampler( 230 | dataset_sub_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 231 | ) 232 | else: 233 | sampler_sub_train = torch.utils.data.RandomSampler(dataset_sub_train) 234 | 235 | data_loader_sub_train = torch.utils.data.DataLoader( 236 | dataset_sub_train, batch_size=args.batch_size, 237 | sampler=sampler_sub_train, num_workers=args.num_workers, 238 | pin_memory=args.pin_mem, drop_last=False 239 | ) 240 | 241 | mixup_fn = None 242 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None 243 | if mixup_active: 244 | mixup_fn = Mixup( 245 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, 246 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, 247 | label_smoothing=args.smoothing, num_classes=args.nb_classes) 248 | 249 | print(f"Creating SuperVisionTransformer") 250 | print(cfg) 251 | model = Vision_TransformerSuper(img_size=args.input_size, 252 | patch_size=args.patch_size, 253 | embed_dim=cfg.SUPERNET.EMBED_DIM, depth=cfg.SUPERNET.DEPTH, 254 | num_heads=cfg.SUPERNET.NUM_HEADS, mlp_ratio=cfg.SUPERNET.MLP_RATIO, 255 | qkv_bias=True, drop_rate=args.drop, 256 | drop_path_rate=args.drop_path, 257 | gp=args.gp, 258 | num_classes=args.nb_classes, 259 | max_relative_position=args.max_relative_position, 260 | relative_position=args.relative_position, 261 | change_qkv=args.change_qkv, abs_pos=not args.no_abs_pos) 262 | 263 | choices = {'num_heads': cfg.SEARCH_SPACE.NUM_HEADS, 'mlp_ratio': cfg.SEARCH_SPACE.MLP_RATIO, 264 | 'embed_dim': cfg.SEARCH_SPACE.EMBED_DIM, 'depth': cfg.SEARCH_SPACE.DEPTH} 265 | 266 | model.to(device) 267 | 268 | model_without_ddp = model 269 | if args.distributed: 270 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True) 271 | model_without_ddp = model.module 272 | 273 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 274 | print('number of params:', n_parameters) 275 | 276 | if args.mixup > 0.: 277 | # smoothing is handled with mixup label transform 278 | criterion = SoftTargetCrossEntropy() 279 | elif args.smoothing: 280 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) 281 | else: 282 | criterion = torch.nn.CrossEntropyLoss() 283 | 284 | output_dir = Path(args.output_dir) 285 | 286 | if not output_dir.exists(): 287 | output_dir.mkdir(parents=True) 288 | 289 | if args.resume: 290 | print('resume') 291 | if args.resume.startswith('https'): 292 | checkpoint = torch.hub.load_state_dict_from_url( 293 | args.resume, map_location='cpu', check_hash=True) 294 | else: 295 | checkpoint = torch.load(args.resume, map_location='cpu') 296 | model_without_ddp.load_state_dict(checkpoint['model']) 297 | if 'epoch' in checkpoint: 298 | args.start_epoch = checkpoint['epoch'] + 1 299 | 300 | print("Start search candidate") 301 | score_maker = ScoreMaker() 302 | score_maker.get_gradient(model, criterion, data_loader_sub_train, args, choices, device, mixup_fn=mixup_fn) 303 | evolution_searcher = Searcher(args, device, model, model_without_ddp, choices, output_dir, score_maker) 304 | print(evolution_searcher.get_params_range()) 305 | interval_candidates = evolution_searcher.search(args.interval_cands_output) 306 | score_maker.drop_gradient() 307 | 308 | 309 | if __name__ == '__main__': 310 | parser = argparse.ArgumentParser('AutoFormer training and evaluation script', parents=[get_args_parser()]) 311 | args = parser.parse_args() 312 | main(args) 313 | --------------------------------------------------------------------------------