├── LICENSE ├── README.md ├── analyze_flops.py ├── analyze_flops_plt.py ├── config ├── mcea_searching.yaml └── training_super_supernet.yaml ├── core ├── __init__.py ├── agent │ ├── __init__.py │ └── nas_agent.py ├── dataset │ ├── augmentations │ │ ├── __init__.py │ │ └── augmentation.py │ ├── build_dataloader.py │ ├── datasets │ │ ├── __init__.py │ │ ├── base_mc_dataset.py │ │ └── single_task_mc_dataset.py │ └── samplers │ │ ├── __init__.py │ │ ├── base_sampler.py │ │ └── dist_iter_batch_sampler.py ├── model │ ├── __init__.py │ ├── basemodel.py │ └── net.py ├── sampler │ ├── __init__.py │ ├── base_sampler.py │ ├── build_sampler.py │ └── evolution │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── evolution_sampler.cpython-36.pyc │ │ ├── evolution_sampler.cpython-37.pyc │ │ ├── nsganet.cpython-36.pyc │ │ ├── nsganet.cpython-37.pyc │ │ ├── sense_evolution_sampler.cpython-36.pyc │ │ └── sense_evolution_sampler.cpython-37.pyc │ │ ├── evolution_sampler.py │ │ └── nsganet.py ├── search_space │ ├── __init__.py │ ├── model_initializer.py │ └── ops.py ├── searcher │ ├── __init__.py │ ├── base_searcher.py │ ├── build_searcher.py │ └── uniform_searcher.py └── utils │ ├── __init__.py │ ├── arch_util.py │ ├── flops.py │ ├── logger.py │ ├── lr_scheduler.py │ ├── measure.py │ ├── misc.py │ └── optimizer.py ├── gen_subnet.py ├── get_best_subnet.py ├── sample ├── run.sh ├── search.sh └── train.sh └── tools ├── __init__.py ├── __pycache__ ├── __init__.cpython-36.pyc ├── __init__.cpython-37.pyc ├── dist_init.cpython-36.pyc └── dist_init.cpython-37.pyc ├── agent_run.py ├── dist_init.py ├── eval ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── base_tester.cpython-36.pyc │ ├── base_tester.cpython-37.pyc │ ├── build_tester.cpython-36.pyc │ └── build_tester.cpython-37.pyc ├── base_tester.py ├── build_tester.py └── imagenet │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── tester.cpython-36.pyc │ └── tester.cpython-37.pyc │ └── tester.py └── trainer ├── __init__.py ├── __pycache__ ├── __init__.cpython-36.pyc ├── __init__.cpython-37.pyc ├── base_trainer.cpython-36.pyc ├── base_trainer.cpython-37.pyc ├── build_trainer.cpython-36.pyc └── build_trainer.cpython-37.pyc ├── base_trainer.py ├── build_trainer.py └── imagenet ├── __init__.py ├── __pycache__ ├── __init__.cpython-36.pyc ├── __init__.cpython-37.pyc ├── trainer.cpython-36.pyc └── trainer.cpython-37.pyc └── trainer.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ScaleNet 2 | ScaleNet: Searching for the Model to Scale (ECCV 2022) [ECVA](https://www.ecva.net/papers.php "ECVA") or [ArXiv](https://arxiv.org/abs/2207.07267 "ArXiv") 3 | 4 | ## Requirements 5 | - python >= 3.6 6 | - 1.0.0 <= PyTorch <= 1.3.0 7 | - torchvision >= 0.3.0 8 | - pymoo == 0.3.0 9 | - pymop == latest version 10 | 11 | ## Super-supernet Training 12 | - Download datasets 13 | - Run: `bash train.sh` 14 | 15 | ## Searching 16 | - Run: `bash search.sh` 17 | 18 | ## Retraining 19 | - Please directly apply the package [TIMM](https://github.com/rwightman/pytorch-image-models "TIMM"). 20 | 21 | ## Warning 22 | - dataloader and dataset may need to be modified for adapting your environment. 23 | 24 | ## Citation 25 | If you find this paper useful in your research, please consider citing: 26 | ``` 27 | @InProceedings{xie2022scalenet, 28 | author={Jiyang Xie and Xiu Su and Shan You and Zhanyu Ma and Fei Wang and Chen Qian}, 29 | booktitle={European Conference on Computer Vision (ECCV)}, 30 | title={{ScaleNet}: {S}earching for the Model to Scale}, 31 | year={2022}, 32 | volume={30}, 33 | } 34 | ``` 35 | -------------------------------------------------------------------------------- /analyze_flops.py: -------------------------------------------------------------------------------- 1 | import math 2 | res = {} 3 | f = open('statistic_flops.txt', 'r') 4 | for line in f: 5 | if len(line) == 0: 6 | continue 7 | flops = int(line.split(',')[-1]) 8 | f_quanti = math.floor(flops / 1e8) 9 | if res.get(f_quanti) is None: 10 | res[f_quanti] = 0 11 | res[f_quanti] += 1 12 | 13 | print(res) 14 | 15 | res_l = [] 16 | for k in res: 17 | res_l.append((k, res[k])) 18 | 19 | res_l.sort(key=lambda x: x[0]) 20 | 21 | total = sum([x[1] for x in res_l]) 22 | print('total num: {}'.format(sum([x[1] for x in res_l]))) 23 | print(', '.join(['{}00M: {}'.format(x[0], x[1]) for x in res_l])) 24 | print(', '.join(['{}00M: {:.3f}'.format(x[0], x[1] / total) for x in res_l])) 25 | 26 | -------------------------------------------------------------------------------- /analyze_flops_plt.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | f = open('statistic_flops.txt', 'r') 4 | data = [] 5 | for line in f: 6 | if len(line) == 0: 7 | continue 8 | flops = int(line.split(',')[-1]) 9 | data.append(flops / 1e6) 10 | 11 | bins = [x*50 for x in range(1, 16)] 12 | plt.hist(data, bins=bins, normed=0, facecolor="blue", edgecolor="black", alpha=0.7) 13 | plt.xlabel('FLOPs') 14 | plt.ylabel('number') 15 | plt.title('FLOPs statistics') 16 | plt.savefig('flops_statistics.png') 17 | -------------------------------------------------------------------------------- /config/mcea_searching.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | # Search space of base model 3 | # [n, stride, c_in, c_out, [expand_ratio], channel_search, [op1, op2, op3]] 4 | # NOTICE: by default, in a stage, only first layer will use stride!=1 and change channel to oup 5 | backbone: 6 | conv_stem: [1, 2, 3, 32, [], False, ['conv3x3']] 7 | stage1: [1, 1, 32, 16, [[1], [1], [1], [1]], False, 8 | ['id', 'ir_3x3_se', 'ir_5x5_se', 'ir_7x7_se']] 9 | stage2: [4, 2, 16, 32, [[1], [6], [6], [6]], False, 10 | ['id', 'ir_3x3_se', 'ir_5x5_se', 'ir_7x7_se']] 11 | stage3: [4, 2, 32, 40, [[1], [6], [6], [6]], False, 12 | ['id', 'ir_3x3_se', 'ir_5x5_se', 'ir_7x7_se']] 13 | stage4: [4, 2, 40, 80, [[1], [6], [6], [6]], False, 14 | ['id', 'ir_3x3_se', 'ir_5x5_se', 'ir_7x7_se']] 15 | stage5: [4, 1, 80, 96, [[1], [6], [6], [6]], False, 16 | ['id', 'ir_3x3_se', 'ir_5x5_se', 'ir_7x7_se']] 17 | stage6: [4, 2, 96, 192, [[1], [6], [6], [6]], False, 18 | ['id', 'ir_3x3_se', 'ir_5x5_se', 'ir_7x7_se']] 19 | stage7: [1, 1, 192, 320, [[1], [6], [6], [6]], False, 20 | ['id', 'ir_3x3_se', 'ir_5x5_se', 'ir_7x7_se']] 21 | conv_out: [1, 1, 320, 1280, [], False, ['conv2d']] 22 | final_pooling: True 23 | head: 24 | linear1: 25 | dim_in: 1280 26 | dim_out: 1000 27 | loss_type: 's-softmax' 28 | 29 | search: 30 | flag: False 31 | searcher: 32 | type: ['uniform'] 33 | start_iter: [0] 34 | depth_multiplier: [[1.00], # scaling stage 0 35 | [1.04, 1.08, 1.12, 1.16], # scaling stage 1 36 | [1.20, 1.24, 1.28, 1.32, 1.36], # scaling stage 2 37 | [1.40, 1.44, 1.48, 1.52, 1.56, 1.60, 1.64]] # scaling stage 3 38 | channel_multiplier: [[1.00], # scaling stage 0 39 | [1.04, 1.08, 1.12, 1.16], # scaling stage 1 40 | [1.20, 1.24, 1.28, 1.32, 1.36], # scaling stage 2 41 | [1.40, 1.44, 1.48, 1.52, 1.56, 1.60, 1.64]] # scaling stage 3 42 | resolution_multiplier: [[224], 43 | [224, 240, 256], 44 | [272, 288, 304], 45 | [320, 336, 354]] 46 | max_scaling_stage: 3 # int 47 | n_laterally_couplng: 2 # n-laterally couplng for channels, n=2^j, 48 | # 0 or 1: AutoSlim, 2: BCNet 49 | asyn: True # Asynchronous n-laterally couplng for channels 50 | 51 | strategy: 52 | max_iter: 750000 53 | optimizer: 54 | type: 'SGD' 55 | lr: 0.12 56 | weight_decay: 0.00004 57 | momentum: 0.9 58 | nesterov: True 59 | lr_scheduler: 60 | lr_steps: [50080, 100160, 125200] 61 | lr_mults: [0.1, 0.1, 0.1] 62 | warmup_steps: 375 63 | warmup_strategy: 'gradual' 64 | warmup_lr: 0.2 65 | decay_stg: 'cosine' 66 | # final lr in cosine strategy 67 | alpha: 0. 68 | # how many iterations it takes to decay lr to 'alpha' 69 | decay_step: 750000 70 | 71 | task_type: 'imagenet' 72 | snapshot_freq: 1000 73 | print_freq: 100 74 | resume: True 75 | save_path: '../generalNAS_exp/scaling' 76 | load_name: 'latest.pth.tar' 77 | 78 | data: 79 | workers: 6 # dataloader worker num 80 | task_type: 'imagenet' 81 | data_type: 'ssst' 82 | scatter_mode: False 83 | final_height: 224 84 | final_width: 224 85 | final_channel: 3 86 | augmentation: 87 | rand_resize: 88 | output_size: 224 89 | scale: [0.08, 1.0] 90 | ratio: [0.75, 1.33] 91 | # resize 92 | resize: 93 | output_size: [224, 224] 94 | # normalize 95 | normalize: 96 | normalize_type: 'mean_std' 97 | mean: [123.675, 116.28, 103.53] 98 | std: [58.395, 57.120, 57.375] 99 | imagenet: 100 | type: 'classification' 101 | task: 'imagenet' 102 | json_path: /mnt/lustre/xiejiyang/nas/imagenet.json 103 | prefix: '/mnt/lustreold/share/images/train' 104 | batch_size: 8 # for single gpu 105 | 106 | sample: 107 | flag: True 108 | sampler: 109 | type: 'evolution' 110 | kwargs: 111 | flops_constraint: 306e6 112 | pop_size: 64 113 | n_gens: 40 114 | sample_num: 50 115 | 116 | pop_size_scaling: 64 117 | n_gens_scaling: 40 118 | sample_num_scaling: 50 119 | n_basenet: 20 120 | admm_iter: 8 121 | max_scaling_stage: 3 122 | strategy: 123 | task_type: 'imagenet-test' 124 | snapshot_freq: 1000 125 | print_freq: 100 126 | save_path: '../generalNAS_exp/scaling' 127 | load_name: 'latest.pth.tar' 128 | resume_name: 'arch_topk.yaml' 129 | 130 | data: 131 | workers: 128 # dataloader worker num 132 | task_type: 'imagenet-test' 133 | data_type: 'ssst' 134 | scatter_mode: False 135 | final_height: 224 136 | final_width: 224 137 | final_channel: 3 138 | augmentation: 139 | # resize 140 | resize: 141 | output_size: 256 142 | center_crop: 143 | output_size: 224 144 | # normalize 145 | normalize: 146 | normalize_type: 'mean_std' 147 | mean: [123.675, 116.28, 103.53] 148 | std: [58.395, 57.120, 57.375] 149 | imagenet: 150 | type: 'classification' 151 | task: 'imagenet' 152 | json_path: '/mnt/lustre/xiejiyang/nas/imagenet_minival.json' #'/mnt/lustre/xiejiyang/nas/imagenet_val.json' 153 | prefix: '/mnt/lustreold/share/images/val' 154 | batch_size: 64 # for single gpu 155 | -------------------------------------------------------------------------------- /config/training_super_supernet.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | # Search space of base model 3 | # [n, stride, c_in, c_out, [expand_ratio], channel_search, [op1, op2, op3]] 4 | # NOTICE: by default, in a stage, only first layer will use stride!=1 and change channel to oup 5 | backbone: 6 | conv_stem: [1, 2, 3, 32, [], False, ['conv3x3']] 7 | stage1: [1, 1, 32, 16, [[1], [1], [1], [1]], False, 8 | ['id', 'ir_3x3_se', 'ir_5x5_se', 'ir_7x7_se']] 9 | stage2: [4, 2, 16, 32, [[1], [6], [6], [6]], False, 10 | ['id', 'ir_3x3_se', 'ir_5x5_se', 'ir_7x7_se']] 11 | stage3: [4, 2, 32, 40, [[1], [6], [6], [6]], False, 12 | ['id', 'ir_3x3_se', 'ir_5x5_se', 'ir_7x7_se']] 13 | stage4: [4, 2, 40, 80, [[1], [6], [6], [6]], False, 14 | ['id', 'ir_3x3_se', 'ir_5x5_se', 'ir_7x7_se']] 15 | stage5: [4, 1, 80, 96, [[1], [6], [6], [6]], False, 16 | ['id', 'ir_3x3_se', 'ir_5x5_se', 'ir_7x7_se']] 17 | stage6: [4, 2, 96, 192, [[1], [6], [6], [6]], False, 18 | ['id', 'ir_3x3_se', 'ir_5x5_se', 'ir_7x7_se']] 19 | stage7: [1, 1, 192, 320, [[1], [6], [6], [6]], False, 20 | ['id', 'ir_3x3_se', 'ir_5x5_se', 'ir_7x7_se']] 21 | conv_out: [1, 1, 320, 1280, [], False, ['conv2d']] 22 | final_pooling: True 23 | head: 24 | linear1: 25 | dim_in: 1280 26 | dim_out: 1000 27 | loss_type: 's-softmax' 28 | 29 | search: 30 | flag: True 31 | searcher: 32 | type: ['uniform'] 33 | start_iter: [0] 34 | depth_multiplier: [[1.00], # scaling stage 0 35 | [1.04, 1.08, 1.12, 1.16], # scaling stage 1 36 | [1.20, 1.24, 1.28, 1.32, 1.36], # scaling stage 2 37 | [1.40, 1.44, 1.48, 1.52, 1.56, 1.60, 1.64]] # scaling stage 3 38 | channel_multiplier: [[1.00], # scaling stage 0 39 | [1.04, 1.08, 1.12, 1.16], # scaling stage 1 40 | [1.20, 1.24, 1.28, 1.32, 1.36], # scaling stage 2 41 | [1.40, 1.44, 1.48, 1.52, 1.56, 1.60, 1.64]] # scaling stage 3 42 | resolution_multiplier: [[224], 43 | [224, 240, 256], 44 | [272, 288, 304], 45 | [320, 336, 354]] 46 | max_scaling_stage: 3 # int 47 | n_laterally_couplng: 2 # n-laterally couplng for channels, n=2^j, 48 | # 0 or 1: AutoSlim, 2: BCNet 49 | asyn: True # Asynchronous n-laterally couplng for channels 50 | 51 | strategy: 52 | max_iter: 750000 53 | optimizer: 54 | type: 'SGD' 55 | lr: 0.12 56 | weight_decay: 0.00004 57 | momentum: 0.9 58 | nesterov: True 59 | lr_scheduler: 60 | lr_steps: [50080, 100160, 125200] 61 | lr_mults: [0.1, 0.1, 0.1] 62 | warmup_steps: 375 63 | warmup_strategy: 'gradual' 64 | warmup_lr: 0.2 65 | decay_stg: 'cosine' 66 | # final lr in cosine strategy 67 | alpha: 0. 68 | # how many iterations it takes to decay lr to 'alpha' 69 | decay_step: 750000 70 | 71 | task_type: 'imagenet' 72 | snapshot_freq: 1000 73 | print_freq: 100 74 | resume: True 75 | save_path: '../generalNAS_exp/scaling' 76 | load_name: 'latest.pth.tar' 77 | 78 | data: 79 | workers: 6 # dataloader worker num 80 | task_type: 'imagenet' 81 | data_type: 'ssst' 82 | scatter_mode: False 83 | final_height: 224 84 | final_width: 224 85 | final_channel: 3 86 | augmentation: 87 | rand_resize: 88 | output_size: 224 89 | scale: [0.08, 1.0] 90 | ratio: [0.75, 1.33] 91 | # resize 92 | resize: 93 | output_size: [224, 224] 94 | # normalize 95 | normalize: 96 | normalize_type: 'mean_std' 97 | mean: [123.675, 116.28, 103.53] 98 | std: [58.395, 57.120, 57.375] 99 | imagenet: 100 | type: 'classification' 101 | task: 'imagenet' 102 | json_path: /mnt/lustre/xiejiyang/nas/imagenet.json 103 | prefix: '/mnt/lustreold/share/images/train' 104 | batch_size: 8 # for single gpu 105 | -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luminolx/ScaleNet/22487c892010aa86316fe6d0762723181db58cc8/core/__init__.py -------------------------------------------------------------------------------- /core/agent/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luminolx/ScaleNet/22487c892010aa86316fe6d0762723181db58cc8/core/agent/__init__.py -------------------------------------------------------------------------------- /core/agent/nas_agent.py: -------------------------------------------------------------------------------- 1 | from core.model.net import Net 2 | from core.searcher import build_searcher 3 | from core.sampler import build_sampler 4 | from core.utils.lr_scheduler import IterLRScheduler 5 | from core.utils.optimizer import build_optimizer 6 | from core.dataset.build_dataloader import build_dataloader 7 | from tools.eval.build_tester import build_tester 8 | from tools.trainer import build_trainer 9 | import torch.nn as nn 10 | 11 | try: 12 | from apex import amp 13 | from apex.parallel import DistributedDataParallel as DDP 14 | has_apex = True 15 | except ImportError: 16 | from torch.nn.parallel import DistributedDataParallel as DDP 17 | has_apex = False 18 | import time 19 | from os.path import join, exists 20 | import os 21 | from core.utils.flops import count_flops 22 | import logging 23 | import torch 24 | from tools.dist_init import dist_init 25 | 26 | class FormatterNoInfo(logging.Formatter): 27 | def __init__(self, fmt='%(levelname)s: %(message)s'): 28 | logging.Formatter.__init__(self, fmt) 29 | 30 | def format(self, record): 31 | if record.levelno == logging.INFO: 32 | return str(record.getMessage()) 33 | return logging.Formatter.format(self, record) 34 | 35 | class NASAgent: 36 | def __init__(self, config): 37 | self.cfg_net = config.pop('model') 38 | self.cfg_search = config.pop('search') 39 | self.cfg_sample = config.pop('sample') 40 | 41 | def run(self): 42 | console_handler = logging.StreamHandler() 43 | console_handler.setFormatter(FormatterNoInfo()) 44 | logging.root.addHandler(console_handler) 45 | logging.root.setLevel(logging.INFO) 46 | 47 | #search config 48 | cfg_search_searcher = self.cfg_search.pop('searcher') 49 | cfg_search_stg = self.cfg_search.pop('strategy') 50 | cfg_search_data = self.cfg_search.pop('data') 51 | #sample config 52 | cfg_sample_sampler = self.cfg_sample.pop('sampler') 53 | cfg_sample_stg = self.cfg_sample.pop('strategy') 54 | cfg_sample_data = self.cfg_sample.pop('data') 55 | 56 | self.rank, self.local_rank, self.world_size = dist_init() 57 | print('==rank{}==local rank{}==world size{}'.format(self.rank, self.local_rank, self.world_size)) 58 | 59 | torch.manual_seed(42 + self.rank) 60 | 61 | # build model 62 | self._build_model(self.cfg_net, cfg_search_searcher) 63 | # search 64 | if self.cfg_search['flag']: 65 | if self.rank == 0: 66 | if not exists(join(cfg_search_stg['save_path'], 'checkpoint')): 67 | os.makedirs(join(cfg_search_stg['save_path'], 'checkpoint')) 68 | if not exists(join(cfg_search_stg['save_path'], 'events')): 69 | os.makedirs(join(cfg_search_stg['save_path'], 'events')) 70 | if not exists(join(cfg_search_stg['save_path'], 'log')): 71 | os.makedirs(join(cfg_search_stg['save_path'], 'log')) 72 | 73 | self._build_searcher(cfg_search_searcher, cfg_search_data, cfg_search_stg) 74 | self.search() 75 | 76 | # sample 77 | if self.cfg_sample['flag']: 78 | if not exists(cfg_sample_stg['save_path']): 79 | os.makedirs(cfg_sample_stg['save_path']) 80 | self._build_sampler(cfg_sample_sampler, cfg_sample_data, cfg_sample_stg, self.cfg_net, 81 | cfg_search_searcher) 82 | self.sample() 83 | self.subnet_candidates = self.sampler.generate_subnet() 84 | 85 | def _build_model(self, cfg_net, cfg_sample_sampler): 86 | self.model = Net(cfg_net, cfg_sample_sampler).cuda() 87 | 88 | def _build_searcher(self, cfg_searcher, cfg_data_search, cfg_stg_search): 89 | self.search_dataloader = build_dataloader(cfg_data_search, cfg_searcher) 90 | 91 | opt = build_optimizer(self.model, cfg_stg_search['optimizer']) 92 | if has_apex: 93 | self.model, opt = amp.initialize(self.model, opt, opt_level='O1', min_loss_scale=2.**10) 94 | if self.local_rank == 0: 95 | logging.info('NVIDIA APEX installed. AMP on.') 96 | self.model = DDP(self.model, delay_allreduce=True) 97 | else: 98 | if self.local_rank == 0: 99 | logging.info("Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP.") 100 | self.model = DDP(self.model, device_ids=[self.local_rank], find_unused_parameters=True) 101 | 102 | lr_scheduler = IterLRScheduler(opt, **cfg_stg_search['lr_scheduler']) 103 | 104 | for searcher_type, start_iter in zip(cfg_searcher['type'], cfg_searcher['start_iter']): 105 | searcher = build_searcher(searcher_type, cfg_searcher, 106 | **cfg_searcher.get(searcher_type, {})) 107 | self.model.module.add_searcher(searcher, start_iter) 108 | 109 | self.search_trainer = build_trainer(cfg_stg_search, self.search_dataloader, self.model, 110 | opt, lr_scheduler, 111 | time.strftime("%Y%m%d_%H%M%S", time.localtime())) 112 | load_path = join(cfg_stg_search['save_path'], 'checkpoint', cfg_stg_search['load_name']) 113 | if cfg_stg_search.get('resume', False) and os.path.exists(load_path): 114 | self.search_trainer.load(load_path) 115 | 116 | def _build_sampler(self, cfg_sampler, cfg_data_sample, cfg_stg_sample, cfg_net, cfg_search_searcher): 117 | if has_apex: 118 | if self.local_rank == 0: 119 | logging.info('NVIDIA APEX installed. AMP on.') 120 | self.model = DDP(self.model, delay_allreduce=True) 121 | else: 122 | if self.local_rank == 0: 123 | logging.info("Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP.") 124 | self.model = DDP(self.model, device_ids=[self.local_rank], find_unused_parameters=True) 125 | self.tester = build_tester(cfg_stg_sample, cfg_data_sample, self.model, cfg_search_searcher) 126 | self.sampler = build_sampler(cfg_sampler, self.tester, cfg_net, cfg_stg_sample) 127 | 128 | def search(self): 129 | self.search_trainer.train() 130 | if hasattr(self.model.module.searcher, 'get_best_arch'): 131 | self.subnet_candidates = self.model.module.searcher.get_best_arch() 132 | self.model.module.remove_searcher() 133 | 134 | def sample(self): 135 | self.sampler.sample() 136 | 137 | def statistic_flops(self): 138 | sampler = build_sampler({'type': 'random'}, self.model, None, None) 139 | logger = open('./statistic_flops.txt', 'a') 140 | for _ in range(12500): # x8 141 | subnet = sampler.generate_subnet() 142 | flops = count_flops(self.model.module.net, subnet) 143 | print('{},{}'.format(subnet, flops)) 144 | logger.write('{},{}\n'.format(subnet, flops)) 145 | logger.flush() 146 | -------------------------------------------------------------------------------- /core/dataset/augmentations/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luminolx/ScaleNet/22487c892010aa86316fe6d0762723181db58cc8/core/dataset/augmentations/__init__.py -------------------------------------------------------------------------------- /core/dataset/build_dataloader.py: -------------------------------------------------------------------------------- 1 | if __name__ == '__main__': 2 | import sys 3 | sys.path.append('../../') 4 | import datasets 5 | import samplers 6 | else: 7 | import core.dataset.datasets as datasets 8 | import core.dataset.samplers as samplers 9 | import torch 10 | from torch import distributed as dist 11 | from torch.utils.data import DataLoader 12 | from .augmentations.augmentation import augmentation_cv 13 | 14 | 15 | class DataPrefetcher(object): 16 | def __init__(self, loader): 17 | self.load_init = loader 18 | self.loader = iter(self.load_init) 19 | self.stream = torch.cuda.Stream() 20 | self.preload() 21 | 22 | def reset_loader(self): 23 | self.loader = iter(self.load_init) 24 | self.preload() 25 | 26 | def preload(self): 27 | # import pdb 28 | # pdb.set_trace() 29 | try: 30 | self.next_items = next(self.loader) 31 | except StopIteration: 32 | self.next_items = [None for _ in self.next_items] 33 | return self.next_items 34 | except: 35 | raise RuntimeError('load data error') 36 | 37 | with torch.cuda.stream(self.stream): 38 | for i in range(len(self.next_items)): 39 | if isinstance(self.next_items[i], dict): 40 | for idx in self.next_items[i].keys(): 41 | if not isinstance(self.next_items[i][idx], str): 42 | self.next_items[i][idx] = self.next_items[i][idx].cuda(non_blocking=True) 43 | else: 44 | if not isinstance(self.next_items[i][0], str): 45 | self.next_items[i] = self.next_items[i].cuda(non_blocking=True) 46 | 47 | def next(self): 48 | torch.cuda.current_stream().wait_stream(self.stream) 49 | next_items = self.next_items 50 | self.preload() 51 | return next_items 52 | 53 | def __iter__(self): 54 | return self 55 | 56 | def __next__(self): 57 | return self.next() 58 | 59 | 60 | def build_dataset(cfg_data, transforms, preprocessor): 61 | ''' cfg_data is a dict, contains one or more datasets of one task ''' 62 | if 'batch_size' not in cfg_data: 63 | cfg_data = cfg_data['imagenet'] 64 | dataset_fun = datasets.NormalDataset 65 | final_dataset = dataset_fun(cfg_data, transforms, preprocessor) 66 | return final_dataset 67 | 68 | 69 | def build_sampler(dataset, is_test=False): 70 | sampler = samplers.DistributedSampler if not is_test else samplers.DistributedTestSampler 71 | final_sampler = sampler(dataset, dist.get_world_size(), dist.get_rank()) 72 | return final_sampler 73 | 74 | 75 | def build_dataloader(cfg_data, cfg_searcher, is_test=False, world_size=1): 76 | ''' Build dataloader for train and test 77 | For multi-source task, return a dict. 78 | For other task and test, return a data loader. 79 | ''' 80 | resolution_multiplier = cfg_searcher.get('resolution_multiplier') 81 | max_scaling_stage = cfg_searcher.get('max_scaling_stage') 82 | resolution = [] 83 | for i in range(max_scaling_stage + 1): 84 | for j in resolution_multiplier[i]: 85 | if j not in resolution: 86 | resolution.append(j) 87 | 88 | transforms = {} 89 | transform_param = cfg_data.get('augmentation') 90 | resize_output_size = transform_param['resize']['output_size'] 91 | preprocessor = transform_param.get('preprocessor', 'cv') 92 | for w in resolution: 93 | if 'rand_resize' in transform_param.keys(): # train supernet 94 | transform_param['rand_resize']['output_size'] = w 95 | transform_param['resize']['output_size'] = [w, w] 96 | transforms[w] = augmentation_cv(transform_param) 97 | elif 'center_crop' in transform_param.keys(): # sample 98 | transform_param['resize']['output_size'] = int(resize_output_size / cfg_data.get('final_width') * w) 99 | transform_param['center_crop']['output_size'] = w 100 | transforms[w] = augmentation_cv(transform_param) 101 | else: 102 | transforms = augmentation_cv(transform_param) 103 | 104 | dataset = build_dataset(cfg_data, transforms, preprocessor) 105 | sampler = build_sampler(dataset, is_test) 106 | 107 | if dataset is None: 108 | dataloader = None 109 | else: 110 | if 'batch_size' not in cfg_data: 111 | batch_size = cfg_data['imagenet']['batch_size'] 112 | else: 113 | batch_size = cfg_data['batch_size'] 114 | dl = DataLoader(dataset, 115 | shuffle=False, 116 | batch_size=batch_size, 117 | num_workers=max(2, min(6, int(cfg_data.get('workers', 0) / world_size))), 118 | sampler=sampler, 119 | pin_memory=False) 120 | dataloader = dl 121 | 122 | return dataloader -------------------------------------------------------------------------------- /core/dataset/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .single_task_mc_dataset import * 2 | -------------------------------------------------------------------------------- /core/dataset/datasets/base_mc_dataset.py: -------------------------------------------------------------------------------- 1 | import mc 2 | import numpy as np 3 | from torch.utils.data import Dataset 4 | import io 5 | import cv2 6 | from PIL import Image 7 | 8 | 9 | def img_loader(img_str, preprocessor='cv'): 10 | if preprocessor == 'pil': 11 | buff = io.BytesIO(img_str) 12 | with Image.open(buff) as img: 13 | img = img.convert('RGB') 14 | elif preprocessor == 'cv': 15 | img_array = np.frombuffer(img_str, np.uint8) 16 | img = cv2.imdecode(img_array, cv2.IMREAD_COLOR) 17 | else: 18 | raise ValueError('no such processor') 19 | return img 20 | 21 | class BaseMcDataset(Dataset): 22 | def __init__(self, preprocessor='cv'): 23 | self.initialized = False 24 | self.num = 0 25 | self.preprocessor = preprocessor 26 | 27 | def __len__(self): 28 | return self.num 29 | 30 | def _init_memcached(self): 31 | if not self.initialized: 32 | server_list_config_file = "/mnt/lustre/share/memcached_client/server_list.conf" 33 | client_config_file = "/mnt/lustre/share/memcached_client/client.conf" 34 | self.mclient = mc.MemcachedClient.GetInstance(server_list_config_file, 35 | client_config_file) 36 | self.initialized = True 37 | 38 | def __getitem__(self, idx): 39 | raise RuntimeError("BaseMcDataset is unabled to be indexed") 40 | -------------------------------------------------------------------------------- /core/dataset/datasets/single_task_mc_dataset.py: -------------------------------------------------------------------------------- 1 | import mc 2 | 3 | if __name__ == '__main__': 4 | from base_mc_dataset import BaseMcDataset, img_loader 5 | else: 6 | from .base_mc_dataset import BaseMcDataset, img_loader 7 | import os 8 | import json 9 | 10 | 11 | class NormalDataset(BaseMcDataset): 12 | def __init__(self, cfg, transform=None, preprocessor='cv'): 13 | super().__init__(preprocessor) 14 | self.prefix = cfg['prefix'] 15 | self.transform = transform 16 | self.cfg = cfg 17 | self.parse_json_() 18 | 19 | def parse_json_(self): 20 | # print('loading json file: {}'.format(self.cfg['json_path'])) 21 | jdata = json.load(open(self.cfg['json_path'], 'r')) 22 | self.key = list(jdata.keys())[0] 23 | self.num = len(jdata[self.key]) 24 | # print('building dataset from %s: %d images' % (self.prefix, self.num)) 25 | 26 | self.metas = [] 27 | for i in range(self.num): 28 | path = jdata[self.key][i]['img_info']['filename'] 29 | label = jdata[self.key][i]['annos'][self.cfg.get('task', 'classification')][ 30 | self.cfg.get('type', 'imagenet')] 31 | self.metas.append((path, int(label))) 32 | 33 | def __getitem__(self, idx): 34 | filename = self.prefix + '/' + self.metas[idx][0] 35 | cls = self.metas[idx][1] 36 | # memcached 37 | self._init_memcached() 38 | value = mc.pyvector() 39 | self.mclient.Get(filename, value) 40 | value_str = mc.ConvertBuffer(value) 41 | img = img_loader(value_str, self.preprocessor) 42 | # transform 43 | if self.transform is not None: 44 | if isinstance(self.transform, dict): 45 | img_dict = {} 46 | for t in self.transform.keys(): 47 | img_dict[t] = self.transform[t](**{'image': img})['image'] 48 | img = img_dict 49 | else: 50 | img = self.transform(**{'image': img}) 51 | img = img['image'] 52 | return img, cls 53 | 54 | -------------------------------------------------------------------------------- /core/dataset/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | from .dist_iter_batch_sampler import * 2 | -------------------------------------------------------------------------------- /core/dataset/samplers/base_sampler.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.sampler import Sampler 2 | 3 | 4 | class BaseIterSampler(Sampler): 5 | def __init__(self, dataset, total_iter, batch_size, latest_iter=-1): 6 | self.dataset = dataset 7 | self.total_iter = total_iter 8 | self.batch_size = batch_size 9 | self.latest_iter = latest_iter 10 | self.total_size = self.total_iter * self.batch_size 11 | self.call = 0 12 | 13 | def __iter__(self): 14 | raise RuntimeError('unable to get iterator from BaseIterSampler') 15 | 16 | def __len__(self): 17 | return self.total_size 18 | -------------------------------------------------------------------------------- /core/dataset/samplers/dist_iter_batch_sampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | import torch 4 | from torch import distributed as dist 5 | from torch.utils.data.sampler import Sampler 6 | 7 | 8 | class DistributedSampler(Sampler): 9 | def __init__(self, dataset, num_replicas=None, rank=None): 10 | if num_replicas is None: 11 | if not dist.is_available(): 12 | raise RuntimeError("Requires distributed package to be available") 13 | num_replicas = dist.get_world_size() 14 | if rank is None: 15 | if not dist.is_available(): 16 | raise RuntimeError("Requires distributed package to be available") 17 | rank = dist.get_rank() 18 | self.dataset = dataset 19 | self.num_replicas = num_replicas 20 | self.rank = rank 21 | self.epoch = 0 22 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 23 | self.total_size = self.num_samples * self.num_replicas 24 | 25 | def __iter__(self): 26 | # deterministically shuffle based on epoch 27 | g = torch.Generator() 28 | g.manual_seed(self.epoch) 29 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 30 | 31 | # add extra samples to make it evenly divisible 32 | indices += indices[:(self.total_size - len(indices))] 33 | assert len(indices) == self.total_size 34 | 35 | # subsample 36 | indices = indices[self.rank:self.total_size:self.num_replicas] 37 | assert len(indices) == self.num_samples 38 | 39 | return iter(indices) 40 | 41 | def __len__(self): 42 | return self.num_samples 43 | 44 | def set_epoch(self, epoch): 45 | self.epoch = epoch 46 | 47 | 48 | class DistributedTestSampler(Sampler): 49 | def __init__(self, dataset, num_replicas=None, rank=None): 50 | self.dataset = dataset 51 | self.epoch = 0 52 | 53 | def __iter__(self): 54 | indices = list(range(len(self.dataset))) 55 | return iter(indices) 56 | 57 | def __len__(self): 58 | return self.num_samples 59 | 60 | def set_epoch(self, epoch): 61 | self.epoch = epoch 62 | 63 | 64 | -------------------------------------------------------------------------------- /core/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .net import Net -------------------------------------------------------------------------------- /core/model/basemodel.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | 4 | class BaseModel(nn.Module): 5 | def __init__(self): 6 | super(BaseModel, self).__init__() 7 | self.param_setted = False 8 | self.searcher = None 9 | 10 | def set_params(self, cfg): 11 | """Backbone is using to set lr and wd""" 12 | self.params = [] 13 | arranged_names = set() 14 | for name, module in self.named_modules(): 15 | for key in cfg: 16 | if isinstance(module, eval(key)) or issubclass(module.__class__, eval(key)): 17 | # self.params.append({'params': module.weight, 'lr': cfg[key][0], 18 | self.params.append({'params': name + ".weight", 'lr': cfg[key][0], 19 | 'weight_decay': cfg[key][1]}) 20 | arranged_names.add(name + '.weight') 21 | if not isinstance(module, nn.PReLU): 22 | if module.bias is not None and len(cfg[key]) == 4: 23 | # self.params.append({'params': module.bias, 24 | self.params.append({'params': name + ".bias", 25 | 'lr': cfg[key][2], 'weight_decay': cfg[key][3]}) 26 | arranged_names.add(name + '.bias') 27 | 28 | for name, param in self.named_parameters(): 29 | if name in arranged_names: 30 | continue 31 | else: 32 | # self.params.append({'params': param}) 33 | self.params.append({'params': name}) 34 | 35 | self.param_setted = True 36 | 37 | def get_params(self, base_lr, weight_decay): 38 | if not self.param_setted: 39 | self.set_params({'nn.Conv2d': [1, 2, 1, 0], 'nn.BatchNorm2d': [1, 0, 1, 0]}) 40 | 41 | real_params = [] 42 | for item in self.params: 43 | if isinstance(item['params'], str): 44 | item['params'] = self.state_dict(keep_vars=True)[item['params']] 45 | if 'lr' in item: 46 | item['lr'] *= base_lr 47 | if 'weight_decay' in item: 48 | item['weight_decay'] *= weight_decay 49 | real_params.append(item) 50 | return real_params 51 | 52 | def _init_params(self): 53 | for m in self.modules(): 54 | if isinstance(m, nn.Conv2d) or issubclass(m.__class__, nn.Conv2d): 55 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 56 | m.weight.data.normal_(0, math.sqrt(2. / n)) 57 | if m.bias is not None: 58 | m.bias.data.zero_() 59 | elif isinstance(m, nn.BatchNorm2d) or issubclass(m.__class__, nn.BatchNorm2d): 60 | m.weight.data.fill_(1) 61 | m.bias.data.zero_() 62 | elif isinstance(m, nn.Linear) or issubclass(m.__class__, nn.Linear): 63 | m.weight.data.normal_(0, 0.01) 64 | if m.bias is not None: 65 | m.bias.data.zero_() 66 | elif isinstance(m, nn.PReLU): 67 | m.weight.data.normal_(0, 0.01) 68 | 69 | def reset_bn(self): 70 | for m in self.modules(): 71 | if isinstance(m, nn.BatchNorm2d) or issubclass(m.__class__, nn.BatchNorm2d): 72 | m.reset_running_stats() 73 | 74 | def add_searcher(self, searcher, start_iter=0): 75 | if self.searcher is None: 76 | self.searcher = {} 77 | self.searcher[start_iter] = searcher 78 | 79 | def remove_searcher(self): 80 | self.searcher = None 81 | 82 | def get_loss(self, logits, label, **kwargs): 83 | raise NotImplementedError() 84 | 85 | def set_subnet(self, idx_list): 86 | """ 87 | set a specific subnet 88 | :param idx_list: indexes of each choice block 89 | """ 90 | assert len(self.net) == len(idx_list) 91 | for cb, idx in zip(self.net, idx_list): 92 | for b_idx, block in enumerate(cb): 93 | if b_idx != idx: 94 | for param in block.parameters(): 95 | param.requires_grad = False 96 | else: 97 | for param in block.parameters(): 98 | param.requires_grad = True 99 | 100 | -------------------------------------------------------------------------------- /core/model/net.py: -------------------------------------------------------------------------------- 1 | from core.utils.misc import get_cls_accuracy 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from core.search_space import init_model 6 | import math 7 | if __name__ == '__main__': 8 | from basemodel import BaseModel 9 | else: 10 | from .basemodel import BaseModel 11 | from torch.utils.checkpoint import checkpoint 12 | 13 | class Net(BaseModel): 14 | def __init__(self, cfg_net, cfg_search_searcher): 15 | super(Net, self).__init__() 16 | self.loss_type = cfg_net.pop('loss_type') 17 | self.net, self.depth_stage = init_model(cfg_net, cfg_search_searcher) 18 | self.subnet = None # hard code 19 | 20 | self.channel_multiplier = cfg_search_searcher['channel_multiplier'] 21 | self.depth_multiplier = cfg_search_searcher['depth_multiplier'] 22 | self.resolution_multiplier = cfg_search_searcher['resolution_multiplier'] 23 | self.max_scaling_stage = cfg_search_searcher['max_scaling_stage'] 24 | self.n = cfg_search_searcher['n_laterally_couplng'] 25 | self.asyn = cfg_search_searcher['asyn'] 26 | 27 | assert self.loss_type in ['softmax', 's-softmax'] 28 | self._init_params() 29 | 30 | def get_loss(self, logits, label): 31 | if self.loss_type == 'softmax': 32 | label = label.long() 33 | criterion = nn.CrossEntropyLoss(ignore_index=-1) 34 | loss = criterion(logits, label) 35 | elif self.loss_type == 's-softmax': 36 | label = label.long() 37 | predict = logits 38 | batch_size = predict.size(0) 39 | num_class = predict.size(1) 40 | label_smooth = torch.zeros((batch_size, num_class)).cuda() 41 | label_smooth.scatter_(1, label.unsqueeze(1), 1) 42 | ones_idx = label_smooth == 1 43 | zeros_idx = label_smooth == 0 44 | label_smooth[ones_idx] = 0.9 45 | label_smooth[zeros_idx] = 0.1 / (num_class - 1) 46 | loss = -torch.sum(F.log_softmax(predict, 1) * label_smooth.type_as(predict)) / batch_size 47 | return loss 48 | 49 | def forward(self, input, subnet=None, c_iter=None): 50 | # subnet: list, [op, ... , c_m, r] 51 | if isinstance(input, dict) and 'images' in input: 52 | x = input['images'] 53 | else: 54 | x = input 55 | 56 | c_searcher = None 57 | if self.subnet is not None and subnet is None: 58 | subnet = self.subnet 59 | self.subnet = None 60 | elif subnet is None and self.searcher is not None: 61 | # search 62 | if c_iter is None: 63 | raise RuntimeError('Param c_iter cannot be None in search mode.') 64 | searcher_keys = list(self.searcher.keys()) 65 | searcher_keys.sort(reverse=True) 66 | for s_iter in searcher_keys: 67 | if s_iter < c_iter: 68 | c_searcher = self.searcher[s_iter] 69 | break 70 | assert c_searcher is not None 71 | subnet = c_searcher.generate_subnet(self) 72 | assert subnet is not None 73 | assert len(subnet) == len(self.net) + 2 74 | 75 | if isinstance(x, dict): 76 | x = x[subnet[-1]] 77 | 78 | self.set_subnet(subnet[:-2])# use op only 79 | # forward 80 | logits = self.forward_(x, subnet) 81 | 82 | if isinstance(input, dict) and 'images' in input and 'labels' in input: 83 | accuracy = get_cls_accuracy(logits, input['labels'], topk=(1, 5)) 84 | loss = self.get_loss(logits, input['labels']) 85 | elif (not isinstance(input, dict)) or (isinstance(input, dict) and 'images' not in input): 86 | return logits 87 | else: 88 | accuracy = -1 89 | loss = -1 90 | 91 | output = {'output': logits, 'accuracy': accuracy, 'loss': loss, 92 | 'c_searcher': c_searcher, 'subnet': subnet} 93 | return output 94 | 95 | def forward_(self, x, subnet): 96 | c_m = subnet[-2] 97 | for idx, block in zip(subnet[:-2], self.net): 98 | x = block[idx](x, c_m) 99 | return x 100 | 101 | -------------------------------------------------------------------------------- /core/sampler/__init__.py: -------------------------------------------------------------------------------- 1 | from .build_sampler import build_sampler -------------------------------------------------------------------------------- /core/sampler/base_sampler.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from core.model.net import Net 3 | from tools.eval.base_tester import BaseTester 4 | from torch import distributed as dist 5 | import os 6 | from core.utils.logger import create_logger 7 | import time 8 | 9 | class BaseSampler: 10 | def __init__(self, tester: BaseTester, **kwargs): 11 | self.tester = tester 12 | self.model = self.tester.model 13 | self.rank = dist.get_rank() 14 | self.world_size = dist.get_world_size() 15 | self.flops_min = kwargs.pop('flops_min', 0) 16 | for k in kwargs: 17 | setattr(self, k, kwargs[k]) 18 | 19 | # build logger 20 | now = time.strftime("%Y%m%d_%H%M%S", time.localtime()) 21 | if now != '': 22 | now = '_' + now 23 | self.logger = create_logger('', 24 | '{}/log/'.format(self.cfg_stg_sample['save_path']) + '/log_sample{}.txt'.format(now)) 25 | 26 | def forward_subnet(self, input): 27 | """ 28 | run one step 29 | :return: 30 | """ 31 | pass 32 | 33 | @abstractmethod 34 | def eval_subnet(self): 35 | """ 36 | Do eval for a model family with basenet and scaling. 37 | :return: a score for the family 38 | """ 39 | 40 | @abstractmethod 41 | def generate_subnet(self): 42 | """ 43 | generate one subnet 44 | :return: block indexes for each choice block 45 | """ 46 | 47 | @abstractmethod 48 | def sample(self): 49 | """ 50 | sample basenet and scaling 51 | :return: None 52 | """ 53 | 54 | -------------------------------------------------------------------------------- /core/sampler/build_sampler.py: -------------------------------------------------------------------------------- 1 | from .evolution.evolution_sampler import EvolutionSampler 2 | 3 | 4 | def build_sampler(cfg, tester, net_cfg, cfg_stg_sample, **kwargs): 5 | kwargs = cfg.get('kwargs', {}) 6 | # Only evolution sample is implemented 7 | return EvolutionSampler(tester=tester, net_cfg=net_cfg, cfg_stg_sample=cfg_stg_sample, cfg_sampler=cfg, **kwargs) 8 | -------------------------------------------------------------------------------- /core/sampler/evolution/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luminolx/ScaleNet/22487c892010aa86316fe6d0762723181db58cc8/core/sampler/evolution/__init__.py -------------------------------------------------------------------------------- /core/sampler/evolution/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luminolx/ScaleNet/22487c892010aa86316fe6d0762723181db58cc8/core/sampler/evolution/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /core/sampler/evolution/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luminolx/ScaleNet/22487c892010aa86316fe6d0762723181db58cc8/core/sampler/evolution/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /core/sampler/evolution/__pycache__/evolution_sampler.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luminolx/ScaleNet/22487c892010aa86316fe6d0762723181db58cc8/core/sampler/evolution/__pycache__/evolution_sampler.cpython-36.pyc -------------------------------------------------------------------------------- /core/sampler/evolution/__pycache__/evolution_sampler.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luminolx/ScaleNet/22487c892010aa86316fe6d0762723181db58cc8/core/sampler/evolution/__pycache__/evolution_sampler.cpython-37.pyc -------------------------------------------------------------------------------- /core/sampler/evolution/__pycache__/nsganet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luminolx/ScaleNet/22487c892010aa86316fe6d0762723181db58cc8/core/sampler/evolution/__pycache__/nsganet.cpython-36.pyc -------------------------------------------------------------------------------- /core/sampler/evolution/__pycache__/nsganet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luminolx/ScaleNet/22487c892010aa86316fe6d0762723181db58cc8/core/sampler/evolution/__pycache__/nsganet.cpython-37.pyc -------------------------------------------------------------------------------- /core/sampler/evolution/__pycache__/sense_evolution_sampler.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luminolx/ScaleNet/22487c892010aa86316fe6d0762723181db58cc8/core/sampler/evolution/__pycache__/sense_evolution_sampler.cpython-36.pyc -------------------------------------------------------------------------------- /core/sampler/evolution/__pycache__/sense_evolution_sampler.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luminolx/ScaleNet/22487c892010aa86316fe6d0762723181db58cc8/core/sampler/evolution/__pycache__/sense_evolution_sampler.cpython-37.pyc -------------------------------------------------------------------------------- /core/sampler/evolution/evolution_sampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from core.sampler.evolution import nsganet as engine 3 | 4 | from pymop.problem import Problem 5 | from pymoo.optimize import minimize 6 | from core.sampler.base_sampler import BaseSampler 7 | from torch import distributed as dist 8 | import torch 9 | import math 10 | import random 11 | from core.search_space.ops import Conv2d, InvertedResidual, FC 12 | import torch.nn as nn 13 | import time 14 | import logging 15 | import os 16 | import yaml 17 | 18 | class EvolutionSampler(BaseSampler): 19 | def __init__(self, **kwargs): 20 | super().__init__(**kwargs) 21 | self.flops_constraint = float(self.flops_constraint) 22 | self.scaling_topk = {} 23 | self.basenet_topk = {} 24 | self.scaling_stage = self.cfg_sampler['kwargs'].get('max_scaling_stage', 3) 25 | self.ratio = self.cfg_sampler['kwargs'].get('ratio', 1.) 26 | self.weights = {} 27 | self.weights[self.scaling_stage] = self.ratio 28 | for i in range(self.scaling_stage - 1, 0, -1): 29 | self.weights[i] = self.weights[i + 1] * self.ratio 30 | 31 | self.start_step = 'start-1' 32 | self.arch_topk_path = self.cfg_stg_sample['save_path'] + '/' + self.cfg_stg_sample['resume_name'] 33 | if os.path.exists(self.arch_topk_path): 34 | arch_topk = yaml.load(open(self.arch_topk_path, 'r'), Loader=yaml.FullLoader) 35 | if type(arch_topk['scaling_topk']) == str: 36 | arch_topk['scaling_topk'] = eval(arch_topk['scaling_topk']) 37 | self.scaling_topk = arch_topk['scaling_topk'] 38 | if type(arch_topk['basenet_topk']) == str: 39 | arch_topk['basenet_topk'] = eval(arch_topk['basenet_topk']) 40 | self.basenet_topk = arch_topk['basenet_topk'] 41 | self.admm_iter = arch_topk['admm_iter'] 42 | self.start_step = arch_topk['start_step'] 43 | self.err_base = 25e6 44 | self.err_scale = 50e6 45 | 46 | def generate_subnet(self): 47 | return None 48 | 49 | def check_basenet(self, basenet): 50 | # basenet: list, [op, ...] 51 | depth_stage = self.model.module.depth_stage# [[base_max_depth, max_depth], ...] 52 | 53 | # Divide basenet into stages 54 | basenet_stage = []# [stage1: [op1, op2, ...], ...] 55 | idx = 0 56 | for i in depth_stage: 57 | basenet_stage.append(basenet[idx: (idx + i[1])]) 58 | idx += i[1] 59 | 60 | # Check each stage 61 | for i in range(len(basenet_stage)): 62 | if type(basenet_stage[i]) == list and depth_stage[i][0] > 1: 63 | # Check the first operation 64 | if basenet_stage[i][0] == 0: 65 | basenet_stage[i][0] = 1 66 | # Check the thermometer code 67 | id_flag = -1 # first identity opereation index 68 | for j in range(depth_stage[i][0]):# Only consider the basenet, ignore the scaled parts 69 | if basenet_stage[i][j] > 0 and id_flag != -1:# Swap the non-id with id 70 | basenet_stage[i][id_flag] = basenet_stage[i][j] 71 | basenet_stage[i][j] = 0 72 | id_flag = -1 73 | elif basenet_stage[i][j] == 0 and id_flag == -1: 74 | id_flag = j 75 | 76 | # Gather the checked basenet 77 | basenet = [] 78 | for i in basenet_stage: 79 | if type(i) == list: 80 | basenet += i 81 | else: 82 | basenet.append(i) 83 | 84 | # Check identity 85 | id = []# base->1, scaled parts->0 86 | for i in depth_stage: 87 | if i[0] == i[1]: 88 | id.append(1) 89 | else: 90 | id += [1] * i[0] + [0] * (i[1] - i[0]) 91 | assert len(id) == len(basenet) 92 | basenet = list(map(lambda x, y: x * y, basenet, id)) 93 | 94 | return basenet 95 | 96 | def generate_subnet_(self, basenet, scaling, scaling_stage=0): 97 | # Generate a subnet from a base model and a scaling strategy 98 | # basenet: list, [op, ...] 99 | # scaling: list of index, [depth, width, resolution] 100 | # scaling_stage: int 101 | basenet = self.check_basenet(basenet) 102 | depth_stage = self.model.module.depth_stage# [[base_max_depth, max_depth], ...] 103 | if scaling_stage == 0: 104 | scaling_value = [1., 1., 224] 105 | else: 106 | scaling_value = [self.model.module.depth_multiplier[scaling_stage][scaling[scaling_stage][0]], 107 | self.model.module.channel_multiplier[scaling_stage][scaling[scaling_stage][1]], 108 | self.model.module.resolution_multiplier[scaling_stage][scaling[scaling_stage][2]]] 109 | 110 | # Divide basenet into stages 111 | basenet_stage = []# [stage1: [op1, op2, ...], ...] 112 | idx = 0 113 | for i in depth_stage: 114 | basenet_stage.append(basenet[idx: (idx + i[1])]) 115 | idx += i[1] 116 | 117 | # Scaling depth 118 | for i in range(len(basenet_stage)): 119 | if type(basenet_stage[i]) == list and len(basenet_stage[i]) > 1: 120 | # Compute the depth of basenet in this stage 121 | n_base = 0 122 | for j in range(depth_stage[i][0]):# Only consider the basenet, ignore the scaled parts 123 | if basenet_stage[i][j] > 0: 124 | n_base += 1 125 | else: 126 | break 127 | # Compute the scaled depth of basenet in this stage 128 | n = int(math.ceil(n_base * scaling_value[0])) 129 | # Get the scaled stage. n <= len(basenet_stage) is guaranteed 130 | basenet_stage[i][n_base: n] = [basenet_stage[i][n_base - 1]] * (n - n_base) 131 | 132 | # Get subnet 133 | subnet = [] 134 | for i in basenet_stage: 135 | if type(i) == list: 136 | subnet += i 137 | else: 138 | subnet.append(i) 139 | return subnet + scaling_value[1:]# [op, ... , c_m, r] 140 | 141 | def eval_subnet(self, basenet, scaling, scaling_stage, err): 142 | # basenet: list, [op, ...] 143 | # scaling: dict of list of index, [depth, width, resolution] 144 | # scaling_stage: list 145 | assert len(self.model.module.net) == len(basenet) 146 | flops_constraint = getattr(self, 'flops_constraint', 400e6) 147 | 148 | score = [] 149 | weights_used = [] 150 | flops_ = self.count_flops(basenet + [1., 224]) 151 | if abs(flops_constraint - flops_) <= err: 152 | for i in scaling_stage: 153 | weights_used.append(self.weights[i]) 154 | subnet = self.generate_subnet_(basenet, scaling, scaling_stage=i) 155 | flops = self.count_flops(subnet) 156 | flops_constraint_ = flops_constraint * (2 ** i) 157 | self.logger.info('==subnet: {}, FLOPs: {}'.format(str(subnet), flops)) 158 | 159 | if not self.check_flops([basenet], scaling, i, err): 160 | score.append(3. + (flops_constraint_ - flops) / flops_constraint_) 161 | else: 162 | self.model.module.reset_bn() 163 | self.model.train() 164 | time.sleep(2)# process may be stuck in dataloader 165 | score.append(self.tester.test(subnet=subnet)) 166 | time.sleep(2)# process may be stuck in dataloader 167 | self.logger.info(score) 168 | score = sum([s * w for s, w in zip(score, weights_used)]) / sum(weights_used) 169 | 170 | if self.logger is not None: 171 | self.logger.info('{}-{}-{}-{}-{}\n'.format(str(basenet), str(scaling), 172 | self.scaling_stage, score, flops)) 173 | else: 174 | score = 3. + (flops_constraint - flops_) / flops_constraint 175 | 176 | if self.logger is not None: 177 | self.logger.info('{}-{}-{}-{}-{}\n'.format(str(basenet), str(scaling), 178 | self.scaling_stage, score, flops_)) 179 | return score 180 | 181 | def eval_subnet_host(self, basenet, scaling, scaling_stage, err): 182 | # basenet: list or list of list 183 | # scaling: dict 184 | # scaling_stage: list 185 | if isinstance(basenet[0], list): 186 | score = [] 187 | for basenet_ in basenet: 188 | score.append(self.eval_subnet(basenet_, scaling, scaling_stage, err)) 189 | score = sum(score) / len(score) 190 | else: 191 | score = self.eval_subnet(basenet, scaling, scaling_stage, err) 192 | return score 193 | 194 | def sample_basenet(self, fixed_scaling, pop_size, n_gens, sample_num): 195 | # Optimize basenet, fix scaling index 196 | # fixed_scaling: dict 197 | # print('==rank{}=={}'.format(self.rank, 2)) 198 | depth_stage = self.model.module.depth_stage# [[base_max_depth, max_depth], ...] 199 | for s in fixed_scaling.keys(): 200 | assert len(fixed_scaling[s]) == 3 201 | 202 | basenet_eval_dict = {} 203 | 204 | n_offspring = None #40 205 | # print('==rank{}=={}'.format(self.rank, 3)) 206 | # setup NAS search problem 207 | n_var = len(self.model.module.net) # operation index 208 | lb = np.zeros(n_var) # left index of each block 209 | ub = np.zeros(n_var) # right index of each block 210 | 211 | for i, block in enumerate(self.model.module.net):# Generate for the whole supernet 212 | ub[i] = len(block) - 1 213 | 214 | idx = 0 215 | for i in depth_stage: 216 | for j in range(i[0]): 217 | if i[0] == 1 and i[1] == 1:# conv_stem, conv_out, pooling, FC 218 | ub[idx + j] = 0 219 | elif j == 0:# Downsampling operation in the stage 220 | lb[idx + j] = 1 221 | for j in range(i[0], i[1]):# The scaled operations are not contained in the basenet 222 | ub[idx + j] = 0 223 | idx += i[1] 224 | 225 | # print('==rank{}=={}'.format(self.rank, 4)) 226 | scaling_stage = [i for i in range(1, self.scaling_stage + 1)] 227 | nas_problem = NAS(n_var=n_var, n_obj=1, n_constr=0, lb=lb, ub=ub, 228 | eval_func=lambda basenet: self.eval_subnet_host(basenet, fixed_scaling, scaling_stage, self.err_base), 229 | result_dict=basenet_eval_dict, rank=self.rank, world_size=self.world_size) 230 | 231 | # print('==rank{}=={}'.format(self.rank, 5)) 232 | # configure the nsga-net method 233 | init_sampling = [] 234 | for _ in range(pop_size): 235 | init_sampling.append(self.get_init_basenet(flops=self.flops_constraint)) 236 | # print('==rank{}=={}'.format(self.rank, 6)) 237 | method = engine.nsganet(pop_size=pop_size, 238 | n_offsprings=n_offspring, 239 | eliminate_duplicates=True, 240 | sampling=np.array(init_sampling, dtype=np.int32)) 241 | 242 | # print('==rank{}=={}'.format(self.rank, 7)) 243 | res = minimize(nas_problem, 244 | method, 245 | callback=lambda algorithm: self.generation_callback(algorithm), 246 | termination=('n_gen', n_gens)) 247 | 248 | if self.rank == 0: 249 | sorted_basenet = sorted(basenet_eval_dict.items(), key=lambda i: i[1]['acc'], reverse=True) 250 | sorted_basenet_key = [x[1]['arch'] for x in sorted_basenet] 251 | basenet_topk = sorted_basenet_key[:sample_num] 252 | self.logger.info('== search result ==') 253 | self.logger.info([[list(x[1]['arch']), x[1]['acc']] for x in sorted_basenet]) 254 | self.logger.info('== best basenet ==') 255 | self.logger.info([list(x) for x in basenet_topk]) 256 | self.basenet_topk = basenet_topk 257 | basenet_topk = torch.IntTensor(basenet_topk).cuda() 258 | dist.broadcast(basenet_topk, 0) 259 | else: 260 | basenet_topk = torch.IntTensor([[0] * n_var for _ in range(sample_num)]).cuda() 261 | dist.broadcast(basenet_topk, 0) 262 | self.basenet_topk = basenet_topk.cpu().tolist() 263 | 264 | def sample_scaling(self, fixed_basenet=None, scaling_stage=1): 265 | # Optimize scaling index, fix basenet 266 | assert len(fixed_basenet) == len(self.model.module.net) 267 | 268 | scaling_eval_dict = {} 269 | 270 | n_offspring = None #40 271 | 272 | # setup NAS search problem 273 | n_var = 3 # depth, width, resolution 274 | lb = np.zeros(n_var) # left index of scaling multiplier 275 | ub = np.array([len(self.model.module.depth_multiplier[scaling_stage]) - 1, # right index of scaling multiplier 276 | len(self.model.module.channel_multiplier[scaling_stage]) - 1, 277 | len(self.model.module.resolution_multiplier[scaling_stage]) - 1], dtype=float) 278 | 279 | nas_problem = NAS(n_var=n_var, n_obj=1, n_constr=0, lb=lb, ub=ub, 280 | eval_func=lambda scaling: self.eval_subnet_host(fixed_basenet, scaling, [scaling_stage], self.err_scale), 281 | result_dict=scaling_eval_dict, rank=self.rank, world_size=self.world_size) 282 | 283 | # configure the nsga-net method 284 | init_sampling = [] 285 | for _ in range(self.pop_size_scaling): 286 | init_sampling.append(self.get_init_scaling(scaling_stage, [fixed_basenet])) 287 | method = engine.nsganet(pop_size=self.pop_size_scaling, 288 | n_offsprings=n_offspring, 289 | eliminate_duplicates=True, 290 | sampling=np.array(init_sampling, dtype=np.int32)) 291 | 292 | res = minimize(nas_problem, 293 | method, 294 | callback=lambda algorithm: self.generation_callback(algorithm), 295 | termination=('n_gen', self.n_gens_scaling)) 296 | 297 | if self.rank == 0: 298 | sorted_scaling = sorted(scaling_eval_dict.items(), key=lambda i: i[1]['acc'], reverse=True) 299 | sorted_scaling_key = [x[1]['arch'] for x in sorted_scaling] 300 | scaling_topk = sorted_scaling_key[:self.sample_num_scaling] 301 | self.logger.info('== search result ==') 302 | self.logger.info([[list(x[1]['arch']), x[1]['acc']] for x in sorted_scaling]) 303 | self.logger.info('== best scaling ==') 304 | self.logger.info([list(x) for x in scaling_topk]) 305 | self.scaling_topk[scaling_stage] = scaling_topk 306 | scaling_topk = torch.IntTensor(scaling_topk).cuda() 307 | dist.broadcast(scaling_topk, 0) 308 | else: 309 | scaling_topk = torch.IntTensor([[0] * n_var for _ in range(self.sample_num_scaling)]).cuda() 310 | dist.broadcast(scaling_topk, 0) 311 | self.scaling_topk[scaling_stage] = scaling_topk.cpu().tolist() 312 | 313 | def grid_search_scaling(self, fixed_basenet=None, scaling_stage=1): 314 | # Optimize scaling index, fix basenet 315 | assert len(fixed_basenet) == len(self.model.module.net) 316 | 317 | scaling_eval_dict = {} 318 | 319 | scaling_list = [] 320 | for d in range(len(self.model.module.depth_multiplier[scaling_stage])): 321 | for w in range(len(self.model.module.channel_multiplier[scaling_stage])): 322 | for r in range(len(self.model.module.resolution_multiplier[scaling_stage])): 323 | if self.check_flops([fixed_basenet], {scaling_stage: [d, w, r]}, scaling_stage, self.err_scale): 324 | scaling_list.append([d, w, r]) 325 | len_scaling_list_fit = len(scaling_list) 326 | self.logger.info(scaling_list) 327 | 328 | score = torch.zeros(len(scaling_list)) 329 | for i in range(self.rank, len(scaling_list), self.world_size): 330 | arch_str = str(scaling_list[i]).replace('\n', '') 331 | acc = self.eval_subnet_host(fixed_basenet, {scaling_stage: scaling_list[i]}, [scaling_stage], self.err_scale) 332 | logging.info('==rank{}== [{}/{}] evaluation basenet/scaling:{} prec@1:{}\n\n'.format(self.rank, i, 333 | len(scaling_list), 334 | arch_str, acc)) 335 | score[i] = acc 336 | 337 | score = score.cuda() 338 | dist.all_reduce(score) 339 | score = score.cpu().tolist() 340 | 341 | for i in range(len_scaling_list_fit): 342 | arch_str = str(scaling_list[i]).replace('\n', '') 343 | if scaling_eval_dict.get(arch_str) is None: 344 | scaling_eval_dict[arch_str] = {'acc': score[i], 'arch': scaling_list[i]} 345 | 346 | if self.rank == 0: 347 | sorted_scaling = sorted(scaling_eval_dict.items(), key=lambda i: i[1]['acc'], reverse=True) 348 | scaling_topk = [x[1]['arch'] for x in sorted_scaling] 349 | self.logger.info('== search result ==') 350 | self.logger.info([[list(x[1]['arch']), x[1]['acc']] for x in sorted_scaling]) 351 | self.logger.info('== best scaling ==') 352 | self.logger.info([list(x) for x in scaling_topk]) 353 | self.scaling_topk[scaling_stage] = scaling_topk 354 | scaling_topk = torch.IntTensor(scaling_topk).cuda() 355 | dist.broadcast(scaling_topk, 0) 356 | else: 357 | scaling_topk = torch.IntTensor([[0] * 3 for _ in range(len_scaling_list_fit)]).cuda() 358 | dist.broadcast(scaling_topk, 0) 359 | self.scaling_topk[scaling_stage] = scaling_topk.cpu().tolist() 360 | 361 | def sample_scaling_(self, scaling_stage=1): 362 | # Optimize scaling index for all (sampling a fixed subset of) basenets 363 | 364 | scaling_eval_dict = {} 365 | n_offspring = None #40 366 | basenet_list = [] 367 | for _ in range(self.n_basenet): 368 | basenet_list.append(self.get_init_basenet(flops=self.flops_constraint)) 369 | 370 | 371 | basenet_list = torch.IntTensor(basenet_list).cuda() 372 | dist.broadcast(basenet_list, 0) 373 | basenet_list = basenet_list.cpu().tolist() 374 | 375 | # setup NAS search problem 376 | n_var = 3 # depth, width, resolution 377 | lb = np.zeros(n_var) # left index of scaling multiplier 378 | ub = np.array([len(self.model.module.depth_multiplier[scaling_stage]) - 1, # right index of scaling multiplier 379 | len(self.model.module.channel_multiplier[scaling_stage]) - 1, 380 | len(self.model.module.resolution_multiplier[scaling_stage]) - 1], dtype=float) 381 | 382 | nas_problem = NAS(n_var=n_var, n_obj=1, n_constr=0, lb=lb, ub=ub, 383 | eval_func=lambda scaling: self.eval_subnet_host(basenet_list, scaling, [scaling_stage], self.err_scale), 384 | result_dict=scaling_eval_dict, rank=self.rank, world_size=self.world_size) 385 | 386 | # configure the nsga-net method 387 | init_sampling = [] 388 | for _ in range(self.pop_size_scaling): 389 | init_sampling.append(self.get_init_scaling(scaling_stage, basenet_list)) 390 | method = engine.nsganet(pop_size=self.pop_size_scaling, 391 | n_offsprings=n_offspring, 392 | eliminate_duplicates=True, 393 | sampling=np.array(init_sampling, dtype=np.int32)) 394 | 395 | res = minimize(nas_problem, 396 | method, 397 | callback=lambda algorithm: self.generation_callback(algorithm), 398 | termination=('n_gen', self.n_gens_scaling)) 399 | 400 | if self.rank == 0: 401 | sorted_scaling = sorted(scaling_eval_dict.items(), key=lambda i: i[1]['acc'], reverse=True) 402 | sorted_scaling_key = [x[1]['arch'] for x in sorted_scaling] 403 | scaling_topk = sorted_scaling_key[:self.sample_num_scaling] 404 | self.logger.info('== search result ==') 405 | self.logger.info([[list(x[1]['arch']), x[1]['acc']] for x in sorted_scaling]) 406 | self.logger.info('== best scaling ==') 407 | self.logger.info([list(x) for x in scaling_topk]) 408 | self.scaling_topk[scaling_stage] = scaling_topk 409 | scaling_topk = torch.IntTensor(scaling_topk).cuda() 410 | dist.broadcast(scaling_topk, 0) 411 | else: 412 | scaling_topk = torch.IntTensor([[0] * n_var for _ in range(self.sample_num_scaling)]).cuda() 413 | dist.broadcast(scaling_topk, 0) 414 | self.scaling_topk[scaling_stage] = scaling_topk.cpu().tolist() 415 | 416 | def grid_search_scaling_(self, scaling_stage=1): 417 | # Optimize scaling index, fix basenet 418 | scaling_eval_dict = {} 419 | basenet_list = [] 420 | for _ in range(self.n_basenet): 421 | basenet_list.append(self.get_init_basenet(flops=self.flops_constraint)) 422 | basenet_list = torch.IntTensor(basenet_list).cuda() 423 | dist.broadcast(basenet_list, 0) 424 | basenet_list = basenet_list.cpu().tolist() 425 | scaling_list = [] 426 | for d in range(len(self.model.module.depth_multiplier[scaling_stage])): 427 | for w in range(len(self.model.module.channel_multiplier[scaling_stage])): 428 | for r in range(len(self.model.module.resolution_multiplier[scaling_stage])): 429 | if self.check_flops(basenet_list, {scaling_stage: [d, w, r]}, scaling_stage, self.err_scale): 430 | scaling_list.append([d, w, r]) 431 | len_scaling_list_fit = len(scaling_list) 432 | self.logger.info(scaling_list) 433 | score = torch.zeros(len(scaling_list)) 434 | for i in range(self.rank, len(scaling_list), self.world_size): 435 | arch_str = str(scaling_list[i]).replace('\n', '') 436 | acc = self.eval_subnet_host(basenet_list, {scaling_stage: scaling_list[i]}, [scaling_stage], self.err_scale) 437 | logging.info('==rank{}== [{}/{}] evaluation basenet/scaling:{} prec@1:{}\n\n'.format(self.rank, i, 438 | len(scaling_list), 439 | arch_str, acc)) 440 | score[i] = acc 441 | 442 | score = score.cuda() 443 | dist.all_reduce(score) 444 | score = score.cpu().tolist() 445 | 446 | for i in range(len_scaling_list_fit): 447 | arch_str = str(scaling_list[i]).replace('\n', '') 448 | if scaling_eval_dict.get(arch_str) is None: 449 | scaling_eval_dict[arch_str] = {'acc': score[i], 'arch': scaling_list[i]} 450 | 451 | if self.rank == 0: 452 | sorted_scaling = sorted(scaling_eval_dict.items(), key=lambda i: i[1]['acc'], reverse=True) 453 | scaling_topk = [x[1]['arch'] for x in sorted_scaling] 454 | self.logger.info('== search result ==') 455 | self.logger.info([[list(x[1]['arch']), x[1]['acc']] for x in sorted_scaling]) 456 | self.logger.info('== best scaling ==') 457 | self.logger.info([list(x) for x in scaling_topk]) 458 | self.scaling_topk[scaling_stage] = scaling_topk 459 | scaling_topk = torch.IntTensor(scaling_topk).cuda() 460 | dist.broadcast(scaling_topk, 0) 461 | else: 462 | scaling_topk = torch.IntTensor([[0] * 3 for _ in range(len_scaling_list_fit)]).cuda() 463 | dist.broadcast(scaling_topk, 0) 464 | self.scaling_topk[scaling_stage] = scaling_topk.cpu().tolist() 465 | 466 | def check_flops(self, basenet_list, scaling, scaling_stage, err): 467 | flops = [] 468 | for basenet in basenet_list: 469 | subnet = self.generate_subnet_(basenet, scaling, scaling_stage=scaling_stage) 470 | flops_base = self.count_flops(basenet + [1., 224]) * 2 ** scaling_stage 471 | flops_scaled = self.count_flops(subnet) 472 | err_ = abs(flops_base - flops_scaled) 473 | flops.append(err_) 474 | err = err * 2 ** scaling_stage 475 | if sum(flops) / len(flops) <= err: 476 | return True 477 | else: 478 | return False 479 | 480 | def get_init_basenet(self, flops=400e6): 481 | depth_stage = self.model.module.depth_stage# [[base_max_depth, max_depth], ...] 482 | flag = True 483 | while flag: 484 | subnet = [] 485 | for block in self.model.module.net: 486 | idx = random.randint(1, len(block) - 1) if len(block) > 1 else 0 487 | subnet.append(idx)# only op now 488 | 489 | id = []# base->1, id->0 490 | for i in depth_stage: 491 | if i[0] == i[1]: 492 | id.append(1) 493 | else: 494 | n_base = random.randint(1, i[0]) 495 | id += [1] * n_base + [0] * (i[1] - n_base) 496 | assert len(id) == len(subnet) 497 | subnet = list(map(lambda x, y: x * y, subnet, id)) 498 | flops_ = self.count_flops(subnet + [1., 224]) 499 | # print(flops_) 500 | if abs(flops - flops_) <= self.err_base: 501 | flag = False 502 | return subnet 503 | 504 | def get_init_scaling(self, scaling_stage, basenet_list): 505 | flops_constraint_ = self.flops_constraint * (2 ** scaling_stage) 506 | flag = True 507 | while flag: 508 | scaling = [random.randint(0, len(self.model.module.depth_multiplier[scaling_stage]) - 1), 509 | random.randint(0, len(self.model.module.channel_multiplier[scaling_stage]) - 1), 510 | random.randint(0, len(self.model.module.resolution_multiplier[scaling_stage]) - 1)] 511 | 512 | # flops = [] 513 | # for basenet in basenet_list: 514 | # subnet = self.generate_subnet_(basenet, scaling, scaling_stage=scaling_stage) 515 | # flops.append(self.count_flops(subnet)) 516 | # flops = sum(flops) / len(flops) 517 | # # print(flops) 518 | # err = self.err_base * 2 ** scaling_stage 519 | # if -err <= (flops_constraint_ - flops) <= err: 520 | if self.check_flops(basenet_list, {scaling_stage: scaling}, scaling_stage, self.err_scale): 521 | flag = False 522 | return scaling # [depth, width, resolution] 523 | 524 | def sample(self): 525 | admm_iter = self.admm_iter 526 | if self.rank == 0: 527 | self.logger.info(self.start_step) 528 | self.logger.info('Depth multipliers: ' + str(self.model.module.depth_multiplier)) 529 | self.logger.info('Channel multipliers: ' + str(self.model.module.channel_multiplier)) 530 | self.logger.info('Resolution multipliers: ' + str(self.model.module.resolution_multiplier)) 531 | 532 | if self.start_step.split('-')[0] == 'start': 533 | # initially sample scaling strategy with evaluating all (a fixed subset of) basenets 534 | for s in range(int(self.start_step.split('-')[1]), self.scaling_stage + 1): 535 | # self.sample_scaling_(scaling_stage=s) 536 | # print('==rank{}=={}'.format(self.rank, 1)) 537 | self.grid_search_scaling_(scaling_stage=s) 538 | if s < self.scaling_stage: 539 | self.save_checkpoint('start-' + str(s + 1), admm_iter) 540 | else: 541 | self.save_checkpoint('basenet', admm_iter) 542 | elif self.start_step.split('-')[0] == 'scaling': 543 | # load checkpoint from sampling scaling 544 | for s in range(int(self.start_step.split('-')[1]), self.scaling_stage + 1): 545 | # self.sample_scaling(fixed_basenet=[int(x) for x in self.basenet_topk[0]], scaling_stage=s) 546 | self.grid_search_scaling(fixed_basenet=[int(x) for x in self.basenet_topk[0]], scaling_stage=s) 547 | if s < self.scaling_stage: 548 | self.save_checkpoint('scaling-' + str(s + 1), admm_iter) 549 | else: 550 | self.save_checkpoint('basenet', admm_iter) 551 | 552 | # iteratively sample basenet and scaling strategy 553 | for _ in range(self.admm_iter): 554 | admm_iter -= 1 555 | scaling_top1 = {} 556 | for s in range(1, self.scaling_stage + 1): 557 | scaling_top1[s] = [int(x) for x in self.scaling_topk[s][0]] 558 | self.sample_basenet(scaling_top1, self.pop_size, self.n_gens, self.sample_num) 559 | self.save_checkpoint('scaling-1', admm_iter) 560 | 561 | for s in range(1, self.scaling_stage + 1): 562 | # self.sample_scaling(fixed_basenet=[int(x) for x in self.basenet_topk[0]], scaling_stage=s) 563 | self.grid_search_scaling(fixed_basenet=[int(x) for x in self.basenet_topk[0]], scaling_stage=s) 564 | if s < self.scaling_stage: 565 | self.save_checkpoint('scaling-' + str(s + 1), admm_iter) 566 | else: 567 | self.save_checkpoint('basenet', admm_iter) 568 | 569 | # finished 570 | if self.rank == 0: 571 | self.logger.info('\n\n\n') 572 | self.logger.info('========== final best scaling ==========') 573 | for s in range(1, self.scaling_stage + 1): 574 | self.logger.info('==scaling stage {}=={}'.format(s, [list(x) for x in self.scaling_topk[s]])) 575 | 576 | self.logger.info('========== final best basenet ==========') 577 | self.logger.info([list(x) for x in self.basenet_topk]) 578 | 579 | self.logger.info('[END] Finish sampling.') 580 | 581 | def save_checkpoint(self, start_step, admm_iter): 582 | arch_topk = {'basenet_topk': str(self.basenet_topk), 583 | 'scaling_topk': str(self.scaling_topk), 584 | 'start_step': start_step, 585 | 'admm_iter': admm_iter} 586 | if self.rank == 0: 587 | with open(self.arch_topk_path, 'w') as f: 588 | yaml.dump(arch_topk, f) 589 | 590 | def generation_callback(self, algorithm): 591 | gen = algorithm.n_gen 592 | pop_var = algorithm.pop.get("X") 593 | pop_obj = algorithm.pop.get("F") 594 | self.logger.info(f'==Finished generation: {gen}') 595 | 596 | def count_flops_conv(self, inp, oup, kernel_size, padding, stride, groups, input_shape): 597 | c, w, h = input_shape 598 | c = oup 599 | w = (w + padding * 2 - kernel_size + 1) // stride 600 | h = (h + padding * 2 - kernel_size + 1) // stride 601 | flops = inp * oup * w * h // groups * kernel_size * kernel_size 602 | output_shape = [c, w, h] 603 | return flops, output_shape 604 | 605 | def count_flops_fc(self, inp, oup): 606 | return inp * oup 607 | 608 | def count_flops_op(self, op, c_m, input_shape): 609 | # Compute FLOPs of an operation 610 | flops = [] 611 | if isinstance(op, InvertedResidual): 612 | inp = int(math.ceil(op.inp_base * c_m)) 613 | hid = int(round(inp * op.t)) 614 | oup = int(math.ceil(op.oup_base * c_m)) 615 | if op.t == 1: 616 | #dw 617 | flops_, input_shape = self.count_flops_conv(hid, hid, op.k, op.k // 2, op.stride, 618 | hid, input_shape) 619 | flops.append(flops_) 620 | #se 621 | if op.use_se: 622 | flops_ = self.count_flops_fc(hid, hid // op.se.reduction) 623 | flops.append(flops_) 624 | flops_ = self.count_flops_fc(hid // op.se.reduction, hid) 625 | flops.append(flops_) 626 | #pw 627 | flops_, input_shape = self.count_flops_conv(hid, oup, 1, 0, 1, 1, input_shape) 628 | flops.append(flops_) 629 | else: 630 | #pw 631 | flops_, input_shape = self.count_flops_conv(inp, hid, 1, 0, 1, 1, input_shape) 632 | flops.append(flops_) 633 | #dw 634 | flops_, input_shape = self.count_flops_conv(hid, hid, op.k, op.k // 2, op.stride, hid, 635 | input_shape) 636 | flops.append(flops_) 637 | #se 638 | if op.use_se: 639 | flops_ = self.count_flops_fc(hid, hid // op.se.reduction) 640 | flops.append(flops_) 641 | flops_ = self.count_flops_fc(hid // op.se.reduction, hid) 642 | flops.append(flops_) 643 | #pw 644 | flops_, input_shape = self.count_flops_conv(hid, oup, 1, 0, 1, 1, input_shape) 645 | flops.append(flops_) 646 | elif isinstance(op, Conv2d): 647 | if op.k == 3: 648 | inp = op.inp_base 649 | oup = int(math.ceil(op.oup_base * c_m)) 650 | elif op.k == 1: 651 | oup = op.oup_base 652 | inp = int(math.ceil(op.inp_base * c_m)) 653 | flops_, input_shape = self.count_flops_conv(inp, oup, op.k, op.k//2, op.stride, 1, 654 | input_shape) 655 | flops.append(flops_) 656 | elif isinstance(op, FC): 657 | flops_ = self.count_flops_fc(op.inp, op.oup) 658 | flops.append(flops_) 659 | else: 660 | flops.append(0.) 661 | 662 | return int(sum(flops)), input_shape 663 | 664 | 665 | def count_flops(self, subnet, input_size=[3, 224, 224]): 666 | # subnet: list, [op, ... , c_m, r] 667 | # print(subnet) 668 | subnet_op = subnet[:-2] 669 | c_m = subnet[-2] 670 | input_size = [input_size[0], subnet[-1], subnet[-1]] 671 | subnet_flops = [] 672 | for op, layer in zip(subnet_op, self.model.module.net): 673 | layer_flops, input_size = self.count_flops_op(layer[op], c_m, input_size) 674 | subnet_flops.append(layer_flops) 675 | return sum(subnet_flops) 676 | 677 | 678 | # --------------------------------------------------------------------------------------------------------- 679 | # Define your NAS Problem 680 | # --------------------------------------------------------------------------------------------------------- 681 | class NAS(Problem): 682 | # first define the NAS problem (inherit from pymop) 683 | def __init__(self, n_var=20, n_obj=1, n_constr=0, lb=None, ub=None, eval_func=None, result_dict=None, 684 | rank=0, world_size=1): 685 | super().__init__(n_var=n_var, n_obj=n_obj, n_constr=n_constr, type_var=np.int) 686 | self.xl = lb 687 | self.xu = ub 688 | self._n_evaluated = 0 # keep track of how many architectures are sampled 689 | self.eval_func = eval_func 690 | self.result_dict = result_dict 691 | self.rank = rank 692 | self.world_size = world_size 693 | 694 | def _evaluate(self, x, out, *args, **kwargs): 695 | x = torch.IntTensor(x).cuda() 696 | dist.broadcast(x, 0) 697 | x = x.cpu().numpy() 698 | 699 | objs = torch.zeros(x.shape[0], self.n_obj) 700 | for i in range(self.rank, x.shape[0], self.world_size): 701 | # all objectives assume to be MINIMIZED !!!!! 702 | arch_str = str(x[i].tolist()).replace('\n', ' ') 703 | if self.result_dict.get(arch_str) is not None: 704 | acc = self.result_dict[arch_str]['acc'] 705 | else: 706 | acc = self.eval_func(x[i].tolist()) 707 | 708 | logging.info('==rank{}== [{}/{}] evaluation basenet/scaling:{} prec@1:{}\n\n'.format(self.rank, i, x.shape[0], 709 | arch_str, acc)) 710 | objs[i, 0] = 100 - acc # performance['valid_acc'] 711 | # objs[i, 1] = 10 # performance['flops'] 712 | self._n_evaluated += 1 713 | 714 | objs = objs.cuda() 715 | dist.all_reduce(objs) 716 | objs = objs.cpu().numpy().astype(np.float64) 717 | out["F"] = objs 718 | 719 | for i in range(x.shape[0]): 720 | arch_str = str(x[i].tolist()).replace('\n', '') 721 | if self.result_dict.get(arch_str) is None: 722 | self.result_dict[arch_str] = {'acc': 100 - objs[i, 0], 'arch': x[i].tolist()} 723 | # if your NAS problem has constraints, use the following line to set constraints 724 | # out["G"] = np.column_stack([g1, g2, g3, g4, g5, g6]) in case 6 constraints 725 | -------------------------------------------------------------------------------- /core/sampler/evolution/nsganet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from pymoo.algorithms.genetic_algorithm import GeneticAlgorithm 4 | from pymoo.docs import parse_doc_string 5 | from pymoo.model.individual import Individual 6 | from pymoo.model.survival import Survival 7 | from pymoo.operators.crossover.point_crossover import PointCrossover 8 | from pymoo.operators.mutation.polynomial_mutation import PolynomialMutation 9 | from pymoo.operators.sampling.random_sampling import RandomSampling 10 | from pymoo.operators.selection.tournament_selection import compare, TournamentSelection 11 | from pymoo.util.display import disp_multi_objective 12 | from pymoo.util.dominator import Dominator 13 | from pymoo.util.non_dominated_sorting import NonDominatedSorting 14 | from pymoo.util.randomized_argsort import randomized_argsort 15 | 16 | 17 | # ========================================================================================================= 18 | # Implementation 19 | # based on nsga2 from https://github.com/msu-coinlab/pymoo 20 | # ========================================================================================================= 21 | 22 | 23 | class NSGANet(GeneticAlgorithm): 24 | 25 | def __init__(self, **kwargs): 26 | kwargs['individual'] = Individual(rank=np.inf, crowding=-1) 27 | super().__init__(**kwargs) 28 | 29 | self.tournament_type = 'comp_by_dom_and_crowding' 30 | self.func_display_attrs = disp_multi_objective 31 | 32 | 33 | # --------------------------------------------------------------------------------------------------------- 34 | # Binary Tournament Selection Function 35 | # --------------------------------------------------------------------------------------------------------- 36 | 37 | 38 | def binary_tournament(pop, P, algorithm, **kwargs): 39 | if P.shape[1] != 2: 40 | raise ValueError("Only implemented for binary tournament!") 41 | 42 | tournament_type = algorithm.tournament_type 43 | S = np.full(P.shape[0], np.nan) 44 | 45 | for i in range(P.shape[0]): 46 | 47 | a, b = P[i, 0], P[i, 1] 48 | 49 | # if at least one solution is infeasible 50 | if pop[a].CV > 0.0 or pop[b].CV > 0.0: 51 | S[i] = compare(a, pop[a].CV, b, pop[b].CV, method='smaller_is_better', return_random_if_equal=True) 52 | 53 | # both solutions are feasible 54 | else: 55 | 56 | if tournament_type == 'comp_by_dom_and_crowding': 57 | rel = Dominator.get_relation(pop[a].F, pop[b].F) 58 | if rel == 1: 59 | S[i] = a 60 | elif rel == -1: 61 | S[i] = b 62 | 63 | elif tournament_type == 'comp_by_rank_and_crowding': 64 | S[i] = compare(a, pop[a].rank, b, pop[b].rank, 65 | method='smaller_is_better') 66 | 67 | else: 68 | raise Exception("Unknown tournament type.") 69 | 70 | # if rank or domination relation didn't make a decision compare by crowding 71 | if np.isnan(S[i]): 72 | S[i] = compare(a, pop[a].get("crowding"), b, pop[b].get("crowding"), 73 | method='larger_is_better', return_random_if_equal=True) 74 | 75 | return S[:, None].astype(np.int) 76 | 77 | 78 | # --------------------------------------------------------------------------------------------------------- 79 | # Survival Selection 80 | # --------------------------------------------------------------------------------------------------------- 81 | 82 | 83 | class RankAndCrowdingSurvival(Survival): 84 | 85 | def __init__(self) -> None: 86 | super().__init__(True) 87 | 88 | def _do(self, pop, n_survive, D=None, **kwargs): 89 | 90 | # get the objective space values and objects 91 | F = pop.get("F") 92 | 93 | # the final indices of surviving individuals 94 | survivors = [] 95 | 96 | # do the non-dominated sorting until splitting front 97 | fronts = NonDominatedSorting().do(F, n_stop_if_ranked=n_survive) 98 | 99 | for k, front in enumerate(fronts): 100 | 101 | # calculate the crowding distance of the front 102 | crowding_of_front = calc_crowding_distance(F[front, :]) 103 | 104 | # save rank and crowding in the individual class 105 | for j, i in enumerate(front): 106 | pop[i].set("rank", k) 107 | pop[i].set("crowding", crowding_of_front[j]) 108 | 109 | # current front sorted by crowding distance if splitting 110 | if len(survivors) + len(front) > n_survive: 111 | I = randomized_argsort(crowding_of_front, order='descending', method='numpy') 112 | I = I[:(n_survive - len(survivors))] 113 | 114 | # otherwise take the whole front unsorted 115 | else: 116 | I = np.arange(len(front)) 117 | 118 | # extend the survivors by all or selected individuals 119 | survivors.extend(front[I]) 120 | 121 | return pop[survivors] 122 | 123 | 124 | def calc_crowding_distance(F): 125 | infinity = 1e+14 126 | 127 | n_points = F.shape[0] 128 | n_obj = F.shape[1] 129 | 130 | if n_points <= 2: 131 | return np.full(n_points, infinity) 132 | else: 133 | 134 | # sort each column and get index 135 | I = np.argsort(F, axis=0, kind='mergesort') 136 | 137 | # now really sort the whole array 138 | F = F[I, np.arange(n_obj)] 139 | 140 | # get the distance to the last element in sorted list and replace zeros with actual values 141 | dist = np.concatenate([F, np.full((1, n_obj), np.inf)]) \ 142 | - np.concatenate([np.full((1, n_obj), -np.inf), F]) 143 | 144 | index_dist_is_zero = np.where(dist == 0) 145 | 146 | dist_to_last = np.copy(dist) 147 | for i, j in zip(*index_dist_is_zero): 148 | dist_to_last[i, j] = dist_to_last[i - 1, j] 149 | 150 | dist_to_next = np.copy(dist) 151 | for i, j in reversed(list(zip(*index_dist_is_zero))): 152 | dist_to_next[i, j] = dist_to_next[i + 1, j] 153 | 154 | # normalize all the distances 155 | norm = np.max(F, axis=0) - np.min(F, axis=0) 156 | norm[norm == 0] = np.nan 157 | dist_to_last, dist_to_next = dist_to_last[:-1] / norm, dist_to_next[1:] / norm 158 | 159 | # if we divided by zero because all values in one columns are equal replace by none 160 | dist_to_last[np.isnan(dist_to_last)] = 0.0 161 | dist_to_next[np.isnan(dist_to_next)] = 0.0 162 | 163 | # sum up the distance to next and last and norm by objectives - also reorder from sorted list 164 | J = np.argsort(I, axis=0) 165 | crowding = np.sum(dist_to_last[J, np.arange(n_obj)] + dist_to_next[J, np.arange(n_obj)], axis=1) / n_obj 166 | 167 | # replace infinity with a large number 168 | crowding[np.isinf(crowding)] = infinity 169 | 170 | return crowding 171 | 172 | 173 | # ========================================================================================================= 174 | # Interface 175 | # ========================================================================================================= 176 | 177 | 178 | def nsganet( 179 | pop_size=100, 180 | sampling=RandomSampling(var_type=np.int), 181 | selection=TournamentSelection(func_comp=binary_tournament), 182 | crossover=PointCrossover(n_points=2), 183 | mutation=PolynomialMutation(eta=3, var_type=np.int), 184 | eliminate_duplicates=True, 185 | n_offsprings=None, 186 | **kwargs): 187 | """ 188 | 189 | Parameters 190 | ---------- 191 | pop_size : {pop_size} 192 | sampling : {sampling} 193 | selection : {selection} 194 | crossover : {crossover} 195 | mutation : {mutation} 196 | eliminate_duplicates : {eliminate_duplicates} 197 | n_offsprings : {n_offsprings} 198 | 199 | Returns 200 | ------- 201 | nsganet : :class:`~pymoo.model.algorithm.Algorithm` 202 | Returns an NSGANet algorithm object. 203 | 204 | 205 | """ 206 | 207 | return NSGANet(pop_size=pop_size, 208 | sampling=sampling, 209 | selection=selection, 210 | crossover=crossover, 211 | mutation=mutation, 212 | survival=RankAndCrowdingSurvival(), 213 | eliminate_duplicates=eliminate_duplicates, 214 | n_offsprings=n_offsprings, 215 | **kwargs) 216 | 217 | 218 | parse_doc_string(nsganet) 219 | -------------------------------------------------------------------------------- /core/search_space/__init__.py: -------------------------------------------------------------------------------- 1 | from .model_initializer import init_model 2 | -------------------------------------------------------------------------------- /core/search_space/model_initializer.py: -------------------------------------------------------------------------------- 1 | from core.dataset.datasets import base_mc_dataset 2 | import torch.nn as nn 3 | from core.search_space.ops import OPS 4 | from core.search_space.ops import FC, AveragePooling 5 | import math 6 | 7 | def init_model(cfg_net, cfg_search_searcher): 8 | model = nn.ModuleList() 9 | channel_multiplier = cfg_search_searcher['channel_multiplier'] 10 | depth_multiplier = cfg_search_searcher['depth_multiplier'] 11 | max_scaling_stage = cfg_search_searcher['max_scaling_stage'] #B_i = B_1^i 1.2 # 12 | n_laterally_couplng = cfg_search_searcher['n_laterally_couplng'] #n-laterally couplng for channels 13 | 14 | max_channel_multiplier = max(channel_multiplier[max_scaling_stage]) 15 | max_depth_multiplier = max(depth_multiplier[max_scaling_stage]) 16 | depth_stage = []# [[base_max_depth, max_depth], ] 17 | for _type in cfg_net: 18 | if _type == 'backbone': 19 | final_pooling = cfg_net[_type].pop('final_pooling') 20 | for stage in cfg_net[_type]: 21 | n, stride, inp, oup, t, _, ops = cfg_net[_type][stage] 22 | inp_base = inp 23 | oup_base = oup 24 | if stage != "conv_out": 25 | oup = int(math.ceil(oup * max_channel_multiplier)) + oup_base 26 | if stage != "conv_stem": 27 | inp = int(math.ceil(inp * max_channel_multiplier)) + inp_base 28 | 29 | if len(t) == 1: 30 | t = t * len(ops)# expand rate 31 | elif len(t) == 0: 32 | t = [1] * len(ops) 33 | 34 | n_init = n 35 | if stage != "conv_out" and stage != "conv_stem": 36 | n = int(math.ceil(n * max_depth_multiplier)) 37 | 38 | for i in range(n): 39 | stride = stride if i == 0 else 1 40 | module_ops = nn.ModuleList() 41 | for _t, op in zip(t, ops): 42 | if isinstance(_t, list): 43 | for t_num in _t: 44 | if stage == "conv_out": 45 | module_ops.append(OPS[op](inp, oup, t_num, stride, 46 | n_laterally_couplng, inp_base, oup_base, 47 | if_conv_out=True)) 48 | else: 49 | module_ops.append(OPS[op](inp, oup, t_num, stride, 50 | n_laterally_couplng, inp_base, oup_base)) 51 | else: 52 | if stage == "conv_out": 53 | module_ops.append(OPS[op](inp, oup, _t, stride, n_laterally_couplng, 54 | inp_base, oup_base, if_conv_out=True)) 55 | else: 56 | module_ops.append(OPS[op](inp, oup, _t, stride, n_laterally_couplng, 57 | inp_base, oup_base)) 58 | model.add_module(f'{_type}_{stage}_{i}', module_ops) 59 | inp = oup 60 | inp_base = oup_base 61 | depth_stage.append([n_init, n]) 62 | if final_pooling: 63 | model.add_module(f'{_type}_final_pooling', nn.ModuleList([AveragePooling(1)])) 64 | depth_stage.append([1, 1]) 65 | else: 66 | for fc_cfg in cfg_net[_type]: 67 | cfg = cfg_net[_type][fc_cfg] 68 | dim_in = cfg['dim_in'] 69 | dim_out = cfg['dim_out'] 70 | use_bn = cfg.get('use_bn', False) 71 | act = cfg.get('act', None) 72 | dp = cfg.get('dp', 0.) 73 | 74 | model.add_module('_'.join([_type, fc_cfg]), 75 | nn.ModuleList([FC(dim_in, dim_out, use_bn, dp, act)])) 76 | depth_stage.append([1, 1]) 77 | 78 | print(depth_stage) 79 | return model, depth_stage 80 | -------------------------------------------------------------------------------- /core/search_space/ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from collections import OrderedDict 4 | import torch.nn.functional as F 5 | import random 6 | import math 7 | from torch.utils.checkpoint import checkpoint, checkpoint_sequential 8 | 9 | OPS = OrderedDict() 10 | # CAUTION: The assign order is Strict 11 | 12 | OPS['ir_3x3_nse'] = lambda inp, oup, t, stride, n, inp_base, oup_base: InvertedResidual(inp=inp, oup=oup, 13 | t=t, stride=stride, k=3, n=n, inp_base=inp_base, oup_base=oup_base, activation=HSwish, 14 | use_se=False) 15 | OPS['ir_5x5_nse'] = lambda inp, oup, t, stride, n, inp_base, oup_base: InvertedResidual(inp=inp, oup=oup, 16 | t=t, stride=stride, k=5, n=n, inp_base=inp_base, oup_base=oup_base, activation=HSwish, 17 | use_se=False) 18 | OPS['ir_7x7_nse'] = lambda inp, oup, t, stride, n, inp_base, oup_base: InvertedResidual(inp=inp, oup=oup, 19 | t=t, stride=stride, k=7, n=n, inp_base=inp_base, oup_base=oup_base, activation=HSwish, 20 | use_se=False) 21 | OPS['ir_3x3_se'] = lambda inp, oup, t, stride, n, inp_base, oup_base: InvertedResidual(inp=inp, oup=oup, 22 | t=t, stride=stride, k=3, n=n, inp_base=inp_base, oup_base=oup_base, activation=HSwish, 23 | use_se=True) 24 | OPS['ir_5x5_se'] = lambda inp, oup, t, stride, n, inp_base, oup_base: InvertedResidual(inp=inp, oup=oup, 25 | t=t, stride=stride, k=5, n=n, inp_base=inp_base, oup_base=oup_base, activation=HSwish, 26 | use_se=True) 27 | OPS['ir_7x7_se'] = lambda inp, oup, t, stride, n, inp_base, oup_base: InvertedResidual(inp=inp, oup=oup, 28 | t=t, stride=stride, k=7, n=n, inp_base=inp_base, oup_base=oup_base, activation=HSwish, 29 | use_se=True) 30 | 31 | OPS['id'] = lambda inp, oup, t, stride, n, inp_base, oup_base: Identity(inp=inp, oup=oup, t=t, 32 | stride=stride, k=1, n=n, inp_base=inp_base, oup_base=oup_base) 33 | OPS['conv2d'] = lambda inp, oup, _, stride, n, inp_base, oup_base, if_conv_out=False: Conv2d(inp=inp, 34 | oup=oup, stride=stride, k=1, n=n, inp_base=inp_base, oup_base=oup_base, 35 | activation=HSwish, if_conv_out=if_conv_out) 36 | OPS['conv3x3'] = lambda inp, oup, _, stride, n, inp_base, oup_base, if_conv_out=False: Conv2d(inp=inp, 37 | oup=oup, stride=stride, k=3, n=n, inp_base=inp_base, oup_base=oup_base, 38 | activation=HSwish, if_conv_out=if_conv_out) 39 | 40 | channel_mults = [1.0, 0.8, 0.6, 0.4, 0.2] 41 | 42 | # custom conv2d for channel search 43 | class DynamicConv2d(nn.Conv2d): 44 | def __init__(self, *args, **kwargs): 45 | self.n = kwargs.pop('n', 1) 46 | super().__init__(*args, **kwargs) 47 | self.out_c = self.out_channels 48 | self.idx = 0 # index of n-laterally couplng parts 49 | 50 | def forward(self, input): 51 | in_c = input.shape[1] 52 | # self.weight.requires_grad_(requires_grad=False) 53 | # if self.bias is not None: 54 | # self.bias.requires_grad_(requires_grad=False) 55 | 56 | if self.idx % 2 == 0: 57 | if self.groups == 1: 58 | w = self.weight[:self.out_c][:, :in_c].contiguous() 59 | b = self.bias[:self.out_c].contiguous() if self.bias is not None else None 60 | else: 61 | w = self.weight[:in_c].contiguous() 62 | b = self.bias[:in_c].contiguous() if self.bias is not None else None 63 | else: 64 | if self.groups == 1: 65 | w = self.weight[(self.out_channels - self.out_c):][:, (self.in_channels - in_c):].contiguous() 66 | b = self.bias[(self.out_channels - self.out_c):].contiguous() if self.bias is not None else None 67 | else: 68 | w = self.weight[(self.in_channels - in_c):].contiguous() 69 | b = self.bias[(self.in_channels - in_c):].contiguous() if self.bias is not None else None 70 | 71 | if self.n > 0: 72 | self.idx = (self.idx + 1) % self.n 73 | # w.requires_grad_() 74 | # if b is not None: 75 | # b.requires_grad_() 76 | return F.conv2d(input, w, b, self.stride, self.padding, self.dilation, in_c if self.groups != 1 else 1) 77 | 78 | 79 | # class DynamicLinear(nn.Linear): 80 | # def __init__(self, *args, **kwargs): 81 | # self.n = kwargs.pop('n', 1) 82 | # super().__init__(*args, **kwargs) 83 | # self.idx = 0 # index of n-laterally couplng parts 84 | 85 | # def forward(self, input): 86 | # in_dim = input.shape[-1] 87 | # if self.idx % 2 == 0: 88 | # w = self.weight[:, :in_dim].contiguous() 89 | # else: 90 | # w = self.weight[:, (self.in_features - in_dim):].contiguous() 91 | 92 | # if self.n > 0: 93 | # self.idx = (self.idx + 1) % self.n 94 | # return F.linear(input, w, self.bias) 95 | 96 | 97 | class DynamicBatchNorm2d(nn.BatchNorm2d): 98 | def __init__(self, *args, **kwargs): 99 | self.n = kwargs.pop('n', 1) 100 | super().__init__(*args, **kwargs) 101 | self.idx = 0 # index of n-laterally couplng parts 102 | 103 | def forward(self, input): 104 | in_c = input.shape[1] 105 | 106 | self._check_input_dim(input) 107 | # exponential_average_factor is self.momentum set to 108 | # (when it is available) only so that if gets updated 109 | # in ONNX graph when this node is exported to ONNX. 110 | if self.momentum is None: 111 | exponential_average_factor = 0.0 112 | else: 113 | exponential_average_factor = self.momentum 114 | 115 | if self.training and self.track_running_stats: 116 | if self.num_batches_tracked is not None: 117 | self.num_batches_tracked += 1 118 | if self.momentum is None: # use cumulative moving average 119 | exponential_average_factor = 1.0 / float(self.num_batches_tracked) 120 | else: # use exponential moving average 121 | exponential_average_factor = self.momentum 122 | 123 | if self.idx % 2 == 0: 124 | w = self.weight[:in_c].contiguous() 125 | b = self.bias[:in_c].contiguous() if self.bias is not None else None 126 | r_mean = self.running_mean[:in_c].contiguous() 127 | r_var = self.running_var[:in_c].contiguous() 128 | else: 129 | w = self.weight[(self.num_features - in_c):].contiguous() 130 | b = self.bias[(self.num_features - in_c):].contiguous() if self.bias is not None else None 131 | r_mean = self.running_mean[(self.num_features - in_c):].contiguous() 132 | r_var = self.running_var[(self.num_features - in_c):].contiguous() 133 | 134 | if self.n > 0: 135 | self.idx = (self.idx + 1) % self.n 136 | return F.batch_norm(input, r_mean, r_var, w, b, self.training or not self.track_running_stats, 137 | exponential_average_factor, self.eps) 138 | 139 | 140 | class FC(nn.Module): 141 | def __init__(self, dim_in, dim_out, use_bn, dp=0., act='nn.ReLU', n=1): 142 | super(FC, self).__init__() 143 | self.inp = dim_in 144 | self.oup = dim_out 145 | self.module = [] 146 | self.module.append(nn.Linear(dim_in, dim_out)) 147 | if use_bn: 148 | self.module.append(nn.BatchNorm1d(dim_out)) 149 | if act is not None: 150 | self.module.append(eval(act)(inplace=True)) 151 | if dp != 0: 152 | self.module.append(nn.Dropout(dp)) 153 | self.module = nn.Sequential(*self.module) 154 | 155 | # def forward_(self, x): 156 | # if x.dim() != 2: 157 | # x = x.flatten(1) 158 | # return self.module(x) 159 | 160 | # @torch.jit.unused 161 | # def call_checkpoint_forward(self, x): 162 | # def closure(*x): 163 | # return self.forward_(*x) 164 | # return checkpoint(closure, x) 165 | 166 | # def forward(self, x, n_c): 167 | # return self.call_checkpoint_forward(x) 168 | # # return self.forward_(x) 169 | 170 | def forward(self, x, n_c): 171 | if x.dim() != 2: 172 | x = x.flatten(1) 173 | return self.module(x) 174 | 175 | 176 | class BasicOp(nn.Module): 177 | def __init__(self, oup, **kwargs): 178 | super(BasicOp, self).__init__() 179 | self.oup = oup 180 | for k in kwargs: 181 | setattr(self, k, kwargs[k]) 182 | 183 | def get_output_channles(self): 184 | return self.oup 185 | 186 | 187 | class Conv2d(BasicOp): 188 | def __init__(self, inp, oup, stride, k, inp_base=None, oup_base=None, 189 | activation=nn.ReLU, n=1, **kwargs): 190 | self.if_conv_out = kwargs.pop('if_conv_out', False) 191 | super(Conv2d, self).__init__(oup, **kwargs) 192 | self.inp = inp 193 | self.oup = oup 194 | self.inp_base = inp_base #inp of base model 195 | self.oup_base = oup_base #oup of base model 196 | self.stride = stride 197 | self.k = k 198 | self.conv = DynamicConv2d(inp, oup, kernel_size=k, stride=stride, padding=k//2, 199 | bias=False, n=n) 200 | self.bn = DynamicBatchNorm2d(oup, n=n) 201 | self.act = activation() 202 | 203 | # def forward_(self, x): 204 | # x = self.conv(x) 205 | # x = self.bn(x) 206 | # return self.act(x) 207 | 208 | # @torch.jit.unused 209 | # def call_checkpoint_forward(self, x): 210 | # def closure(*x): 211 | # return self.forward_(*x) 212 | # return checkpoint(closure, x) 213 | 214 | # def forward(self, x, c_m): 215 | # if not self.if_conv_out: 216 | # self.conv.out_c = int(math.ceil(self.oup_base * c_m)) 217 | # return self.call_checkpoint_forward(x) 218 | # # return self.forward_(x) 219 | 220 | def forward(self, x, c_m): 221 | if not self.if_conv_out: 222 | self.conv.out_c = int(math.ceil(self.oup_base * c_m)) 223 | x = self.conv(x) 224 | x = self.bn(x) 225 | return self.act(x) 226 | 227 | 228 | class HSwish(nn.Module): 229 | def __init__(self, inplace=True): 230 | super(HSwish, self).__init__() 231 | self.inplace = inplace 232 | 233 | def forward(self, x): 234 | out = x * F.relu6(x + 3, inplace=self.inplace) / 6 235 | return out 236 | 237 | 238 | class HSigmoid(nn.Module): 239 | def __init__(self, inplace=True): 240 | super(HSigmoid, self).__init__() 241 | self.inplace = inplace 242 | 243 | def forward(self, x): 244 | out = F.relu6(x + 3, inplace=self.inplace) / 6 245 | return out 246 | 247 | 248 | class SqueezeExcite(nn.Module): 249 | def __init__(self, in_channel, 250 | reduction=4, 251 | squeeze_act=nn.ReLU(inplace=True), 252 | excite_act=HSigmoid(inplace=True), n=1): 253 | super(SqueezeExcite, self).__init__() 254 | self.reduction = reduction 255 | # self.global_pooling = nn.AdaptiveAvgPool2d(1) 256 | self.squeeze_conv = DynamicConv2d(in_channels=in_channel, 257 | out_channels=in_channel // reduction, 258 | kernel_size=1, 259 | bias=True, n=n) 260 | self.squeeze_act = squeeze_act 261 | self.excite_conv = DynamicConv2d(in_channels=in_channel // reduction, 262 | out_channels=in_channel, 263 | kernel_size=1, 264 | bias=True, n=n) 265 | self.excite_act = excite_act 266 | 267 | def forward(self, inputs): 268 | self.squeeze_conv.out_c = inputs.size(1) // self.reduction 269 | self.excite_conv.out_c = inputs.size(1) 270 | # x = self.global_pooling(inputs) 271 | x = inputs.view(inputs.size(0), inputs.size(1), -1).mean(-1).view(inputs.size(0), inputs.size(1), 1, 1) 272 | x = self.squeeze_conv(x) 273 | x = self.squeeze_act(x) 274 | x = self.excite_conv(x) 275 | # x = self.excite_act(x) 276 | # return inputs * x 277 | return inputs * self.excite_act(x) 278 | 279 | 280 | class InvertedResidual(BasicOp): 281 | def __init__(self, inp, oup, stride, t, inp_base=None, oup_base=None, 282 | k=3, activation=nn.ReLU, use_se=False, n=1, **kwargs): 283 | super(InvertedResidual, self).__init__(oup, **kwargs) 284 | self.stride = stride 285 | self.t = t 286 | self.k = k 287 | self.use_se = use_se 288 | assert stride in [1, 2] 289 | self.inp = inp 290 | self.oup = oup 291 | self.inp_base = inp_base #inp of base model 292 | self.oup_base = oup_base #oup of base model 293 | hidden_dim = int(round(inp * t)) 294 | self.hidden_dim = hidden_dim 295 | 296 | if t == 1: 297 | # dw 298 | self.conv1 = DynamicConv2d(hidden_dim, hidden_dim, k, stride, padding=k//2, 299 | groups=hidden_dim, bias=False, n=n) 300 | self.bn1 = DynamicBatchNorm2d(hidden_dim, n=n) 301 | self.act = activation(inplace=True) 302 | # se 303 | self.se = SqueezeExcite(hidden_dim, n=n) if use_se else nn.Sequential() 304 | # pw-linear 305 | self.conv2 = DynamicConv2d(hidden_dim, oup, 1, 1, 0, bias=False, n=n) 306 | self.bn2 = DynamicBatchNorm2d(oup, n=n) 307 | else: 308 | # pw 309 | self.conv1 = DynamicConv2d(inp, hidden_dim, 1, 1, 0, bias=False, n=n) 310 | self.bn1 = DynamicBatchNorm2d(hidden_dim, n=n) 311 | self.act1 = activation(inplace=True) 312 | # dw 313 | self.conv1_2 = DynamicConv2d(hidden_dim, hidden_dim, k, stride, padding=k//2, 314 | groups=hidden_dim, bias=False, n=n) 315 | self.bn1_2 = DynamicBatchNorm2d(hidden_dim, n=n) 316 | self.act2 = activation(inplace=True) 317 | # se 318 | self.se = SqueezeExcite(hidden_dim, n=n) if use_se else nn.Sequential() 319 | # pw-linear 320 | self.conv2 = DynamicConv2d(hidden_dim, oup, 1, 1, 0, bias=False, n=n) 321 | self.bn2 = DynamicBatchNorm2d(oup, n=n) 322 | 323 | self.use_shortcut = inp == oup and stride == 1 324 | 325 | # def forward_(self, x): 326 | # if self.t == 1: 327 | # y = self.conv1(x) 328 | # y = self.bn1(y) 329 | # y = self.act(y) 330 | # y = self.se(y) 331 | # y = self.conv2(y) 332 | # y = self.bn2(y) 333 | # else: 334 | # y = self.conv1(x) 335 | # y = self.bn1(y) 336 | # y = self.act1(y) 337 | # y = self.conv1_2(y) 338 | # y = self.bn1_2(y) 339 | # y = self.act2(y) 340 | # y = self.se(y) 341 | # y = self.conv2(y) 342 | # y = self.bn2(y) 343 | # if self.use_shortcut: 344 | # y += x 345 | # return y 346 | 347 | # @torch.jit.unused 348 | # def call_checkpoint_forward(self, x): 349 | # def closure(*x): 350 | # return self.forward_(*x) 351 | # return checkpoint(closure, x) 352 | 353 | # def forward(self, x, c_m): 354 | # self.conv1.out_c = int(math.ceil(int(round(self.inp_base * self.t)) * c_m)) 355 | # self.conv2.out_c = int(math.ceil(self.oup_base * c_m)) 356 | # # if not self.use_shortcut: 357 | # # self.conv1_2.out_c = self.conv1.out_c 358 | # return self.call_checkpoint_forward(x) 359 | # # return self.forward_(x) 360 | 361 | def forward(self, x, c_m): 362 | self.conv1.out_c = int(math.ceil(int(round(self.inp_base * self.t)) * c_m)) 363 | self.conv2.out_c = int(math.ceil(self.oup_base * c_m)) 364 | # if not self.use_shortcut: 365 | # self.conv1_2.out_c = self.conv1.out_c 366 | if self.t == 1: 367 | y = self.conv1(x) 368 | y = self.bn1(y) 369 | y = self.act(y) 370 | y = self.se(y) 371 | y = self.conv2(y) 372 | y = self.bn2(y) 373 | else: 374 | y = self.conv1(x) 375 | y = self.bn1(y) 376 | y = self.act1(y) 377 | y = self.conv1_2(y) 378 | y = self.bn1_2(y) 379 | y = self.act2(y) 380 | y = self.se(y) 381 | y = self.conv2(y) 382 | y = self.bn2(y) 383 | if self.use_shortcut: 384 | y += x 385 | return y 386 | 387 | 388 | class Identity(BasicOp): 389 | def __init__(self, inp, oup, stride, inp_base=None, oup_base=None, **kwargs): 390 | super(Identity, self).__init__(oup, **kwargs) 391 | n = kwargs.pop('n', 1) 392 | self.inp = inp 393 | self.oup = oup 394 | self.inp_base = inp_base #inp of base model 395 | self.oup_base = oup_base #oup of base model 396 | if stride != 1 or inp != oup: 397 | self.downsample = True 398 | self.conv = DynamicConv2d(inp, oup, kernel_size=1, stride=stride, bias=False, n=n) 399 | self.bn = DynamicBatchNorm2d(oup, n=n) 400 | else: 401 | self.downsample = False 402 | 403 | def forward(self, x, c_m): 404 | if self.downsample: 405 | self.conv.out_c = int(math.ceil(self.oup_base * c_m)) 406 | x = self.conv(x) 407 | x = self.bn(x) 408 | return x 409 | 410 | 411 | class AveragePooling(BasicOp): 412 | def __init__(self, oup, **kwargs): 413 | super(AveragePooling, self).__init__(oup, **kwargs) 414 | self.pool = nn.AdaptiveAvgPool2d(oup) 415 | 416 | # def forward_(self, x): 417 | # return self.pool(x) 418 | 419 | # @torch.jit.unused 420 | # def call_checkpoint_forward(self, x): 421 | # def closure(*x): 422 | # return self.forward_(*x) 423 | # return checkpoint(closure, x) 424 | 425 | # def forward(self, x, c_m): 426 | # return self.call_checkpoint_forward(x) 427 | # # return self.pool(x) 428 | 429 | def forward(self, x, c_m): 430 | return self.pool(x) 431 | -------------------------------------------------------------------------------- /core/searcher/__init__.py: -------------------------------------------------------------------------------- 1 | from .build_searcher import build_searcher 2 | -------------------------------------------------------------------------------- /core/searcher/base_searcher.py: -------------------------------------------------------------------------------- 1 | from torch import distributed as dist 2 | 3 | 4 | class BaseSearcher: 5 | def __init__(self): 6 | self.rank = dist.get_rank() 7 | self.searched = False 8 | 9 | def search_step(self, x, model): 10 | raise NotImplemented('search_step must be implemented in sub-classes') 11 | 12 | def get_topk_arch(self, k): 13 | raise NotImplemented('get_topk_arch must be implemented in sub-classes') 14 | 15 | 16 | -------------------------------------------------------------------------------- /core/searcher/build_searcher.py: -------------------------------------------------------------------------------- 1 | from .uniform_searcher import UniformSearcher 2 | 3 | 4 | def build_searcher(searcher_type, cfg_search_searcher, **kwargs): 5 | # Only uniform search is implemented 6 | return UniformSearcher(cfg_search_searcher, **kwargs) 7 | -------------------------------------------------------------------------------- /core/searcher/uniform_searcher.py: -------------------------------------------------------------------------------- 1 | if __name__ == '__main__': 2 | from base_searcher import BaseSearcher 3 | else: 4 | from .base_searcher import BaseSearcher 5 | 6 | import torch 7 | from torch import distributed as dist 8 | import random 9 | import math 10 | 11 | class UniformSearcher(BaseSearcher): 12 | def __init__(self, cfg_search_searcher, **kwargs): 13 | super(UniformSearcher, self).__init__() 14 | self.rank = dist.get_rank() 15 | self.flops_constrant = kwargs.pop('flops_constrant', 400e6) 16 | self.c_ms = cfg_search_searcher['channel_multiplier'] 17 | self.d_ms = cfg_search_searcher['depth_multiplier'] 18 | self.rs = cfg_search_searcher['resolution_multiplier'] 19 | self.max_scaling_stage = cfg_search_searcher['max_scaling_stage'] 20 | 21 | def generate_subnet(self, model): 22 | depth_stage = model.depth_stage# [[n_base, n_max], ] 23 | if self.rank == 0: 24 | search_space = [len(i) for i in self.c_ms] 25 | rnd_s = random.randint(0, sum(search_space) - 1) 26 | s_s = 0 27 | for i in range(len(search_space)): 28 | if rnd_s < sum(search_space[:(i + 1)]): 29 | s_s = i 30 | break 31 | # s_s = random.randint(0, self.max_scaling_stage) 32 | c_m = self.c_ms[s_s][random.randint(0, len(self.c_ms[s_s]) - 1)] 33 | d_m = self.d_ms[s_s][random.randint(0, len(self.d_ms[s_s]) - 1)] 34 | r = self.rs[s_s][random.randint(0, len(self.rs[s_s]) - 1)] 35 | # c_m = 1.64 36 | # d_m = 1.64 37 | # r = 354 38 | 39 | subnet = [] 40 | for block in model.net: 41 | idx = random.randint(1, len(block) - 1) if len(block) > 1 else 0 42 | # idx = 3 if len(block) > 1 else 0 43 | subnet.append(idx)# only op now 44 | 45 | id = []# base->1, scaled->0 46 | subnet_scaled = []# scaled->op, other->0 47 | for i in depth_stage: 48 | if i[0] == i[1]: 49 | id.append(1) 50 | subnet_scaled.append(0) 51 | else: 52 | n_base = random.randint(1, i[0]) 53 | # n_base = i[0] 54 | n_op = int(math.ceil(n_base * d_m)) 55 | subnet_scaled += [0] * n_base + [subnet[len(id) + n_base - 1]] * (n_op - n_base) \ 56 | + [0] * (i[1] - n_op) 57 | id += [1] * n_base + [0] * (i[1] - n_base) 58 | assert len(id) == len(subnet) 59 | subnet = list(map(lambda x, y: x * y, subnet, id)) 60 | assert len(subnet_scaled) == len(subnet) 61 | subnet = list(map(lambda x, y: x + y, subnet, subnet_scaled)) 62 | 63 | subnet = torch.IntTensor(subnet).cuda() 64 | w = torch.Tensor([c_m]).cuda() 65 | r = torch.IntTensor([r]).cuda() 66 | else: 67 | subnet = torch.zeros(len(model.net), dtype=torch.int32).cuda() 68 | w = torch.zeros(1).cuda() 69 | r = torch.zeros(1, dtype=torch.int32).cuda() 70 | 71 | dist.broadcast(subnet, 0) 72 | dist.broadcast(w, 0) 73 | dist.broadcast(r, 0) 74 | subnet = subnet.cpu().tolist() 75 | w = w.cpu().tolist() 76 | r = r.cpu().tolist() 77 | return subnet + w + r# [op, ... , c_m, r] 78 | -------------------------------------------------------------------------------- /core/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luminolx/ScaleNet/22487c892010aa86316fe6d0762723181db58cc8/core/utils/__init__.py -------------------------------------------------------------------------------- /core/utils/arch_util.py: -------------------------------------------------------------------------------- 1 | 2 | from core.search_space.ops import channel_mults 3 | 4 | def _decode_arch(self, subnet_topk): 5 | canidates = [] 6 | for target_subnet in subnet_topk: 7 | new_arch = {} 8 | count = 0 9 | off_set = len(target_subnet) // 2 10 | for _type in self.net_cfg: 11 | if _type == 'backbone': 12 | new_arch = {_type: {}} 13 | for stage in self.net_cfg[_type]: 14 | if isinstance(self.net_cfg[_type][stage], list): 15 | n, stride, inp, ori_oup, t, c_search, ops = self.net_cfg[_type][stage] 16 | for i in range(n): 17 | stride = stride if i == 0 else 1 18 | oup = int(ori_oup * channel_mults[target_subnet[count + off_set]]) 19 | new_arch[_type]['{}_{}'.format(stage, count)] = [1, stride, inp, 20 | oup, t, 21 | False, ops[target_subnet[count]]] 22 | inp = oup 23 | count += 1 24 | else: 25 | new_arch[_type][stage] = self.net_cfg[_type][stage] 26 | else: 27 | new_arch[_type] = self.net_cfg[_type] 28 | canidates.append(new_arch) 29 | return canidates -------------------------------------------------------------------------------- /core/utils/flops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import pickle 4 | import yaml 5 | from core.search_space.ops import InvertedResidual, FC, Conv2d, SqueezeExcite 6 | 7 | 8 | def count_flops(model, subnet=None, input_shape=[3, 224, 224]): 9 | if subnet is None: 10 | subnet = [0] * len(model) 11 | flops = [] 12 | m_list = [] 13 | skip = 0 14 | for ms, idx in zip(model, subnet): 15 | for m in ms[idx].modules(): 16 | if isinstance(m, SqueezeExcite): 17 | skip = 2 18 | if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Linear): 19 | if skip == 0: 20 | m_list.append(m) 21 | else: 22 | flops.append(m.in_channels * m.out_channels) 23 | skip -= 1 24 | 25 | c, w, h = input_shape 26 | for m in m_list: 27 | if isinstance(m, nn.Conv2d): 28 | c = m.out_channels 29 | w = (w + m.padding[0] * 2 - m.kernel_size[0] + 1) // m.stride[0] 30 | h = (h + m.padding[1] * 2 - m.kernel_size[1] + 1) // m.stride[1] 31 | flops.append(m.in_channels * m.out_channels * w * h // m.groups * m.kernel_size[0] * m.kernel_size[1]) 32 | elif isinstance(m, nn.Linear): 33 | flops.append(m.in_features * m.out_features) 34 | return sum(flops) 35 | 36 | 37 | # trim = yaml.load(open('mb_imagenet_timedict_v1/mobile_trim.yaml', 'r'), Loader=yaml.FullLoader) 38 | 39 | 40 | # def count_latency(model, subnet=None, input_shape=(3, 224, 224), dump_path=''): 41 | # if subnet is None: 42 | # subnet = [0] * len(model) 43 | # flops = [] 44 | # m_list = [] 45 | # c = 0 46 | # for ms, idx in zip(model, subnet): 47 | # c += 1 48 | # for m in ms[idx].modules(): 49 | # if isinstance(m, (InvertedResidual, FC, Conv2d)): 50 | # m_list.append(m) 51 | 52 | # latency = [] 53 | # c, w, h = input_shape 54 | # for m in m_list: 55 | # if isinstance(m, Conv2d): 56 | # if m.k == 1: 57 | # latency.append(trim.get(f'Conv_1-input:{w}x{h}x{c}-output:{w//m.stride}x{h//m.stride}x{m.oup}', {'mean': 0})['mean']) 58 | # else: 59 | # latency.append(trim.get(f'Conv-input:{w}x{h}x{c}-output:{w//m.stride}x{h//m.stride}x{m.oup}', {'mean': 0})['mean']) 60 | # c = m.oup 61 | # w = w // m.stride 62 | # h = h // m.stride 63 | # elif isinstance(m, InvertedResidual): 64 | # latency.append(trim.get(f'expanded_conv-input:{w}x{h}x{c}-output:{w//m.stride}x{h//m.stride}x{m.oup}-expand:{m.t}-kernel:{m.k}-stride:{m.stride}-idskip:{1 if m.use_shortcut else 0}', {'mean': 0})['mean']) 65 | # if latency[-1] == 0: 66 | # print(f'expanded_conv-input:{w}x{h}x{c}-output:{w//m.stride}x{h//m.stride}x{m.oup}-expand:{m.t}-kernel:{m.k}-stride:{m.stride}-idskip:{1 if m.use_shortcut else 0}') 67 | # c = m.oup 68 | # w = w // m.stride 69 | # h = h // m.stride 70 | # elif isinstance(m, FC): 71 | # latency.append(trim.get(f'Logits-input:{w}x{h}x{c}-output:{m.oup}', {'mean': 0})['mean']) 72 | # #print(latency) 73 | # return sum(latency) 74 | -------------------------------------------------------------------------------- /core/utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | def create_logger(name, log_file, level=logging.INFO): 4 | """create logger for training""" 5 | logger = logging.getLogger(name) 6 | formatter = logging.Formatter('[%(asctime)s][%(filename)15s]' 7 | '[line:%(lineno)4d][%(levelname)8s]%(message)s') 8 | fh = logging.FileHandler(log_file) 9 | fh.setFormatter(formatter) 10 | logger.setLevel(level) 11 | logger.addHandler(fh) 12 | return logger 13 | -------------------------------------------------------------------------------- /core/utils/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import copy 3 | import numpy as np 4 | 5 | 6 | class IterLRScheduler(object): 7 | ''' Step learning rate while training ''' 8 | 9 | def __init__(self, optimizer, lr_steps, lr_mults, warmup_steps, 10 | warmup_strategy, warmup_lr, latest_iter=-1, decay_stg='step', decay_step=500000, alpha=0.): 11 | milestones = lr_steps 12 | assert len(milestones) == len(lr_mults), "{} vs {}".format(milestones, lr_mults) 13 | assert decay_stg in ['step', 'cosine'] 14 | self.decay_stg = decay_stg 15 | self.milestones = milestones 16 | self.lr_mults = lr_mults 17 | self.warmup_steps = warmup_steps 18 | self.warmup_lr = warmup_lr 19 | self.warmup_strategy = warmup_strategy 20 | if self.decay_stg == 'cosine': 21 | self.alpha = alpha 22 | self.decay_step = decay_step 23 | assert self.decay_step > self.warmup_steps 24 | self.set_lr = False if self.decay_stg == 'step' else True 25 | if not isinstance(optimizer, torch.optim.Optimizer): 26 | raise TypeError('{} is not an Optimizer'.format( 27 | type(optimizer).__name__)) 28 | self.optimizer = optimizer 29 | self.ori_param_groups = copy.deepcopy(optimizer.param_groups) 30 | for i, group in enumerate(optimizer.param_groups): 31 | if 'lr' not in group: 32 | raise KeyError("param 'lr' is not specified " 33 | "in param_groups[{}] when resuming an optimizer".format(i)) 34 | self.latest_iter = latest_iter 35 | 36 | def _get_lr(self): 37 | if self.latest_iter < self.warmup_steps: 38 | if self.warmup_strategy == 'gradual': 39 | return list(map(lambda group: group['lr'] * float(self.latest_iter) / \ 40 | float(self.warmup_steps), self.ori_param_groups)) 41 | elif self.decay_stg == 'cosine': 42 | c_step = min(self.latest_iter, self.decay_step) 43 | decayed = (1 - self.alpha) * 0.5 * (1 + np.cos( 44 | np.pi * (c_step - self.warmup_steps) / (self.decay_step - self.warmup_steps))) + self.alpha 45 | return list(map(lambda group: group['lr'] * decayed, self.ori_param_groups)) 46 | 47 | if not self.set_lr: 48 | mults = 1. 49 | for iter, mult in zip(self.milestones, self.lr_mults): 50 | if iter <= self.latest_iter: 51 | mults *= mult 52 | else: 53 | break 54 | self.set_lr = True 55 | return list(map(lambda group: group['lr'] * mults, self.ori_param_groups)) 56 | else: 57 | try: 58 | pos = self.milestones.index(self.latest_iter) 59 | except ValueError: 60 | return list(map(lambda group: group['lr'], self.optimizer.param_groups)) 61 | except: 62 | raise Exception('wtf?') 63 | return list(map(lambda group: group['lr'] * self.lr_mults[pos], self.optimizer.param_groups)) 64 | 65 | def get_lr(self): 66 | return list(map(lambda group: group['lr'], self.optimizer.param_groups)) 67 | 68 | def step(self, this_iter=None): 69 | if this_iter is None: 70 | this_iter = self.latest_iter + 1 71 | self.latest_iter = this_iter 72 | for param_group, lr in zip(self.optimizer.param_groups, self._get_lr()): 73 | param_group['lr'] = lr 74 | -------------------------------------------------------------------------------- /core/utils/measure.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Variable 2 | import torch.nn as nn 3 | import torch 4 | import time 5 | import sys 6 | import yaml 7 | from core.model.net import Net 8 | def count_fc_flops_params(m, x, y): 9 | ret = 2 * m.weight.numel() 10 | n_ps = m.weight.numel() 11 | if m.bias is None: 12 | ret -= m.bias.size(0) 13 | else: 14 | n_ps += m.bias.size(0) 15 | m.flops = torch.Tensor([ret]) 16 | m.n_params = torch.Tensor([n_ps]) 17 | 18 | def count_conv_flops_params(m, x, y): 19 | c_out, c_in, ks_h, ks_w = m.weight.size() 20 | out_h, out_w = y.size()[-2:] 21 | n_ps = m.weight.numel() 22 | if m.bias is None: 23 | ret = (2 * c_in * ks_h * ks_w - 1) * out_h * out_w * c_out / m.groups 24 | else: 25 | ret = 2 * c_in * ks_h * ks_w * out_h * out_w * c_out / m.groups 26 | n_ps += m.bias.size(0) 27 | m.flops = torch.Tensor([ret]) 28 | m.n_params = torch.Tensor([n_ps]) 29 | 30 | def count_bn_params(m, x, y): 31 | n_ps = 0 32 | if m.weight is not None: 33 | n_ps += m.weight.numel() 34 | if m.bias is not None: 35 | n_ps += m.bias.numel() 36 | m.n_params = torch.Tensor([n_ps]) 37 | 38 | def flops_str(FLOPs): 39 | preset = [(1e12, 'T'), (1e9, 'G'), (1e6, 'M'), (1e3, 'K')] 40 | 41 | for p in preset: 42 | if FLOPs // p[0] > 0: 43 | N = FLOPs / p[0] 44 | ret = "%.3f%s" % (N, p[1]) 45 | return ret 46 | ret = "%.1f" % (FLOPs) 47 | return ret 48 | 49 | def measure_model(model, input_shape=[3, 224, 224]): 50 | 51 | for m in model.modules(): 52 | if len(list(m.children())) > 1: 53 | continue 54 | if isinstance(m, nn.Linear): 55 | m.register_forward_hook(count_fc_flops_params) 56 | m.register_buffer('flops', torch.zeros(0)) 57 | m.register_buffer('n_params', torch.zeros(0)) 58 | if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): 59 | m.register_forward_hook(count_conv_flops_params) 60 | m.register_buffer('flops', torch.zeros(0)) 61 | m.register_buffer('n_params', torch.zeros(0)) 62 | # if isinstance(m, nn.BatchNorm2d): 63 | # m.register_buffer('n_params', torch.zeros(0)) 64 | # m.register_forward_hook(count_bn_params) 65 | fake_data = {'images': torch.randn([1] + input_shape)} 66 | if torch.cuda.is_available(): 67 | fake_data['images'] = fake_data['images'].cuda() 68 | model(fake_data) 69 | total_flops = 0. 70 | total_params = 0. 71 | for m in model.modules(): 72 | if hasattr(m, 'flops'): 73 | total_flops += m.flops.item() 74 | if hasattr(m, 'n_params'): 75 | total_params += m.n_params.item() 76 | s_t = time.time() 77 | for i in range(10): 78 | model(fake_data) 79 | avg_time = (time.time() - s_t) /10 80 | return flops_str(total_flops), flops_str(total_params), avg_time 81 | 82 | if __name__ == '__main__': 83 | cfg = sys.argv[1] 84 | config = yaml.load(open(sys.argv[1], 'r')).pop('test') 85 | model_cfg = config['model'] 86 | model = Net(model_cfg) 87 | data_cfg = config.pop('data') 88 | input_shape = [data_cfg['final_channel'], data_cfg['final_height'], data_cfg['final_width']] 89 | print(measure_model(model, input_shape=input_shape)) 90 | -------------------------------------------------------------------------------- /core/utils/misc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from functools import reduce 3 | 4 | import torch 5 | 6 | class AverageMeter(object): 7 | ''' Computes and stores the average and current value ''' 8 | def __init__(self, length=0): 9 | self.length = length 10 | self.reset() 11 | 12 | def reset(self): 13 | self.history = [] 14 | self.count = 0 15 | self.val = 0.0 16 | self.avg = 0.0 17 | self.sum = 0.0 18 | self.all = 0.0 19 | 20 | def update(self, val): 21 | self.val = val 22 | self.sum += val 23 | self.count += 1 24 | self.all = self.sum / 3600 25 | if self.length > 0: 26 | self.history.append(val) 27 | if len(self.history) > self.length: 28 | del self.history[0] 29 | self.avg = np.mean(self.history) 30 | else: 31 | self.avg = self.sum / self.count 32 | 33 | def get_cls_accuracy(output, target, topk=(1, ), ignore_indices=[-1]): 34 | """Computes the precision@k for the specified values of k""" 35 | target = target.long() 36 | masks = [target != idx for idx in ignore_indices] 37 | mask = reduce(lambda x, y : x&y, masks) 38 | keep = torch.nonzero(mask).squeeze() 39 | if keep.numel() <= 0: 40 | return [torch.cuda.FloatTensor([1]).zero_()] 41 | if keep.dim() == 0: 42 | keep = keep.view(-1) 43 | assert keep.dim() == 1, keep.dim() 44 | target = target[keep] 45 | output = output[keep] 46 | maxk = max(topk) 47 | batch_size = target.size(0) 48 | _, pred = output.topk(maxk, 1, True, True) 49 | pred = pred.t() 50 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 51 | 52 | res = [] 53 | for k in topk: 54 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 55 | res.append(correct_k.mul_(100.0 / batch_size)) 56 | return res 57 | 58 | # def get_cls_accuracy(output, target, topk=(1, )): 59 | # """Computes the accuracy over the k top predictions for the specified values of k""" 60 | # maxk = max(topk) 61 | # batch_size = target.size(0) 62 | 63 | # _, pred = output.topk(maxk, 1, True, True) 64 | # pred = pred.t() 65 | # correct = pred.eq(target.view(1, -1).expand_as(pred)) 66 | 67 | # res = [] 68 | # for k in topk: 69 | # # correct_k = correct[:k].reshape(-1).float().sum(0) 70 | # correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 71 | # # res.append(correct_k.item() / batch_size) 72 | # res.append(correct_k.mul_(100.0 / batch_size)) 73 | # return res 74 | -------------------------------------------------------------------------------- /core/utils/optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def build_optimizer(model, cfg_optimizer): 5 | ''' Build optimizer for training ''' 6 | opt_type = cfg_optimizer.pop('type', 'SGD') 7 | base_lr = cfg_optimizer.get('lr', 0.01) 8 | base_wd = cfg_optimizer.get('weight_decay', 0.0001) 9 | parameters = model.get_params(base_lr, base_wd) 10 | opt_fun = getattr(torch.optim, opt_type) 11 | optimizer = opt_fun(parameters, **cfg_optimizer) 12 | return optimizer 13 | -------------------------------------------------------------------------------- /gen_subnet.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import sys 3 | 4 | path = sys.argv[1] 5 | cfg = yaml.load(open(path, 'r')) 6 | 7 | model = cfg['model'] 8 | backbone = model['backbone'] 9 | 10 | supernet = [] 11 | 12 | for stage in backbone: 13 | item = backbone[stage] 14 | print(stage, item) 15 | if not isinstance(item, list): 16 | continue 17 | n, stride, c_in, c_out, expand_ratio, c_search, ops = item 18 | if len(expand_ratio) == 0: 19 | expand_ratio = [1] 20 | if not isinstance(expand_ratio[0], list): 21 | expand_ratio = [expand_ratio] * len(ops) 22 | for _ in range(n): 23 | stage = [] 24 | for ts, op in zip(expand_ratio, ops): 25 | for t in ts: 26 | stage.append([1, stride, c_in, c_out, [t], c_search, [op]]) 27 | supernet.append(stage) 28 | c_in = c_out 29 | stride = 1 30 | 31 | 32 | for idx, stage in enumerate(supernet): 33 | print(f'\nstage {idx}') 34 | for i, op in enumerate(stage): 35 | print(f'{i} {op}') 36 | 37 | subnet = [0] * 30 38 | subnet = [0, 0, 5, 4, 1, 1, 0, 4, 0, 3, 3, 6, 2, 5, 1, 2, 1, 3, 5, 3, 4, 2, 5, 0, 0, 0] 39 | s = '0 0 2 2 6 6 2 0 5 2 3 4 6 0 0 2 2 3 5 4 5 2 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # top1 40 | s = '0 0 2 2 6 6 2 0 5 2 3 4 6 0 0 2 1 2 5 4 5 0 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 2 41 | s = '0 0 2 2 6 6 2 0 5 4 3 4 6 0 0 2 1 2 5 4 5 0 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 3 42 | s = '0 0 2 2 6 6 2 0 5 2 3 4 6 0 0 2 1 2 5 4 5 2 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 4 43 | s = '0 0 2 2 6 6 2 0 6 4 3 4 1 0 0 2 0 3 5 4 5 2 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 5 44 | s = '0 0 2 2 6 6 2 0 6 4 3 4 6 2 0 2 2 3 5 4 5 2 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 6 45 | s = '0 0 2 2 6 6 2 0 6 4 3 4 6 0 0 2 0 3 5 4 5 2 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 7 46 | s = '0 0 2 2 6 6 2 0 5 2 3 4 6 0 0 2 0 3 5 4 5 2 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 8 47 | s = '0 0 2 2 6 6 2 0 6 4 5 4 4 0 0 2 1 2 5 4 5 2 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 9 48 | s = '0 0 2 2 6 6 2 0 6 4 5 4 4 2 0 2 1 2 5 4 5 2 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 10 49 | 50 | # 17619 + 3W iter 51 | s = '0 0 5 2 6 6 4 0 4 6 2 5 4 0 1 1 6 4 5 0 2 4 3 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 1 52 | 53 | # 30480 + 3W iter 54 | s = '0 0 5 2 6 6 3 1 6 6 2 3 4 2 0 2 0 2 3 0 4 0 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 1 55 | 56 | # 47290 + 3W iter 57 | s = '0 0 0 0 6 0 4 2 0 0 4 4 6 2 0 2 4 3 5 4 5 2 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 1 58 | 59 | # noflops 9W iter 60 | s = '0 0 2 4 6 4 4 6 2 6 5 6 6 4 1 4 2 2 5 5 0 0 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 1 61 | 62 | # FLOPs 50x20 63 | s = '0 0 3 6 0 6 4 6 0 2 3 4 0 3 0 2 2 2 5 4 5 2 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 1 64 | s = '0 0 3 6 0 6 4 2 0 2 4 4 0 2 0 2 2 2 5 4 5 4 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 2 65 | s = '0 0 2 2 6 6 4 2 4 2 3 4 0 0 0 2 2 1 5 4 5 4 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 3 66 | s = '0 0 3 6 0 6 4 6 0 2 3 4 0 2 0 2 2 2 5 5 5 2 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 4 67 | s = '0 0 2 2 6 6 4 2 4 2 3 4 0 0 0 2 2 2 5 4 5 0 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 5 68 | s = '0 0 3 6 0 6 2 0 4 2 3 4 0 4 0 2 2 2 5 4 5 4 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 7 69 | s = '0 0 3 6 0 6 4 0 4 2 4 4 0 0 0 0 2 4 5 5 2 4 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 10 70 | s = '0 0 2 2 6 6 4 2 4 2 3 4 0 4 0 2 2 2 5 4 5 0 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 6 71 | s = '0 0 3 6 0 6 4 6 0 2 3 4 0 2 0 0 2 2 5 5 5 2 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 8 72 | s = '0 0 3 6 0 6 4 6 0 2 3 4 0 0 0 2 2 2 5 5 5 2 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 9 73 | 74 | # FLOPs 0.05 50x20 75 | s = '0 0 2 0 6 6 2 1 4 0 3 4 6 0 3 0 2 3 3 5 2 2 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 1 76 | #s = '0 0 2 0 6 6 2 1 4 2 3 4 6 0 3 0 2 3 3 5 2 2 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 2 77 | s = '0 0 2 0 6 6 2 1 4 0 3 4 6 0 3 0 0 3 5 5 2 2 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 3 78 | s = '0 0 4 6 0 6 2 0 4 0 5 4 6 0 2 2 1 3 5 5 2 2 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 4 79 | s = '0 0 4 6 0 6 2 0 2 2 5 4 6 0 3 0 2 3 5 5 4 2 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 5 80 | s = '0 0 2 0 6 6 2 0 3 2 4 4 6 0 3 0 2 2 5 5 4 2 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 6 81 | s = '0 0 2 0 6 6 2 2 4 1 3 4 6 0 3 0 2 3 5 5 2 2 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 7 82 | s = '0 0 4 6 0 6 2 1 2 0 3 4 6 0 3 0 0 3 5 5 2 2 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 8 83 | s = '0 0 4 6 0 6 2 0 3 2 4 4 6 0 3 0 2 3 5 5 4 2 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 9 84 | s = '0 0 4 6 0 6 2 0 4 0 5 4 6 0 2 2 1 2 5 5 2 2 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 10 85 | 86 | # FLOPs multi-path 2/10 50x20 87 | s = '0 0 2 6 0 2 4 6 5 2 3 2 0 2 3 2 6 4 5 4 4 0 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 1 88 | s = '0 0 4 6 0 2 4 6 3 2 3 2 0 2 1 2 6 4 5 4 4 0 5 0 0 0' # 2 89 | s = '0 0 4 6 0 2 4 6 5 2 3 2 0 2 1 2 6 4 5 4 4 0 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 3 90 | s = '0 0 4 6 0 2 4 6 3 2 3 2 2 2 1 2 6 4 5 4 4 4 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 6 91 | s = '0 0 4 6 0 2 4 6 3 2 3 2 2 2 1 2 6 4 5 4 4 4 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 7 92 | s = '0 0 0 6 0 2 4 0 4 2 3 2 0 2 1 2 6 4 5 4 4 0 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 8 93 | 94 | # FLOPs multi-path 2/4 50x20 95 | s = '0 0 2 0 6 2 4 4 6 0 2 6 4 3 0 0 2 2 5 4 3 2 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 1 96 | s = '0 0 2 0 6 2 2 4 2 0 3 6 4 2 0 0 2 2 5 4 3 2 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 2 97 | s = '0 0 2 0 6 2 2 4 2 0 2 6 4 3 0 0 2 2 5 4 3 2 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 3 98 | s = '0 0 2 0 6 0 4 4 2 0 2 6 4 2 0 0 2 2 5 4 3 3 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 4 99 | s = '0 0 2 0 6 2 4 4 6 0 2 6 5 2 0 0 2 2 5 4 3 2 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 5 100 | s = '0 0 2 0 6 2 4 3 6 0 2 6 4 3 0 0 2 2 5 4 3 2 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 6 101 | s = '0 0 4 0 6 2 4 4 6 0 2 6 4 3 0 0 2 2 5 4 3 2 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 7 102 | s = '0 0 4 0 6 2 4 4 6 0 2 6 4 2 0 0 2 2 5 4 3 2 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 8 103 | s = '0 0 2 0 6 2 4 4 2 0 2 4 0 2 0 0 2 2 5 4 3 2 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 9 104 | s = '0 0 2 0 6 2 4 4 6 0 5 4 0 2 0 0 2 2 5 4 0 2 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 10 105 | 106 | 107 | # FLOPs 322M multi-path 5/10 50x20 # wrong... still 330M 108 | s = '0 0 2 0 6 2 2 2 3 6 3 0 4 6 2 4 2 2 5 5 2 2 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 1 109 | s = '0 0 2 2 6 2 1 2 2 6 3 4 4 6 2 4 2 2 5 5 6 2 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 2 110 | s = '0 0 2 0 6 2 2 2 3 6 3 0 4 6 2 4 2 2 5 2 4 5 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 3 111 | s = '0 0 2 0 6 2 2 2 3 6 3 4 4 6 2 4 2 2 5 5 2 2 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 4 112 | s = '0 0 2 0 6 2 2 2 2 6 3 4 4 6 2 4 2 2 5 5 2 2 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 5 113 | s = '0 0 2 0 6 2 2 2 1 6 3 4 4 6 2 4 2 2 5 5 2 2 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 6 114 | s = '0 0 2 0 6 2 2 2 1 6 3 0 4 6 2 4 2 2 5 5 2 2 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 7 115 | s = '0 0 2 0 6 2 2 2 1 6 3 4 4 6 2 4 2 2 5 2 4 5 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 8 116 | s = '0 0 2 2 6 2 2 2 1 6 3 0 4 6 2 4 2 2 5 5 2 2 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 9 117 | s = '0 0 2 2 6 2 2 2 3 6 4 4 4 6 2 4 2 2 5 2 5 2 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 10 118 | 119 | # SE FLOPs 330M mpath 5/10 50x20 120 | s = '0 0 6 6 12 6 6 10 6 0 5 12 10 10 2 4 4 8 7 10 0 10 11 0' # 1 121 | s = '0 0 6 6 12 6 6 10 6 0 5 12 10 10 2 2 4 8 7 10 0 10 11 0' # 2 122 | s = '0 0 6 6 12 6 6 10 6 0 5 12 10 10 2 4 4 8 7 8 0 10 11 0' # 3 123 | s = '0 0 6 6 12 6 6 10 6 2 5 12 10 10 2 4 4 8 7 10 2 10 11 0' # 4 124 | s = '0 0 6 6 6 12 8 10 6 0 6 8 12 11 2 2 4 8 7 10 2 10 11 0' # 5 125 | s = '0 0 6 6 6 12 6 10 6 0 5 12 10 10 2 4 4 10 7 10 2 10 11 0' # 6 126 | s = '0 0 6 6 12 6 6 10 4 12 6 8 12 11 2 2 4 8 7 10 0 10 11 0' # 7 127 | s = '0 0 6 6 12 6 6 10 6 0 5 12 10 10 2 2 4 8 7 10 2 10 11 0' # 8 128 | s = '0 0 6 6 12 6 6 10 4 12 6 12 10 7 2 4 4 8 7 10 2 10 11 0' # 9 129 | s = '0 0 6 6 12 6 6 10 6 2 5 12 10 10 2 2 4 8 7 10 2 10 11 0' # 10 130 | 131 | # FLOPs 322M multi-path 5/10 50x20 0.05 132 | s = '0 0 2 0 6 6 2 1 4 2 3 4 6 0 2 3 0 2 5 5 0 2 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 1 319881920 133 | s = '0 0 2 0 6 6 2 1 4 5 3 0 6 4 2 0 0 2 5 5 0 2 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 2 321123776 134 | s = '0 0 2 0 6 6 2 1 4 3 3 0 6 4 2 0 0 2 5 5 0 2 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 3 308611136 135 | s = '0 0 2 0 6 6 2 1 4 2 3 4 6 4 2 0 0 2 5 5 0 2 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 4 136 | s = '0 0 2 0 6 6 2 0 4 1 3 4 6 0 2 3 0 2 5 4 0 3 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 5 317699264 137 | s = '0 0 2 0 6 6 2 1 4 0 3 4 6 0 2 3 0 2 5 5 0 2 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 6 138 | s = '0 0 2 0 6 6 2 0 4 1 3 4 4 0 2 0 0 2 5 5 0 2 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 7 315055616 139 | #s = '0 0 2 0 6 6 2 1 4 4 3 0 6 4 2 0 0 2 5 5 0 2 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 8 140 | #s = '0 0 2 0 6 6 2 1 4 2 3 4 6 0 2 3 0 2 5 5 0 4 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 9 320559296 141 | #s = '0 0 2 0 6 6 2 1 4 2 3 4 4 0 2 0 0 2 5 5 0 4 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 10 142 | 143 | # FLOPs 330M mpath 5/10 1w+6w 50x20 144 | s = '0 0 4 2 0 6 2 0 2 4 4 6 4 0 2 4 0 3 5 4 0 0 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 1 145 | s = '0 0 4 2 0 6 2 0 2 4 5 6 4 0 2 4 0 4 5 4 0 0 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 2 146 | s = '0 0 4 2 6 0 2 0 2 6 4 3 4 0 2 4 0 4 5 4 0 0 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 3 147 | s = '0 0 4 6 4 0 2 0 2 6 4 3 4 0 2 4 0 4 5 4 0 0 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 4 148 | s = '0 0 4 2 0 6 2 0 2 6 4 3 4 0 2 4 0 4 5 4 0 0 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 5 149 | s = '0 0 4 2 0 6 2 0 2 4 4 6 4 0 2 4 0 4 5 4 0 0 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 6 150 | s = '0 0 4 2 4 6 2 0 2 6 4 6 4 0 2 4 0 3 5 4 0 0 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 7 151 | s = '0 0 4 4 6 0 2 0 2 6 4 3 4 0 2 4 0 4 5 4 0 0 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 8 152 | s = '0 0 4 4 0 6 2 0 2 6 4 6 4 0 2 2 4 3 5 4 0 0 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 9 153 | s = '0 0 4 2 0 6 2 0 2 6 4 3 4 3 0 4 0 4 5 4 0 4 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 10 154 | 155 | # FLOPs cand-350M mpath 5/10 1w+6w 322M 50x20 156 | s = '0 0 4 2 6 6 2 4 2 2 3 2 6 0 2 4 4 4 2 3 2 5 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 1 157 | s = '0 0 4 0 6 6 2 2 2 2 3 2 6 0 2 4 0 3 2 3 2 5 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 2 158 | s = '0 0 4 0 6 0 2 2 2 2 5 2 6 0 2 4 4 4 2 3 2 4 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 3 159 | s = '0 0 4 0 6 6 2 2 0 4 4 2 6 0 2 4 4 4 2 3 2 5 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 4 160 | s = '0 0 4 2 6 6 2 2 2 2 4 2 6 0 2 4 4 4 2 3 2 5 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 5 161 | s = '0 0 4 2 6 6 2 2 0 2 5 2 6 0 2 4 4 4 2 3 2 5 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 6 162 | s = '0 0 4 0 6 0 2 2 2 2 5 2 6 0 2 4 4 4 2 3 2 0 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 7 163 | s = '0 0 4 2 6 6 2 2 0 2 5 2 6 0 2 4 0 3 2 3 2 4 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 8 164 | s = '0 0 4 2 6 6 2 2 0 3 5 2 6 0 2 4 4 2 2 3 2 4 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 9 165 | s = '0 0 4 2 6 6 2 2 0 3 4 2 6 0 2 4 4 4 2 3 2 5 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 10 166 | 167 | # FLOPs 322 0p05 sota 168 | s = '0 0 2 0 6 4 2 2 4 6 4 6 0 4 2 4 0 2 5 5 0 2 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 54.7019773210798, 320592224 169 | s = '0 0 2 0 6 4 2 4 0 6 4 6 0 4 2 4 0 2 5 5 0 4 5 0 0 0 0 0' # 54.70919714168627, 319764320 170 | s = '0 0 4 0 6 4 2 2 0 6 4 6 2 4 2 4 0 2 5 5 0 2 5 0 0 0 0 0' # 54.81016252478775, 321194336 171 | 172 | 173 | # FLOPs 330M - 318-322 0p05 sota - 5000 subnets 174 | s = '0 0 4 6 0 2 2 2 1 4 4 2 6 4 2 0 0 2 5 2 4 2 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 54.75958578927176 320451104 175 | s = '0 0 4 6 0 4 2 2 0 4 4 4 6 4 2 0 0 2 5 2 4 2 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 54.75917216709682 320432288 176 | s = '0 0 4 6 0 4 2 2 0 4 4 2 6 4 2 0 0 2 5 2 4 2 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 54.73450431045221 319303328 177 | s = '0 0 4 6 0 0 2 2 5 0 4 4 6 4 2 0 2 2 5 2 4 2 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 54.73292494793328 321429536 178 | s = '0 0 4 6 0 0 2 2 1 0 4 2 2 0 2 0 2 2 5 2 4 2 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 54.732436199577485 319594976 179 | s = '0 0 4 6 0 2 2 2 1 4 4 4 6 0 2 0 0 2 5 2 4 2 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 54.73168415925941 319698464 180 | s = '0 0 4 6 0 0 2 2 0 0 4 2 5 0 2 0 0 2 5 2 4 2 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 54.7296910188636 321279008 181 | s = '0 0 4 6 0 0 2 2 3 4 4 4 6 4 2 0 0 2 5 2 4 2 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 54.72337380234076 319773728 182 | s = '0 0 4 6 0 2 2 2 1 4 4 4 6 4 2 0 0 2 5 2 4 2 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 54.716905866350444 321580064 183 | s = '0 0 4 6 0 0 2 2 1 0 4 4 2 0 2 0 2 2 5 2 4 2 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 54.714912881656566 320723936 184 | 185 | 186 | # greedy 1 + 6W nopool mpath 187 | s = '0 0 0 0 6 0 5 4 6 4 5 4 2 6 2 6 4 0 1 4 3 4 3 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 50.74304487267319 319806656 188 | s = '0 0 0 0 6 0 5 6 1 2 5 6 2 2 2 6 4 0 1 4 3 4 3 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 50.73947252546038 321029696 189 | s = '0 0 2 0 6 0 5 4 6 4 5 6 2 2 2 6 4 0 1 0 3 4 3 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 50.665243187729196 319957184 190 | s = '0 0 2 0 6 0 5 4 6 4 5 4 0 4 2 6 4 0 1 4 0 4 3 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 50.654187805798585 319298624 191 | s = '0 0 0 0 6 0 5 2 6 3 5 4 1 6 2 6 4 0 1 4 0 4 3 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 50.61940461762097 320371136 192 | s = '0 0 0 0 6 0 5 4 6 4 5 6 2 2 2 6 4 0 1 5 0 4 3 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 50.61357607160296 318903488 193 | s = '0 0 2 0 6 0 5 4 6 4 5 4 2 4 2 6 4 0 1 4 0 4 3 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 50.60834907998844 320051264 194 | s = '0 0 0 2 6 0 5 4 6 2 5 6 2 2 2 6 4 0 1 4 3 4 3 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 50.591577958087534 321236672 195 | s = '0 0 0 0 6 0 5 4 6 4 5 6 2 2 2 6 4 0 1 4 3 4 3 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 50.581951530612244 318677696 196 | s = '0 0 0 0 6 0 5 6 1 2 5 6 2 2 2 6 4 0 1 4 1 4 3 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 50.5689030861368 320126528 197 | 198 | 199 | # megvii 15w evolution 200 | s = '0 0 4 0 2 6 5 1 6 2 2 4 3 6 2 2 6 2 3 4 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 60.567932907415894 320309984 201 | s = '0 0 4 0 0 4 0 4 6 4 4 0 4 6 2 0 6 2 4 4 1 0 3 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 60.5587952282964 321876416 202 | s = '0 0 4 0 0 4 0 4 6 4 4 0 5 6 2 0 6 2 4 4 0 0 3 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 60.55729114766024 320615744 203 | s = '0 0 4 0 2 6 5 2 6 3 2 3 3 6 0 0 6 2 4 4 0 0 3 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 60.54608527981505 321998720 204 | s = '0 0 4 0 1 6 5 2 6 6 2 3 6 1 0 2 6 2 4 4 0 0 3 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 60.51983798280054 318800000 205 | s = '0 0 4 0 1 6 5 2 6 2 2 3 4 6 2 0 6 2 4 4 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 60.49923137742646 321556544 206 | s = '0 0 4 0 1 6 5 2 6 0 2 4 3 6 2 2 6 2 4 4 0 0 3 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 60.46527535574777 321857600 207 | s = '0 0 4 0 0 4 0 2 6 4 4 0 5 6 2 0 6 2 4 4 0 0 3 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 60.463207167022084 318357824 208 | s = '0 0 4 0 0 4 2 4 6 4 2 4 3 6 2 2 6 2 4 4 0 0 3 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 60.45922127548529 321782336 209 | s = '0 0 4 0 1 6 5 2 6 2 2 4 3 6 0 0 6 2 4 4 0 0 3 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 60.43372609663983 321707072 210 | 211 | 212 | # greedynas se 320M-330M 213 | s = '0 0 8 12 12 2 6 1 1 2 10 7 10 12 6 1 1 8 6 0 10 4 11 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 43.71954112150231 320968064 214 | s = '0 0 8 12 12 2 6 1 0 2 10 7 6 11 6 4 1 8 6 0 10 4 11 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 43.712058086784516 320801024 215 | s = '0 0 2 12 12 2 6 5 2 2 10 7 6 11 6 4 1 1 6 0 10 4 11 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 43.69611428708446 326886944 216 | s = '0 0 8 12 12 2 6 1 1 2 10 7 6 6 6 4 1 5 6 0 10 4 11 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 43.685321963563254 329021888 217 | s = '0 0 8 12 12 2 6 1 1 2 10 7 6 11 6 4 1 8 6 0 10 4 11 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 43.659826784717794 329174144 218 | s = '0 0 8 12 12 2 6 1 2 2 10 7 10 6 6 4 1 5 6 0 10 4 11 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 43.65377262660435 321784928 219 | s = '0 0 8 12 12 2 6 5 2 2 10 7 6 11 6 4 1 6 6 0 10 4 11 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 43.64380770313497 322505504 220 | s = '0 0 8 12 12 2 6 5 2 2 10 7 6 6 6 4 1 5 6 0 10 4 11 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 43.61474033277862 323666528 221 | s = '0 0 8 12 12 2 6 11 2 2 10 7 8 12 6 4 1 5 6 0 10 4 11 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 43.60876130084602 320580128 222 | s = '0 0 8 12 12 2 6 1 1 2 10 7 6 11 6 4 1 6 6 0 10 4 11 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 43.600789362070515 327860864 223 | 224 | # se greedy sample 330M - 从里面选一些小的 225 | s = '0 0 8 12 12 2 6 12 2 2 10 7 10 12 6 4 1 8 6 0 10 4 11 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 43.729092422796754 285413024 226 | s = '0 0 8 12 12 2 6 4 0 6 10 7 6 12 6 4 8 8 6 0 10 4 11 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 43.611092703683035 286572416 227 | s = '0 0 8 12 12 2 6 12 2 2 10 7 6 12 6 4 1 8 6 0 10 4 11 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 43.595976148332866 284284064 228 | s = '0 0 6 12 12 2 6 12 2 2 10 7 6 6 6 2 1 8 6 0 10 4 11 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 43.59315584143814 288542048 229 | s = '0 0 8 12 12 2 6 12 2 6 10 7 6 12 6 4 1 8 6 0 10 4 11 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 43.539006914411274 285789344 230 | #s = '0 0 2 12 12 2 6 12 2 2 10 7 6 12 6 4 1 5 6 0 10 4 11 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 43.43695107284857 289158560 231 | #s = '0 0 2 12 12 2 6 4 0 6 10 7 6 12 6 4 1 8 6 0 10 4 11 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 43.43131045905911 289640576 232 | #s = '0 0 2 12 12 2 6 12 2 2 10 7 8 12 6 4 1 8 6 0 10 4 11 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 43.429317474365234 279364256 233 | #s = '0 0 8 12 12 2 6 3 2 2 10 7 6 12 6 4 6 2 6 0 10 0 11 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 43.38148584171217 287989664 234 | #s = '0 0 8 12 12 2 6 1 2 2 10 7 6 12 6 4 0 2 6 0 10 0 11 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 43.2686378323302 287016224 235 | 236 | # 在400M池子的SE下sample 390M,取370-390的 237 | s = '0 0 6 3 4 12 4 12 0 7 6 10 9 4 1 1 6 10 10 8 4 3 11 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 43.80174239800901 372192800 238 | s = '0 0 6 0 10 4 4 8 11 12 6 10 10 4 1 1 6 10 10 8 4 3 11 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 43.78662576480787 382697696 239 | s = '0 0 6 0 12 1 4 8 11 12 6 8 10 4 1 1 1 10 10 8 4 3 11 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 43.74270489751076 376206560 240 | s = '0 0 6 8 12 2 4 8 11 7 6 10 9 4 6 2 1 6 10 0 10 3 11 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 43.70532701453384 375027296 241 | s = '0 0 6 11 12 2 8 8 4 7 0 8 9 12 6 2 1 6 10 0 10 6 11 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 43.674792932004344 376195520 242 | #s = ' 0 0 6 8 12 5 4 8 11 12 6 10 10 4 1 1 6 10 10 8 4 6 11 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 43.67054390420719 376371872 243 | #s = ' 0 0 6 8 12 6 4 8 11 3 6 7 8 8 1 2 1 8 10 0 4 6 11 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 43.669716504155375 373043936 244 | #s = ' 0 0 6 3 4 12 4 12 11 12 6 7 8 9 1 2 1 8 10 0 4 3 11 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 43.64771862419284 375050720 245 | #s = ' 0 0 6 0 10 4 4 8 11 12 6 10 8 5 1 2 7 8 10 0 4 4 11 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 43.64745533223055 379965920 246 | #s = ' 0 0 6 3 4 12 4 12 0 7 6 10 9 4 1 1 6 10 10 0 10 3 11 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 43.62895436189613 371907104 247 | 248 | 249 | # 在400M 池子的se下sample 290M,取270-285的 250 | s = '0 0 2 2 6 12 2 6 12 4 11 8 10 12 12 2 6 2 5 12 8 6 11 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 42.351864873146525 283010720 251 | s = '0 0 2 2 6 12 2 4 12 4 11 8 12 12 12 4 6 2 5 12 8 6 11 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 42.337913902438416 274005056 252 | s = '0 0 2 2 6 12 2 6 12 4 11 8 12 12 12 4 6 2 5 12 8 6 11 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 42.326294412418285 274012256 253 | s = '0 0 2 2 6 12 2 4 12 4 11 8 12 12 12 2 6 2 5 12 8 3 11 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 42.25180193842674 284281472 254 | s = '0 0 2 2 6 12 2 6 12 4 11 8 12 12 12 2 6 2 5 12 6 6 11 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 42.2136720929827 272639072 255 | #s = ' 0 0 2 2 6 12 2 6 12 4 11 8 12 12 12 2 6 2 5 12 8 6 11 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 42.213596888950896 273150560 256 | #s = ' 0 0 6 2 0 12 4 8 12 0 9 8 4 12 12 4 6 2 5 12 8 6 11 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 42.169751147834624 282217856 257 | #s = ' 0 0 2 2 0 12 2 6 12 4 11 8 3 12 12 2 6 2 5 12 8 6 11 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 42.146249031534 284343776 258 | #s = ' 0 0 2 2 6 12 2 4 12 4 11 8 12 12 12 4 2 1 5 12 8 6 11 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 42.10274170856086 284406464 259 | #s = ' 0 0 2 2 0 12 2 6 12 1 11 8 12 12 12 4 6 2 5 12 8 6 11 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 42.097928572674185 276058592 260 | 261 | # retrain相关性 - 旷视7W - 330M - 30个 - 间隔21 262 | s = '0 0 4 0 6 0 3 4 0 6 2 4 1 6 2 4 1 0 5 4 2 6 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 0 49.66653224400112 319698464 263 | #s = '0 0 2 0 6 0 5 4 6 6 2 2 2 6 2 1 3 4 5 4 2 6 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 1 49.34242785706812 316151648 264 | #s = '0 0 2 0 6 0 5 4 6 6 2 2 3 6 2 1 5 2 0 4 2 6 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 2 49.21777226973553 316824320 265 | #s = '0 0 2 0 6 0 5 3 6 6 3 4 2 4 2 1 4 4 0 2 2 6 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 3 49.15324479706433 320008928 266 | #s = '0 0 2 2 6 0 3 4 6 6 2 0 3 6 2 1 5 4 0 4 2 5 4 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 4 49.0908981634646 327238976 267 | #s = '0 0 4 0 6 0 5 4 6 6 2 4 2 6 2 1 3 4 5 6 2 5 4 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 5 49.05502443897481 317280608 268 | #s = '0 0 4 0 6 0 5 4 6 6 2 2 3 6 2 1 2 2 0 4 2 5 4 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 6 49.00486132563377 314086592 269 | #s = '0 0 4 0 6 0 5 4 6 6 2 2 3 6 2 1 4 4 0 4 2 5 4 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 7 48.96947650520169 316796096 270 | #s = '0 0 4 0 6 1 4 3 0 6 2 2 3 6 6 4 1 2 5 6 2 5 4 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 8 48.90679137560786 326293472 271 | #s = '0 0 2 0 6 0 5 4 6 6 2 2 3 6 2 1 3 4 2 2 2 6 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 9 48.827485765729634 315017984 272 | #s = '0 0 4 0 6 1 4 4 0 6 2 4 1 6 2 4 1 0 0 6 2 5 4 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 10 48.770441327776226 315958784 273 | #s = '0 0 2 0 6 0 5 4 6 5 3 2 3 6 6 4 1 4 5 6 2 0 4 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 11 48.697114360575775 319552640 274 | #s = '0 0 2 0 6 0 5 4 6 6 2 2 3 6 2 1 5 2 2 2 2 0 4 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 12 48.60908461590203 311631104 275 | #s = '0 0 2 1 6 6 3 6 0 6 3 4 1 6 3 1 5 2 0 4 1 6 4 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 13 48.54042076577946 319566752 276 | #s = '0 0 2 0 6 0 5 4 6 6 5 2 1 6 2 6 5 2 1 4 2 5 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 14 48.476983595867544 316175168 277 | #s = '0 0 2 0 6 0 5 6 1 6 3 2 5 6 2 4 1 2 5 6 2 1 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 15 48.39218778026347 319157504 278 | #s = '0 0 2 0 6 0 5 4 6 6 5 4 1 6 6 4 1 2 5 6 2 5 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 16 48.2765193861358 307350464 279 | #s = '0 0 2 1 6 6 3 6 0 6 3 4 1 6 2 1 5 2 0 4 2 6 4 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 17 48.10602538439692 299471264 280 | #s = '0 0 2 0 6 6 3 6 1 6 3 4 1 6 2 6 5 2 0 4 2 5 4 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 18 48.010625255351165 287617184 281 | #s = '0 0 2 0 6 0 4 4 6 6 2 2 3 6 2 1 5 2 2 2 2 6 4 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 19 47.87668158083546 284206784 282 | #s = '0 0 2 0 6 6 0 4 6 4 5 4 5 6 6 4 1 2 5 6 2 4 4 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 20 47.69016834181182 270386432 283 | #s = '0 0 2 1 6 6 0 4 6 2 5 4 1 6 2 1 5 4 2 2 2 6 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 21 47.53979196353834 312990560 284 | #s = '0 0 5 0 6 0 6 2 6 0 1 2 3 0 6 1 5 4 2 2 2 6 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 22 47.46232877458845 320843104 285 | #s = '0 0 2 0 6 0 6 2 6 0 2 2 3 5 6 4 1 2 5 6 2 5 4 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 23 47.30345418501874 295159264 286 | #s = '0 0 2 0 6 1 4 4 0 6 1 2 3 0 6 6 5 2 1 0 2 6 4 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 24 47.11062349591936 304923200 287 | #s = '0 0 0 4 6 6 1 6 1 6 3 1 5 2 0 0 5 4 2 2 2 0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 25 46.97502533270388 310845536 288 | #s = '0 0 2 1 6 6 1 6 6 0 1 2 3 6 2 1 5 4 2 2 2 6 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 26 46.76279130273936 311165408 289 | #s = '0 0 6 2 0 6 1 3 0 2 1 3 6 1 1 1 5 2 0 4 2 6 4 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 27 46.173619795818716 329677216 290 | #s = '0 0 2 1 6 6 1 6 6 0 1 2 3 6 2 1 5 5 4 6 0 6 3 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 28 45.54485196483378 314016032 291 | #s = '0 0 0 4 6 6 1 6 1 6 3 1 5 6 2 0 5 5 2 2 2 0 6 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0' # 29 44.05816050938198 304354016 292 | n = 1 293 | 294 | subnet = [int(i) for i in s.replace(' ', ' ').split(' ')] 295 | res_subnet = [] 296 | for idx, ops in zip(subnet, supernet): 297 | res_subnet.append(ops[idx]) 298 | 299 | print('\ngenerate result:') 300 | print(f'subnet: {subnet}\n') 301 | res = [] 302 | for idx, op in enumerate(res_subnet): 303 | print(f' layer{idx}: {op}') 304 | res.append(f' layer{idx}: {op}') 305 | 306 | 307 | ori_yaml = 'config/retrain/greedynas_se_330/config1.yaml' 308 | new_yaml = 'config/retrain/megvii_7w_correlation/config{}.yaml'.format(n) 309 | start_l, end_l = 5, 28 310 | 311 | data = open(ori_yaml, 'r').read().split('\n') 312 | 313 | data = data[:start_l] + res + data[end_l+1:] 314 | 315 | open(new_yaml, 'w').write('\n'.join(data)) 316 | 317 | 318 | -------------------------------------------------------------------------------- /get_best_subnet.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | log_path = sys.argv[1] 4 | 5 | log = open(log_path, 'r') 6 | nets = [] 7 | last = '' 8 | for line in log: 9 | if ']' not in line: 10 | last += line[:-1] 11 | continue 12 | line = last + line 13 | last = '' 14 | print(line[:-1]) 15 | sp = line.split('-') 16 | if len(sp) != 3: 17 | continue 18 | nets.append([' '.join([x for x in sp[0].split(' ') if x != '']), float(sp[1]), int(sp[2])]) 19 | 20 | nets.sort(key=lambda x: x[1], reverse=True) 21 | print('===================result========================') 22 | print('\n'.join([str(x) for x in nets])) 23 | print('===================top 10========================') 24 | print('\n'.join([str(x) for x in nets[:10]])) 25 | print('==================generate=======================') 26 | print('\n'.join(['s = \'' + x[0][2:-1] + '\' # ' + '{} {}'.format(x[1], x[2]) for x in nets][:10])) 27 | print('================FLOPs top 10=====================') 28 | print('\n'.join([str(x) for x in nets if 318e6 < x[2] < 322e6][:10])) 29 | print('==================generate=======================') 30 | print('\n'.join(['s = \'' + x[0][1:-1] + '\' # ' + '{} {}'.format(x[1], x[2]) for x in nets if 318e6 < x[2] < 322e6][:10])) 31 | 32 | 33 | -------------------------------------------------------------------------------- /sample/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | partition=$1 3 | job_name=$2 4 | train_gpu=$3 5 | num_node=$4 6 | options=$5 7 | total_process=$((train_gpu*num_node)) 8 | 9 | mkdir -p log 10 | 11 | port=$(( $RANDOM % 300 + 23450 )) 12 | 13 | GLOG_vmodule=MemcachedClient=-1 \ 14 | srun --partition=$partition \ 15 | --mpi=pmi2 -n$total_process \ 16 | --gres=gpu:$train_gpu \ 17 | --ntasks-per-node=$train_gpu \ 18 | --job-name=$job_name \ 19 | --kill-on-bad-exit=1 \ 20 | --cpus-per-task=7 \ 21 | python -u tools/agent_run.py $options 22 | -------------------------------------------------------------------------------- /sample/search.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ./sample/run.sh Train searching 8 4 "./config/mcea_searching.yaml" 3 | -------------------------------------------------------------------------------- /sample/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ./sample/run.sh Train training 8 4 "./config/training_super_supernet.yaml" 3 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luminolx/ScaleNet/22487c892010aa86316fe6d0762723181db58cc8/tools/__init__.py -------------------------------------------------------------------------------- /tools/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luminolx/ScaleNet/22487c892010aa86316fe6d0762723181db58cc8/tools/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /tools/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luminolx/ScaleNet/22487c892010aa86316fe6d0762723181db58cc8/tools/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /tools/__pycache__/dist_init.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luminolx/ScaleNet/22487c892010aa86316fe6d0762723181db58cc8/tools/__pycache__/dist_init.cpython-36.pyc -------------------------------------------------------------------------------- /tools/__pycache__/dist_init.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luminolx/ScaleNet/22487c892010aa86316fe6d0762723181db58cc8/tools/__pycache__/dist_init.cpython-37.pyc -------------------------------------------------------------------------------- /tools/agent_run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append(os.getcwd()) 4 | import yaml 5 | from core.agent.nas_agent import NASAgent 6 | import torch 7 | import numpy as np 8 | 9 | 10 | if __name__ == '__main__': 11 | # manual seed 12 | seed = 42 13 | torch.manual_seed(seed) 14 | torch.cuda.manual_seed_all(seed) 15 | np.random.seed(seed) 16 | config = yaml.load(open(sys.argv[1], 'r'), Loader=yaml.FullLoader) 17 | agent = NASAgent(config) 18 | agent.run() 19 | -------------------------------------------------------------------------------- /tools/dist_init.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | 4 | def dist_init(port=23456): 5 | 6 | def init_parrots(host_addr, rank, local_rank, world_size, port): 7 | os.environ['MASTER_ADDR'] = str(host_addr) 8 | os.environ['MASTER_PORT'] = str(port) 9 | os.environ['WORLD_SIZE'] = str(world_size) 10 | os.environ['RANK'] = str(rank) 11 | torch.distributed.init_process_group(backend="nccl") 12 | torch.cuda.set_device(local_rank) 13 | 14 | def init(host_addr, rank, local_rank, world_size, port): 15 | host_addr_full = 'tcp://' + host_addr + ':' + str(port) 16 | torch.distributed.init_process_group("nccl", init_method=host_addr_full, 17 | rank=rank, world_size=world_size) 18 | torch.cuda.set_device(local_rank) 19 | assert torch.distributed.is_initialized() 20 | 21 | 22 | def parse_host_addr(s): 23 | if '[' in s: 24 | left_bracket = s.index('[') 25 | right_bracket = s.index(']') 26 | prefix = s[:left_bracket] 27 | first_number = s[left_bracket+1:right_bracket].split(',')[0].split('-')[0] 28 | return prefix + first_number 29 | else: 30 | return s 31 | 32 | rank = int(os.environ['SLURM_PROCID']) 33 | local_rank = int(os.environ['SLURM_LOCALID']) 34 | world_size = int(os.environ['SLURM_NTASKS']) 35 | 36 | ip = parse_host_addr(os.environ['SLURM_STEP_NODELIST']) 37 | 38 | if torch.__version__ == 'parrots': 39 | init_parrots(ip, rank, local_rank, world_size, port) 40 | else: 41 | init(ip, rank, local_rank, world_size, port) 42 | 43 | return rank, local_rank, world_size -------------------------------------------------------------------------------- /tools/eval/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luminolx/ScaleNet/22487c892010aa86316fe6d0762723181db58cc8/tools/eval/__init__.py -------------------------------------------------------------------------------- /tools/eval/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luminolx/ScaleNet/22487c892010aa86316fe6d0762723181db58cc8/tools/eval/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /tools/eval/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luminolx/ScaleNet/22487c892010aa86316fe6d0762723181db58cc8/tools/eval/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /tools/eval/__pycache__/base_tester.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luminolx/ScaleNet/22487c892010aa86316fe6d0762723181db58cc8/tools/eval/__pycache__/base_tester.cpython-36.pyc -------------------------------------------------------------------------------- /tools/eval/__pycache__/base_tester.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luminolx/ScaleNet/22487c892010aa86316fe6d0762723181db58cc8/tools/eval/__pycache__/base_tester.cpython-37.pyc -------------------------------------------------------------------------------- /tools/eval/__pycache__/build_tester.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luminolx/ScaleNet/22487c892010aa86316fe6d0762723181db58cc8/tools/eval/__pycache__/build_tester.cpython-36.pyc -------------------------------------------------------------------------------- /tools/eval/__pycache__/build_tester.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luminolx/ScaleNet/22487c892010aa86316fe6d0762723181db58cc8/tools/eval/__pycache__/build_tester.cpython-37.pyc -------------------------------------------------------------------------------- /tools/eval/base_tester.py: -------------------------------------------------------------------------------- 1 | 2 | from collections import defaultdict 3 | import os 4 | import torch 5 | 6 | class Info(): 7 | def __init__(self): 8 | super().__init__() 9 | self.best_avg = -1 10 | self._info = defaultdict(list) 11 | 12 | def add(self, task_names, value): 13 | self._info[task_names].append(value) 14 | 15 | def avg_(self, prefix, ckpt): 16 | total_task = 0 17 | total_avg = 0 18 | for k in self._info: 19 | avg = sum(self._info[k]) / len(self._info[k]) 20 | total_task += 1 21 | total_avg += avg 22 | if total_task != 0: 23 | if total_avg / total_task > self.best_avg: 24 | self.best_avg = total_avg / total_task 25 | ckpt['best_avg'] = self.best_avg 26 | torch.save(ckpt, os.path.join(prefix, 'best_ckpt.pth.tar')) 27 | self._info = defaultdict(list) 28 | 29 | class BaseTester(object): 30 | ''' Base Trainer class ''' 31 | 32 | def __init__(self, cfg_data, model, model_folder, model_name, 33 | dataloader_fun, cfg_searcher): 34 | self.cfg_data = cfg_data 35 | self.model = model 36 | self.set_model_path(model_name, model_folder) 37 | self.dataloader_fun = dataloader_fun 38 | self.cfg_searcher = cfg_searcher 39 | self.best_loaded = False 40 | self.info = Info() 41 | 42 | def test(self, subnet): 43 | raise RuntimeError('BaseTester cannot test') 44 | 45 | def load(self): 46 | raise RuntimeError('BaseTester cannot load search_space') 47 | 48 | def gen_dataloader(self): 49 | raise RuntimeError('BaseTester cannot generate dataloader') 50 | 51 | def set_model_path(self, model_name, model_folder=None): 52 | self.model_name = model_name 53 | if model_folder is not None: 54 | self.model_folder = model_folder 55 | self.model_loaded = False 56 | 57 | def eval_init(self): 58 | raise RuntimeError('BaseTester cannot init evaluation') 59 | 60 | def save_eval_result(self): 61 | raise RuntimeError('BaseTester cannot save evaluation result') 62 | 63 | def predict_single_img(self): 64 | raise RuntimeError('BaseTester cannot predict single image') 65 | -------------------------------------------------------------------------------- /tools/eval/build_tester.py: -------------------------------------------------------------------------------- 1 | from core.dataset.build_dataloader import build_dataloader 2 | from torch import distributed as dist 3 | from tools.eval.imagenet.tester import ImagenetTester 4 | import os 5 | 6 | 7 | def build_tester(cfg_stg, cfg_data, model, cfg_searcher): 8 | ''' Build tester and return ''' 9 | 10 | kwargs = {} 11 | kwargs['rank'] = dist.get_rank() 12 | kwargs['world_size'] = dist.get_world_size() 13 | task_type = cfg_stg.get('task_type', '') 14 | dataloader_func = build_dataloader 15 | 16 | model_folder = os.path.join(cfg_stg['save_path'], 'checkpoint') 17 | model_name = cfg_stg.get('load_name', None) 18 | 19 | if task_type in ['imagenet-test']: 20 | tester = ImagenetTester 21 | else: 22 | raise RuntimeError('Wrong task_type of {}, task_type musk be imagenet-test'.format(task_type)) 23 | 24 | # build tester 25 | final_tester = tester(cfg_data, model, model_folder, model_name, 26 | dataloader_func, cfg_searcher, **kwargs) 27 | return final_tester 28 | -------------------------------------------------------------------------------- /tools/eval/imagenet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luminolx/ScaleNet/22487c892010aa86316fe6d0762723181db58cc8/tools/eval/imagenet/__init__.py -------------------------------------------------------------------------------- /tools/eval/imagenet/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luminolx/ScaleNet/22487c892010aa86316fe6d0762723181db58cc8/tools/eval/imagenet/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /tools/eval/imagenet/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luminolx/ScaleNet/22487c892010aa86316fe6d0762723181db58cc8/tools/eval/imagenet/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /tools/eval/imagenet/__pycache__/tester.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luminolx/ScaleNet/22487c892010aa86316fe6d0762723181db58cc8/tools/eval/imagenet/__pycache__/tester.cpython-36.pyc -------------------------------------------------------------------------------- /tools/eval/imagenet/__pycache__/tester.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luminolx/ScaleNet/22487c892010aa86316fe6d0762723181db58cc8/tools/eval/imagenet/__pycache__/tester.cpython-37.pyc -------------------------------------------------------------------------------- /tools/eval/imagenet/tester.py: -------------------------------------------------------------------------------- 1 | from os.path import join, exists 2 | from core.utils.misc import get_cls_accuracy 3 | 4 | import torch 5 | from torch import distributed as dist 6 | from core.utils.misc import AverageMeter 7 | import core.dataset.build_dataloader as BD 8 | from tools.eval.base_tester import BaseTester 9 | 10 | 11 | class ImagenetTester(BaseTester): 12 | ''' Multi-Source Tester: test multi dataset one by one 13 | requires attrs: 14 | - in Base Tester 15 | (load) model_folder, model_name 16 | (config) cfg_data[with all datasets neet to be tested], cfg_stg[build dataloader] 17 | 18 | - in Customized Tester 19 | (dist) rank, world_size 20 | ''' 21 | 22 | def __init__(self, *args, **kwargs): 23 | super().__init__(*args) 24 | for k, v in kwargs.items(): 25 | setattr(self, k, v) 26 | self.required_atts = ('rank', 'world_size') 27 | for att in self.required_atts: 28 | if not hasattr(self, att): 29 | raise RuntimeError(f'ImagenetTester must has attr: {att}') 30 | self.dataloader = None 31 | if not self.model_loaded: 32 | self.load() 33 | self.model_loaded = True 34 | 35 | def test(self, subnet=None): 36 | # subnet: list, [op, ... , c_m, r] 37 | top1 = AverageMeter() 38 | top5 = AverageMeter() 39 | ''' test setted model in self ''' 40 | if not self.model_loaded: 41 | self.load() 42 | self.model_loaded = True 43 | if self.dataloader == None: 44 | self.dataloader = BD.DataPrefetcher(self.gen_dataloader()) 45 | dataloader = iter(self.dataloader) 46 | input, target = next(dataloader) 47 | if input is None: 48 | self.dataloader.reset_loader() 49 | input, target = next(dataloader) 50 | with torch.no_grad(): 51 | while input is not None: 52 | logits = self.model(input, subnet=subnet) 53 | if self.model.module.n > 1: 54 | for _ in range(1, self.model.module.n): 55 | logits += self.model(input, subnet=subnet) 56 | prec1, prec5 = get_cls_accuracy(logits, target, topk=(1, 5)) 57 | top1.update(prec1.item()) 58 | top5.update(prec5.item()) 59 | 60 | input, target = next(dataloader) 61 | 62 | print('==[rank{rank}]== Prec@1 {acc1.avg:.3f} Prec@5 {acc5.avg:.3f}'.format(rank=self.rank, 63 | acc1=top1, acc5=top5)) 64 | self.save_eval_result(top1.avg, top5.avg) 65 | return top1.avg 66 | 67 | def gen_dataloader(self): 68 | return self.dataloader_fun(self.cfg_data, self.cfg_searcher, is_test=True, world_size=self.world_size) 69 | 70 | def reduce_tensor(self, tensor): 71 | rt = tensor.clone() 72 | dist.all_reduce(rt, op=dist.reduce_op.SUM) 73 | rt /= self.world_size 74 | return rt 75 | 76 | def load(self): 77 | if self.model_loaded or self.model_name is None: 78 | return 79 | # load state_dict 80 | ckpt_path = join(self.model_folder, self.model_name) 81 | assert exists(ckpt_path), f'{ckpt_path} not exist.' 82 | if self.rank == 0: 83 | print(f'==[rank{self.rank}]==loading checkpoint from {ckpt_path}') 84 | 85 | def map_func(storage, location): 86 | return storage.cuda() 87 | 88 | ckpt = torch.load(ckpt_path, map_location=map_func) 89 | from collections import OrderedDict 90 | fixed_ckpt = OrderedDict() 91 | for k in ckpt['state_dict']: 92 | if 'head' in k: 93 | k1 = k.replace('classification_head', 'head') 94 | fixed_ckpt[k1] = ckpt['state_dict'][k] 95 | continue 96 | fixed_ckpt[k] = ckpt['state_dict'][k] 97 | ckpt['state_dict'] = fixed_ckpt 98 | self.model.load_state_dict(ckpt['state_dict'], strict=False) 99 | ckpt_keys = set(ckpt['state_dict'].keys()) 100 | own_keys = set(self.model.state_dict().keys()) 101 | missing_keys = own_keys - ckpt_keys 102 | 103 | if self.rank == 0: 104 | print(f'==[rank{self.rank}]==load model done.') 105 | 106 | 107 | def eval_init(self): 108 | self.eval_t = 0.0 109 | self.eval_f = 0.0 110 | 111 | def save_eval_result(self, acc1, acc5): 112 | result_line = f'ckpt: {self.model_name}\tacc1: {acc1:.4f}\tacc5: {acc5:.4f}' 113 | print(f'==[rank{self.rank}]=={result_line}') 114 | -------------------------------------------------------------------------------- /tools/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .build_trainer import build_trainer 2 | -------------------------------------------------------------------------------- /tools/trainer/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luminolx/ScaleNet/22487c892010aa86316fe6d0762723181db58cc8/tools/trainer/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /tools/trainer/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luminolx/ScaleNet/22487c892010aa86316fe6d0762723181db58cc8/tools/trainer/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /tools/trainer/__pycache__/base_trainer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luminolx/ScaleNet/22487c892010aa86316fe6d0762723181db58cc8/tools/trainer/__pycache__/base_trainer.cpython-36.pyc -------------------------------------------------------------------------------- /tools/trainer/__pycache__/base_trainer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luminolx/ScaleNet/22487c892010aa86316fe6d0762723181db58cc8/tools/trainer/__pycache__/base_trainer.cpython-37.pyc -------------------------------------------------------------------------------- /tools/trainer/__pycache__/build_trainer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luminolx/ScaleNet/22487c892010aa86316fe6d0762723181db58cc8/tools/trainer/__pycache__/build_trainer.cpython-36.pyc -------------------------------------------------------------------------------- /tools/trainer/__pycache__/build_trainer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luminolx/ScaleNet/22487c892010aa86316fe6d0762723181db58cc8/tools/trainer/__pycache__/build_trainer.cpython-37.pyc -------------------------------------------------------------------------------- /tools/trainer/base_trainer.py: -------------------------------------------------------------------------------- 1 | 2 | class BaseTrainer(object): 3 | ''' Base Trainer class ''' 4 | 5 | def __init__(self, dataloader, model, optimizer, 6 | lr_scheduler, print_freq, 7 | save_path, snapshot, logger): 8 | self.dataloader = dataloader 9 | self.model = model 10 | self.optimizer = optimizer 11 | self.lr_scheduler = lr_scheduler 12 | self.print_freq = print_freq 13 | self.snapshot_freq = snapshot 14 | self.save_path = save_path 15 | self.logger = logger 16 | self.cur_iter = 1 17 | self.cur_epoch = 1 18 | 19 | def train(self): 20 | raise RuntimeError('BaseTrainer cannot train') 21 | 22 | def save(self): 23 | raise RuntimeError('BaseTrainer cannot save search_space') 24 | 25 | def load(self, ckpt_path): 26 | raise RuntimeError('BaseTrainer cannot load search_space and optimizer') 27 | 28 | def show_attr(self): 29 | for name, value in vars(self).items(): 30 | print(name, value) 31 | 32 | def show_task_key(self): 33 | print(self.task_key) 34 | -------------------------------------------------------------------------------- /tools/trainer/build_trainer.py: -------------------------------------------------------------------------------- 1 | from torch import distributed as dist 2 | from core.utils.misc import AverageMeter 3 | from core.utils.logger import create_logger 4 | from tools.trainer.imagenet.trainer import ImagenetTrainer 5 | 6 | 7 | def build_trainer(cfg_stg, dataloader, model, optimizer, lr_scheduler, now): 8 | ''' Build trainer and return ''' 9 | # choose trainer function 10 | kwargs = {} 11 | kwargs['rank'] = dist.get_rank() 12 | kwargs['world_size'] = dist.get_world_size() 13 | kwargs['max_iter'] = cfg_stg['max_iter'] 14 | kwargs['quantization'] = cfg_stg.get('quantization', None) 15 | print_freq = cfg_stg.get('print_freq', 20) 16 | kwargs['data_time'] = AverageMeter(length=print_freq) 17 | kwargs['forw_time'] = AverageMeter(length=print_freq) 18 | kwargs['bckw_time'] = AverageMeter(length=print_freq) 19 | kwargs['step_time'] = AverageMeter(length=print_freq) 20 | kwargs['batch_time'] = AverageMeter() 21 | kwargs['mixed_training'] = cfg_stg.get('mixed_training', False) 22 | if cfg_stg['task_type'] in ['imagenet']: 23 | trainer = ImagenetTrainer 24 | kwargs['disp_loss'] = AverageMeter(length=print_freq) 25 | kwargs['disp_acc_top1'] = AverageMeter(length=print_freq) 26 | kwargs['disp_acc_top5'] = AverageMeter(length=print_freq) 27 | kwargs['task_has_accuracy'] = True #search_space.head.task_has_accuracy 28 | else: 29 | raise RuntimeError('task_type {} invalid, must be in imagenet'.format(cfg_stg['task_type'])) 30 | 31 | if now != '': 32 | now = '_' + now 33 | 34 | # build logger 35 | if cfg_stg['task_type'] in ['verify']: 36 | logger = create_logger('global_logger', 37 | '{}/log/log_task{}_train{}.txt'.format(cfg_stg['save_path'], now, 38 | model.task_id)) 39 | # TRACKING_TIP 40 | elif cfg_stg['task_type'] in ['attribute', 'gaze', 'imagenet', 'tracking', 'smoking']: 41 | logger = create_logger('', 42 | '{}/log/'.format(cfg_stg['save_path']) + '/log_train{}.txt'.format(now)) 43 | else: 44 | raise RuntimeError('task_type musk be in verify/attribute/gaze/imagenet/tracking') 45 | 46 | # build trainer 47 | final_trainer = trainer(dataloader, model, optimizer, lr_scheduler, print_freq, 48 | cfg_stg['save_path'] + '/checkpoint', 49 | cfg_stg.get('snapshot_freq', 5000), logger, **kwargs) 50 | return final_trainer 51 | -------------------------------------------------------------------------------- /tools/trainer/imagenet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luminolx/ScaleNet/22487c892010aa86316fe6d0762723181db58cc8/tools/trainer/imagenet/__init__.py -------------------------------------------------------------------------------- /tools/trainer/imagenet/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luminolx/ScaleNet/22487c892010aa86316fe6d0762723181db58cc8/tools/trainer/imagenet/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /tools/trainer/imagenet/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luminolx/ScaleNet/22487c892010aa86316fe6d0762723181db58cc8/tools/trainer/imagenet/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /tools/trainer/imagenet/__pycache__/trainer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luminolx/ScaleNet/22487c892010aa86316fe6d0762723181db58cc8/tools/trainer/imagenet/__pycache__/trainer.cpython-36.pyc -------------------------------------------------------------------------------- /tools/trainer/imagenet/__pycache__/trainer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luminolx/ScaleNet/22487c892010aa86316fe6d0762723181db58cc8/tools/trainer/imagenet/__pycache__/trainer.cpython-37.pyc -------------------------------------------------------------------------------- /tools/trainer/imagenet/trainer.py: -------------------------------------------------------------------------------- 1 | import time 2 | from os.path import join, exists 3 | 4 | import torch 5 | from torch import distributed as dist 6 | import core.dataset.build_dataloader as BD 7 | from tools.trainer.base_trainer import BaseTrainer 8 | 9 | try: 10 | from apex import amp 11 | has_apex = True 12 | except ImportError: 13 | has_apex = False 14 | 15 | 16 | def accuracy(output, target, topk=(1,)): 17 | """Computes the precision@k for the specified values of k""" 18 | maxk = max(topk) 19 | batch_size = target.size(0) 20 | 21 | _, pred = output.topk(maxk, 1, True, True) 22 | pred = pred.t() 23 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 24 | 25 | res = [] 26 | for k in topk: 27 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 28 | res.append(correct_k.mul_(100.0 / batch_size)) 29 | return res 30 | 31 | 32 | class ImagenetTrainer(BaseTrainer): 33 | ''' Imagenet Trainer 34 | requires attrs: 35 | - in Base Trainer 36 | (train) search_space, optimizer, lr_scheduler, dataloader, cur_iter 37 | (log) logger 38 | (save) print_freq, snapshot_freq, save_path 39 | 40 | - in Customized Trainer 41 | (dist) rank, world_size 42 | (train) max_iter 43 | (time) data_time, forw_time, batch_time 44 | (loss&acc) _disp_loss, _disp_acc 45 | (task) task_key, task_training_shapes 46 | ''' 47 | 48 | def __init__(self, *args, **kwargs): 49 | super().__init__(*args) 50 | for k, v in kwargs.items(): 51 | setattr(self, k, v) 52 | # check customized trainer has all required attrs 53 | self.required_atts = ('rank', 'world_size', 'max_iter', 54 | 'data_time', 'forw_time', 'batch_time', 'bckw_time', 'step_time') 55 | for att in self.required_atts: 56 | if not hasattr(self, att): 57 | raise RuntimeError(f'ImagenetTrainer must has attr: {att}') 58 | self.task_key = 'classification' 59 | self.logger.info("task key: %s" % (self.task_key)) 60 | if not hasattr(self, 'disp_acc_top1'): 61 | raise RuntimeError(f'ImagenetTrainer must has attr: disp_acc_top1') 62 | if not hasattr(self, 'disp_acc_top5'): 63 | raise RuntimeError(f'ImagenetTrainer must has attr: disp_acc_top5') 64 | self.logger.info(f'[rank{self.rank}]ImagenetTrainer build done.') 65 | if self.rank == 0: 66 | self.logger.info(self.model) 67 | 68 | def train(self): 69 | self.model.train() 70 | if self.rank == 0: 71 | self.logger.info('Start training...') 72 | self.logger.info(f'Loading classification data') 73 | loader_iter = iter(BD.DataPrefetcher(self.dataloader)) 74 | input_all = {} 75 | end_time = time.time() 76 | self.optimizer.zero_grad() 77 | while self.cur_iter <= self.max_iter: 78 | self.lr_scheduler.step(self.cur_iter) 79 | tmp_time = time.time() 80 | images, target = next(loader_iter) 81 | flag_epoch_end = False 82 | if images is not None and isinstance(images, dict): 83 | for idx in images.keys(): 84 | if images[idx] is None: 85 | flag_epoch_end = True 86 | break 87 | if images is None or flag_epoch_end: 88 | epoch = int(self.cur_iter / len(self.dataloader)) 89 | if self.rank == 0: 90 | self.logger.info('classification epoch-{} done at iter-{}'.format(epoch, 91 | self.cur_iter)) 92 | self.dataloader.sampler.set_epoch(epoch) 93 | loader_iter = iter(BD.DataPrefetcher(self.dataloader)) 94 | images, target = loader_iter.next() 95 | 96 | input_all['images'] = images 97 | input_all['labels'] = target 98 | 99 | self.data_time.update(time.time() - tmp_time) 100 | tmp_time = time.time() 101 | output = self.model(input_all, c_iter=self.cur_iter) 102 | self.forw_time.update(time.time() - tmp_time) 103 | 104 | loss = output['loss'] 105 | if not self.model.module.asyn and self.model.module.n > 0: 106 | loss /= self.model.module.n 107 | reduced_loss = loss.data.clone() / self.world_size 108 | dist.all_reduce(reduced_loss) 109 | self.disp_loss.update(reduced_loss.item()) 110 | if self.task_has_accuracy: 111 | prec1, prec5 = output['accuracy'] 112 | reduced_prec1 = prec1.clone() / self.world_size 113 | dist.all_reduce(reduced_prec1) 114 | reduced_prec5 = prec5.clone() / self.world_size 115 | dist.all_reduce(reduced_prec5) 116 | self.disp_acc_top1.update(reduced_prec1.item()) 117 | self.disp_acc_top5.update(reduced_prec5.item()) 118 | 119 | tmp_time = time.time() 120 | if has_apex: 121 | with amp.scale_loss(loss, self.optimizer) as scaled_loss: 122 | scaled_loss.backward() 123 | else: 124 | loss.backward() 125 | self.bckw_time.update(time.time() - tmp_time) 126 | # torch.cuda.empty_cache() 127 | if not self.model.module.asyn and self.model.module.n > 0: 128 | for _ in range(1, self.model.module.n): 129 | output = self.model(input_all, subnet=output['subnet'], c_iter=self.cur_iter) 130 | loss = output['loss'] / self.model.module.n 131 | reduced_loss = loss.data.clone() / self.world_size 132 | dist.all_reduce(reduced_loss) 133 | self.disp_loss.update(reduced_loss.item()) 134 | if has_apex: 135 | with amp.scale_loss(loss, self.optimizer) as scaled_loss: 136 | scaled_loss.backward() 137 | else: 138 | loss.backward() 139 | 140 | tmp_time = time.time() 141 | self.optimizer.step() 142 | self.step_time.update(time.time() - tmp_time) 143 | self.optimizer.zero_grad() 144 | self.batch_time.update(time.time() - end_time) 145 | end_time = time.time() 146 | # vis loss 147 | if self.cur_iter % self.print_freq == 0 and self.rank == 0: 148 | self.logger.info('Iter: [{0}/{1}] ' 149 | 'BatchTime {batch_time.avg:.4f} | ' 150 | 'DataTime {data_time.avg:.4f} | ' 151 | 'ForwardTime {forw_time.avg:.4f} | ' 152 | 'BackwardTime {bckw_time.avg:.4f} | ' 153 | 'StepTime {step_time.avg:.4f} | ' 154 | 'Total {batch_time.all:.2f} hrs | ' 155 | 'Loss {loss.avg:.4f} | ' 156 | 'Prec@1 {top1.avg:.4f} | ' 157 | 'Prec@5 {top5.avg:.4f} | ' 158 | 'LR {lr:.6f} | ETA {eta:.2f} hrs'.format( 159 | self.cur_iter, self.max_iter, 160 | batch_time=self.batch_time, 161 | data_time=self.data_time, 162 | forw_time=self.forw_time, 163 | bckw_time=self.bckw_time, 164 | step_time=self.step_time, 165 | loss=self.disp_loss, 166 | top1=self.disp_acc_top1, 167 | top5=self.disp_acc_top5, 168 | lr=self.lr_scheduler.get_lr()[0], 169 | eta=self.batch_time.avg * (self.max_iter - self.cur_iter) / 3600)) 170 | 171 | # save search_space 172 | if self.cur_iter % self.snapshot_freq == 0: 173 | if self.rank == 0: 174 | self.save() 175 | self.cur_iter += 1 176 | 177 | # finish training 178 | self.logger.info('Finish training {} iterations.'.format(self.cur_iter)) 179 | 180 | def save(self): 181 | ''' save search_space ''' 182 | path = join(self.save_path, 'iter_{}_ckpt.pth.tar'.format(self.cur_iter)) 183 | latest_path = join(self.save_path, 'latest.pth.tar') 184 | if has_apex: 185 | torch.save({'step': self.cur_iter, 'state_dict': self.model.state_dict(), 186 | 'optimizer': self.optimizer.state_dict(), 'amp': amp.state_dict()}, latest_path) 187 | else: 188 | torch.save({'step': self.cur_iter, 'state_dict': self.model.state_dict(), 189 | 'optimizer': self.optimizer.state_dict()}, latest_path) 190 | self.logger.info('[rank{}]Saved search_space to {}.'.format(self.rank, latest_path)) 191 | 192 | if self.cur_iter % 10000 == 0: 193 | if has_apex: 194 | torch.save({'step': self.cur_iter, 'state_dict': self.model.state_dict(), 195 | 'optimizer': self.optimizer.state_dict(), 'amp': amp.state_dict()}, path) 196 | else: 197 | torch.save({'step': self.cur_iter, 'state_dict': self.model.state_dict(), 198 | 'optimizer': self.optimizer.state_dict()}, path) 199 | self.logger.info('[rank{}]Saved search_space to {}.'.format(self.rank, path)) 200 | 201 | def load(self, ckpt_path): 202 | ''' load search_space and optimizer ''' 203 | 204 | def map_func(storage, location): 205 | return storage.cuda() 206 | 207 | assert exists(ckpt_path), f'{ckpt_path} not exist.' 208 | ckpt = torch.load(ckpt_path, map_location=map_func) 209 | self.model.load_state_dict(ckpt['state_dict'], strict=False) 210 | ckpt_keys = set(ckpt['state_dict'].keys()) 211 | own_keys = set(self.model.state_dict().keys()) 212 | missing_keys = own_keys - ckpt_keys 213 | for k in missing_keys: 214 | if self.rank == 0: 215 | self.logger.info(f'**missing key while loading search_space**: {k}') 216 | raise RuntimeError(f'**missing key while loading search_space**: {k}') 217 | 218 | # load optimizer 219 | self.cur_iter = ckpt['step'] + 1 220 | epoch = int(self.cur_iter / len(self.dataloader)) 221 | self.dataloader.sampler.set_epoch(epoch) 222 | if self.rank == 0: 223 | self.logger.info('load [resume] search_space done, ' 224 | f'current iter is {self.cur_iter}') 225 | 226 | # load amp 227 | if has_apex and 'amp' in ckpt: 228 | amp.load_state_dict(ckpt['amp']) 229 | --------------------------------------------------------------------------------