├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── configs ├── eval_alphanet_models.yml ├── parallel_supernet_evo_search.yml └── train_alphanet_models.yml ├── loss_ops.py ├── parallel_supernet_evo_search.py ├── test_alphanet.py └── train_alphanet.py /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | Facebook has adopted a Code of Conduct that we expect project participants to adhere to. 4 | Please read the [full text](https://code.fb.com/codeofconduct/) 5 | so that you can understand what actions will and will not be tolerated. 6 | 7 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to AlphaNet 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Our Development Process 6 | Minor changes and improvements will be released on an ongoing basis. Larger changes (e.g., changesets implementing a new paper) will be released on a more periodic basis. 7 | 8 | ## Pull Requests 9 | We actively welcome your pull requests. 10 | 11 | 1. Fork the repo and create your branch from `master`. 12 | 2. If you've added code that should be tested, add tests. 13 | 3. If you've changed APIs, update the documentation. 14 | 4. Ensure the test suite passes. 15 | 5. Make sure your code lints. 16 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 17 | 18 | ## Contributor License Agreement ("CLA") 19 | In order to accept your pull request, we need you to submit a CLA. You only need 20 | to do this once to work on any of Facebook's open source projects. 21 | 22 | Complete your CLA here: 23 | 24 | ## Issues 25 | We use GitHub issues to track public bugs. Please ensure your description is 26 | clear and has sufficient instructions to be able to reproduce the issue. 27 | 28 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 29 | disclosure of security bugs. In those cases, please go through the process 30 | outlined on that page and do not file a public issue. 31 | 32 | ## License 33 | By contributing to AlphaNet, you agree that your contributions will be licensed 34 | under the LICENSE file in the root directory of this source tree. 35 | 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More_considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial 4.0 International Public 58 | License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial 4.0 International Public License ("Public 63 | License"). To the extent this Public License may be interpreted as a 64 | contract, You are granted the Licensed Rights in consideration of Your 65 | acceptance of these terms and conditions, and the Licensor grants You 66 | such rights in consideration of benefits the Licensor receives from 67 | making the Licensed Material available under these terms and 68 | conditions. 69 | 70 | Section 1 -- Definitions. 71 | 72 | a. Adapted Material means material subject to Copyright and Similar 73 | Rights that is derived from or based upon the Licensed Material 74 | and in which the Licensed Material is translated, altered, 75 | arranged, transformed, or otherwise modified in a manner requiring 76 | permission under the Copyright and Similar Rights held by the 77 | Licensor. For purposes of this Public License, where the Licensed 78 | Material is a musical work, performance, or sound recording, 79 | Adapted Material is always produced where the Licensed Material is 80 | synched in timed relation with a moving image. 81 | 82 | b. Adapter's License means the license You apply to Your Copyright 83 | and Similar Rights in Your contributions to Adapted Material in 84 | accordance with the terms and conditions of this Public License. 85 | 86 | c. Copyright and Similar Rights means copyright and/or similar rights 87 | closely related to copyright including, without limitation, 88 | performance, broadcast, sound recording, and Sui Generis Database 89 | Rights, without regard to how the rights are labeled or 90 | categorized. For purposes of this Public License, the rights 91 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 92 | Rights. 93 | d. Effective Technological Measures means those measures that, in the 94 | absence of proper authority, may not be circumvented under laws 95 | fulfilling obligations under Article 11 of the WIPO Copyright 96 | Treaty adopted on December 20, 1996, and/or similar international 97 | agreements. 98 | 99 | e. Exceptions and Limitations means fair use, fair dealing, and/or 100 | any other exception or limitation to Copyright and Similar Rights 101 | that applies to Your use of the Licensed Material. 102 | 103 | f. Licensed Material means the artistic or literary work, database, 104 | or other material to which the Licensor applied this Public 105 | License. 106 | 107 | g. Licensed Rights means the rights granted to You subject to the 108 | terms and conditions of this Public License, which are limited to 109 | all Copyright and Similar Rights that apply to Your use of the 110 | Licensed Material and that the Licensor has authority to license. 111 | 112 | h. Licensor means the individual(s) or entity(ies) granting rights 113 | under this Public License. 114 | 115 | i. NonCommercial means not primarily intended for or directed towards 116 | commercial advantage or monetary compensation. For purposes of 117 | this Public License, the exchange of the Licensed Material for 118 | other material subject to Copyright and Similar Rights by digital 119 | file-sharing or similar means is NonCommercial provided there is 120 | no payment of monetary compensation in connection with the 121 | exchange. 122 | 123 | j. Share means to provide material to the public by any means or 124 | process that requires permission under the Licensed Rights, such 125 | as reproduction, public display, public performance, distribution, 126 | dissemination, communication, or importation, and to make material 127 | available to the public including in ways that members of the 128 | public may access the material from a place and at a time 129 | individually chosen by them. 130 | 131 | k. Sui Generis Database Rights means rights other than copyright 132 | resulting from Directive 96/9/EC of the European Parliament and of 133 | the Council of 11 March 1996 on the legal protection of databases, 134 | as amended and/or succeeded, as well as other essentially 135 | equivalent rights anywhere in the world. 136 | 137 | l. You means the individual or entity exercising the Licensed Rights 138 | under this Public License. Your has a corresponding meaning. 139 | 140 | Section 2 -- Scope. 141 | 142 | a. License grant. 143 | 144 | 1. Subject to the terms and conditions of this Public License, 145 | the Licensor hereby grants You a worldwide, royalty-free, 146 | non-sublicensable, non-exclusive, irrevocable license to 147 | exercise the Licensed Rights in the Licensed Material to: 148 | 149 | a. reproduce and Share the Licensed Material, in whole or 150 | in part, for NonCommercial purposes only; and 151 | 152 | b. produce, reproduce, and Share Adapted Material for 153 | NonCommercial purposes only. 154 | 155 | 2. Exceptions and Limitations. For the avoidance of doubt, where 156 | Exceptions and Limitations apply to Your use, this Public 157 | License does not apply, and You do not need to comply with 158 | its terms and conditions. 159 | 160 | 3. Term. The term of this Public License is specified in Section 161 | 6(a). 162 | 163 | 4. Media and formats; technical modifications allowed. The 164 | Licensor authorizes You to exercise the Licensed Rights in 165 | all media and formats whether now known or hereafter created, 166 | and to make technical modifications necessary to do so. The 167 | Licensor waives and/or agrees not to assert any right or 168 | authority to forbid You from making technical modifications 169 | necessary to exercise the Licensed Rights, including 170 | technical modifications necessary to circumvent Effective 171 | Technological Measures. For purposes of this Public License, 172 | simply making modifications authorized by this Section 2(a) 173 | (4) never produces Adapted Material. 174 | 175 | 5. Downstream recipients. 176 | 177 | a. Offer from the Licensor -- Licensed Material. Every 178 | recipient of the Licensed Material automatically 179 | receives an offer from the Licensor to exercise the 180 | Licensed Rights under the terms and conditions of this 181 | Public License. 182 | 183 | b. No downstream restrictions. You may not offer or impose 184 | any additional or different terms or conditions on, or 185 | apply any Effective Technological Measures to, the 186 | Licensed Material if doing so restricts exercise of the 187 | Licensed Rights by any recipient of the Licensed 188 | Material. 189 | 190 | 6. No endorsement. Nothing in this Public License constitutes or 191 | may be construed as permission to assert or imply that You 192 | are, or that Your use of the Licensed Material is, connected 193 | with, or sponsored, endorsed, or granted official status by, 194 | the Licensor or others designated to receive attribution as 195 | provided in Section 3(a)(1)(A)(i). 196 | 197 | b. Other rights. 198 | 199 | 1. Moral rights, such as the right of integrity, are not 200 | licensed under this Public License, nor are publicity, 201 | privacy, and/or other similar personality rights; however, to 202 | the extent possible, the Licensor waives and/or agrees not to 203 | assert any such rights held by the Licensor to the limited 204 | extent necessary to allow You to exercise the Licensed 205 | Rights, but not otherwise. 206 | 207 | 2. Patent and trademark rights are not licensed under this 208 | Public License. 209 | 210 | 3. To the extent possible, the Licensor waives any right to 211 | collect royalties from You for the exercise of the Licensed 212 | Rights, whether directly or through a collecting society 213 | under any voluntary or waivable statutory or compulsory 214 | licensing scheme. In all other cases the Licensor expressly 215 | reserves any right to collect such royalties, including when 216 | the Licensed Material is used other than for NonCommercial 217 | purposes. 218 | 219 | Section 3 -- License Conditions. 220 | 221 | Your exercise of the Licensed Rights is expressly made subject to the 222 | following conditions. 223 | 224 | a. Attribution. 225 | 226 | 1. If You Share the Licensed Material (including in modified 227 | form), You must: 228 | 229 | a. retain the following if it is supplied by the Licensor 230 | with the Licensed Material: 231 | 232 | i. identification of the creator(s) of the Licensed 233 | Material and any others designated to receive 234 | attribution, in any reasonable manner requested by 235 | the Licensor (including by pseudonym if 236 | designated); 237 | 238 | ii. a copyright notice; 239 | 240 | iii. a notice that refers to this Public License; 241 | 242 | iv. a notice that refers to the disclaimer of 243 | warranties; 244 | 245 | v. a URI or hyperlink to the Licensed Material to the 246 | extent reasonably practicable; 247 | 248 | b. indicate if You modified the Licensed Material and 249 | retain an indication of any previous modifications; and 250 | 251 | c. indicate the Licensed Material is licensed under this 252 | Public License, and include the text of, or the URI or 253 | hyperlink to, this Public License. 254 | 255 | 2. You may satisfy the conditions in Section 3(a)(1) in any 256 | reasonable manner based on the medium, means, and context in 257 | which You Share the Licensed Material. For example, it may be 258 | reasonable to satisfy the conditions by providing a URI or 259 | hyperlink to a resource that includes the required 260 | information. 261 | 262 | 3. If requested by the Licensor, You must remove any of the 263 | information required by Section 3(a)(1)(A) to the extent 264 | reasonably practicable. 265 | 266 | 4. If You Share Adapted Material You produce, the Adapter's 267 | License You apply must not prevent recipients of the Adapted 268 | Material from complying with this Public License. 269 | 270 | Section 4 -- Sui Generis Database Rights. 271 | 272 | Where the Licensed Rights include Sui Generis Database Rights that 273 | apply to Your use of the Licensed Material: 274 | 275 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 276 | to extract, reuse, reproduce, and Share all or a substantial 277 | portion of the contents of the database for NonCommercial purposes 278 | only; 279 | 280 | b. if You include all or a substantial portion of the database 281 | contents in a database in which You have Sui Generis Database 282 | Rights, then the database in which You have Sui Generis Database 283 | Rights (but not its individual contents) is Adapted Material; and 284 | 285 | c. You must comply with the conditions in Section 3(a) if You Share 286 | all or a substantial portion of the contents of the database. 287 | 288 | For the avoidance of doubt, this Section 4 supplements and does not 289 | replace Your obligations under this Public License where the Licensed 290 | Rights include other Copyright and Similar Rights. 291 | 292 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 293 | 294 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 295 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 296 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 297 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 298 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 299 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 300 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 301 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 302 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 303 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 304 | 305 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 306 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 307 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 308 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 309 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 310 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 311 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 312 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 313 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 314 | 315 | c. The disclaimer of warranties and limitation of liability provided 316 | above shall be interpreted in a manner that, to the extent 317 | possible, most closely approximates an absolute disclaimer and 318 | waiver of all liability. 319 | 320 | Section 6 -- Term and Termination. 321 | 322 | a. This Public License applies for the term of the Copyright and 323 | Similar Rights licensed here. However, if You fail to comply with 324 | this Public License, then Your rights under this Public License 325 | terminate automatically. 326 | 327 | b. Where Your right to use the Licensed Material has terminated under 328 | Section 6(a), it reinstates: 329 | 330 | 1. automatically as of the date the violation is cured, provided 331 | it is cured within 30 days of Your discovery of the 332 | violation; or 333 | 334 | 2. upon express reinstatement by the Licensor. 335 | 336 | For the avoidance of doubt, this Section 6(b) does not affect any 337 | right the Licensor may have to seek remedies for Your violations 338 | of this Public License. 339 | 340 | c. For the avoidance of doubt, the Licensor may also offer the 341 | Licensed Material under separate terms or conditions or stop 342 | distributing the Licensed Material at any time; however, doing so 343 | will not terminate this Public License. 344 | 345 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 346 | License. 347 | 348 | Section 7 -- Other Terms and Conditions. 349 | 350 | a. The Licensor shall not be bound by any additional or different 351 | terms or conditions communicated by You unless expressly agreed. 352 | 353 | b. Any arrangements, understandings, or agreements regarding the 354 | Licensed Material not stated herein are separate from and 355 | independent of the terms and conditions of this Public License. 356 | 357 | Section 8 -- Interpretation. 358 | 359 | a. For the avoidance of doubt, this Public License does not, and 360 | shall not be interpreted to, reduce, limit, restrict, or impose 361 | conditions on any use of the Licensed Material that could lawfully 362 | be made without permission under this Public License. 363 | 364 | b. To the extent possible, if any provision of this Public License is 365 | deemed unenforceable, it shall be automatically reformed to the 366 | minimum extent necessary to make it enforceable. If the provision 367 | cannot be reformed, it shall be severed from this Public License 368 | without affecting the enforceability of the remaining terms and 369 | conditions. 370 | 371 | c. No term or condition of this Public License will be waived and no 372 | failure to comply consented to unless expressly agreed to by the 373 | Licensor. 374 | 375 | d. Nothing in this Public License constitutes or may be interpreted 376 | as a limitation upon, or waiver of, any privileges and immunities 377 | that apply to the Licensor or You, including from the legal 378 | processes of any jurisdiction or authority. 379 | 380 | ======================================================================= 381 | 382 | Creative Commons is not a party to its public 383 | licenses. Notwithstanding, Creative Commons may elect to apply one of 384 | its public licenses to material it publishes and in those instances 385 | will be considered the “Licensor.” The text of the Creative Commons 386 | public licenses is dedicated to the public domain under the CC0 Public 387 | Domain Dedication. Except for the limited purpose of indicating that 388 | material is shared under a Creative Commons public license or as 389 | otherwise permitted by the Creative Commons policies published at 390 | creativecommons.org/policies, Creative Commons does not authorize the 391 | use of the trademark "Creative Commons" or any other trademark or logo 392 | of Creative Commons without its prior written consent including, 393 | without limitation, in connection with any unauthorized modifications 394 | to any of its public licenses or any other arrangements, 395 | understandings, or agreements concerning use of licensed material. For 396 | the avoidance of doubt, this paragraph does not form part of the 397 | public licenses. 398 | 399 | Creative Commons may be contacted at creativecommons.org. 400 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AlphaNet: Improved Training of Supernet with Alpha-Divergence 2 | This repository contains our PyTorch training code, evaluation code and pretrained models for AlphaNet. 3 | 4 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/alphanet-improved-training-of-supernet-with/neural-architecture-search-on-imagenet)](https://paperswithcode.com/sota/neural-architecture-search-on-imagenet?p=alphanet-improved-training-of-supernet-with) 5 | 6 | Our implementation is largely based on [AttentiveNAS](https://arxiv.org/pdf/2011.09011.pdf). 7 | To reproduce our results, please first download the [AttentiveNAS repo](https://github.com/facebookresearch/AttentiveNAS), and use our *train\_alphanet.py* for training and *test\_alphanet.py* for testing. 8 | 9 | For more details, please see [AlphaNet: Improved Training of Supernet with Alpha-Divergence](https://arxiv.org/pdf/2102.07954.pdf) by Dilin Wang, Chengyue Gong, Meng Li, Qiang Liu, Vikas Chandra. 10 | 11 | If you find this repo useful in your research, please consider citing our work and [AttentiveNAS](https://arxiv.org/pdf/2011.09011.pdf): 12 | 13 | ```BibTex 14 | @article{wang2021alphanet, 15 | title={AlphaNet: Improved Training of Supernet with Alpha-Divergence}, 16 | author={Wang, Dilin and Gong, Chengyue and Li, Meng and Liu, Qiang and Chandra, Vikas}, 17 | journal={arXiv preprint arXiv:2102.07954}, 18 | year={2021} 19 | } 20 | 21 | @article{wang2020attentivenas, 22 | title={AttentiveNAS: Improving Neural Architecture Search via Attentive Sampling}, 23 | author={Wang, Dilin and Li, Meng and Gong, Chengyue and Chandra, Vikas}, 24 | journal={arXiv preprint arXiv:2011.09011}, 25 | year={2020} 26 | } 27 | ``` 28 | 29 | ## Evaluation 30 | To reproduce our results: 31 | - Please first download our [pretrained AlphaNet models](https://drive.google.com/file/d/1CyZoPyiCoGJ0qv8bqi7s7TQRUum_8FeG/view?usp=sharing) from a Google Drive path and put the pretrained models under your local folder *./alphanet_data* 32 | 33 | - To evaluate our pre-trained AlphaNet models, from AlphaNet-A0 to A6, on ImageNet with a single GPU, please run: 34 | 35 | ```python 36 | python test_alphanet.py --config-file ./configs/eval_alphanet_models.yml --model a[0-6] 37 | ``` 38 | 39 | Expected results: 40 | 41 | | Name | MFLOPs | Top-1 (%) | 42 | | :------------ |:---------------:| -----:| 43 | | AlphaNet-A0 | 203 | 77.87 | 44 | | AlphaNet-A1 | 279 | 78.94 | 45 | | AlphaNet-A2 | 317 | 79.20 | 46 | | AlphaNet-A3 | 357 | 79.41 | 47 | | AlphaNet-A4 | 444 | 80.01 | 48 | | AlphaNet-A5 (small) | 491 | 80.29 | 49 | | AlphaNet-A5 (base) | 596 | 80.62 | 50 | | AlphaNet-A6 | 709 | 80.78 | 51 | 52 | - Additionally, [here](https://drive.google.com/file/d/1NgZhJy8MJnuxjXkJ0gfnBGyrUVYwbAmx/view?usp=sharing) is our pretrained supernet with KL based inplace-KD and [here](https://drive.google.com/file/d/1rj1opDnlBD2_8ZV--LUSn8HXWfhiMdu8/view?usp=sharing) is our pretrained supernet without inplace-KD. 53 | 54 | ## Training 55 | To train our AlphaNet models from scratch, please run: 56 | ```python 57 | python train_alphanet.py --config-file configs/train_alphanet_models.yml --machine-rank ${machine_rank} --num-machines ${num_machines} --dist-url ${dist_url} 58 | ``` 59 | We adopt SGD training on 64 GPUs. The mini-batch size is 32 per GPU; all training hyper-parameters are specified in [train_alphanet_models.yml](configs/train_alphanet_models.yml). 60 | 61 | ## Evolutionary search 62 | In case you want to search the set of models of your own interest - we provide an example to show how to search the Pareto models for the best FLOPs vs. accuracy tradeoffs in _parallel_supernet_evo_search.py_; to run this example: 63 | ```python 64 | python parallel_supernet_evo_search.py --config-file configs/parallel_supernet_evo_search.yml 65 | ``` 66 | 67 | ## License 68 | AlphaNet is licensed under CC-BY-NC. 69 | 70 | ## Contributing 71 | We actively welcome your pull requests! Please see [CONTRIBUTING](CONTRIBUTING.md) and [CODE_OF_CONDUCT](CODE_OF_CONDUCT.md) for more info. 72 | 73 | 74 | -------------------------------------------------------------------------------- /configs/eval_alphanet_models.yml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # Modified from AttentiveNAS (https://github.com/facebookresearch/AttentiveNAS) 3 | 4 | arch: 'attentive_nas_static_model' 5 | 6 | pareto_models: 7 | supernet_checkpoint_path: "./alphanet_data/alphanet_pretrained.pth.tar" 8 | 9 | a0: 10 | model: "alphanet_a0" 11 | resolution: 192 12 | width: [16, 16, 24, 32, 64, 112, 192, 216, 1792] 13 | kernel_size: [3, 3, 3, 3, 3, 3, 3] 14 | expand_ratio: [1, 4, 4, 4, 4, 6, 6] 15 | depth: [1, 3, 3, 3, 3, 3, 1] 16 | 17 | a1: 18 | model: "alphanet_a1" 19 | resolution: 224 20 | width: [16, 16, 24, 32, 64, 112, 192, 216, 1984] 21 | kernel_size: [3, 3, 3, 5, 3, 5, 3] 22 | expand_ratio: [1, 4, 4, 4, 4, 6, 6] 23 | depth: [1, 3, 3, 3, 3, 3, 1] 24 | 25 | a2: 26 | model: "alphanet_a2" 27 | resolution: 224 28 | width: [16, 16, 24, 32, 64, 112, 200, 224, 1984] 29 | kernel_size: [3, 3, 3, 3, 3, 5, 3] 30 | expand_ratio: [1, 4, 5, 4, 4, 6, 6] 31 | depth: [1, 3, 3, 3, 3, 4, 1] 32 | 33 | a3: 34 | model: "alphanet_a3" 35 | resolution: 224 36 | width: [16, 16, 24, 32, 64, 112, 208, 224, 1984] 37 | kernel_size: [3, 3, 3, 5, 3, 3, 3] 38 | expand_ratio: [1, 4, 4, 4, 4, 6, 6] 39 | depth: [2, 3, 3, 4, 3, 5, 1] 40 | 41 | a4: 42 | model: "alphanet_a4" 43 | resolution: 256 44 | width: [16, 16, 24, 32, 64, 112, 192, 216, 1984] 45 | kernel_size: [3, 3, 3, 5, 3, 5, 3] 46 | expand_ratio: [1, 4, 4, 5, 4, 6, 6] 47 | depth: [1, 3, 3, 4, 3, 5, 1] 48 | 49 | # this is different from the a5 used in AttentiveNAS 50 | # see https://github.com/facebookresearch/AlphaNet/issues/1 51 | a5: 52 | model: "alphanet_a5_small" 53 | resolution: 256 54 | width: [16, 16, 24, 32, 64, 112, 208, 216, 1984] 55 | kernel_size: [3, 3, 3, 3, 3, 5, 3] 56 | expand_ratio: [1, 4, 4, 5, 4, 6, 6] 57 | depth: [1, 3, 3, 4, 4, 5, 1] 58 | 59 | a5_1: 60 | model: "alphanet_a5_base" 61 | resolution: 288 62 | width: [16, 16, 24, 32, 64, 112, 200, 216, 1984] 63 | kernel_size: [3, 3, 5, 5, 3, 3, 3] 64 | expand_ratio: [1, 4, 4, 4, 5, 6, 6] 65 | depth: [2, 3, 4, 3, 3, 5, 1] 66 | 67 | a6: 68 | model: "alphanet_a6" 69 | resolution: 288 70 | width: [16, 16, 24, 32, 64, 112, 216, 224, 1984] 71 | kernel_size: [3, 3, 3, 3, 3, 5, 3] 72 | expand_ratio: [1, 4, 6, 5, 4, 6, 6] 73 | depth: [1, 3, 3, 4, 4, 6, 1] 74 | 75 | 76 | batch_size: 256 77 | post_bn_calibration_batch_num: 64 78 | 79 | augment: "auto_augment_tf" 80 | 81 | bn_momentum: 0.1 82 | bn_eps: 1e-5 83 | 84 | distributed: False 85 | distributed_val: False 86 | eval_only: True 87 | 88 | ### imagenet dataset ### 89 | dataset: 'imagenet' 90 | dataset_dir: "/data/local/packages/ai-group.imagenet-full-size/prod/imagenet_full_size/" 91 | n_classes: 1000 92 | drop_last: True 93 | data_loader_workers_per_gpu: 4 94 | 95 | print_freq: 10 96 | seed: 5 97 | 98 | #attentive nas search space 99 | # c: channels, d: layers, k: kernel size, t: expand ratio, s: stride, act: activation, se: se layer 100 | supernet_config: 101 | use_v3_head: True 102 | resolutions: [192, 224, 256, 288] 103 | first_conv: 104 | c: [16, 24] 105 | act_func: 'swish' 106 | s: 2 107 | mb1: 108 | c: [16, 24] 109 | d: [1, 2] 110 | k: [3, 5] 111 | t: [1] 112 | s: 1 113 | act_func: 'swish' 114 | se: False 115 | mb2: 116 | c: [24, 32] 117 | d: [3, 4, 5] 118 | k: [3, 5] 119 | t: [4, 5, 6] 120 | s: 2 121 | act_func: 'swish' 122 | se: False 123 | mb3: 124 | c: [32, 40] 125 | d: [3, 4, 5, 6] 126 | k: [3, 5] 127 | t: [4, 5, 6] 128 | s: 2 129 | act_func: 'swish' 130 | se: True 131 | mb4: 132 | c: [64, 72] 133 | d: [3, 4, 5, 6] 134 | k: [3, 5] 135 | t: [4, 5, 6] 136 | s: 2 137 | act_func: 'swish' 138 | se: False 139 | mb5: 140 | c: [112, 120, 128] 141 | d: [3, 4, 5, 6, 7, 8] 142 | k: [3, 5] 143 | t: [4, 5, 6] 144 | s: 1 145 | act_func: 'swish' 146 | se: True 147 | mb6: 148 | c: [192, 200, 208, 216] 149 | d: [3, 4, 5, 6, 7, 8] 150 | k: [3, 5] 151 | t: [6] 152 | s: 2 153 | act_func: 'swish' 154 | se: True 155 | mb7: 156 | c: [216, 224] 157 | d: [1, 2] 158 | k: [3, 5] 159 | t: [6] 160 | s: 1 161 | act_func: 'swish' 162 | se: True 163 | last_conv: 164 | c: [1792, 1984] 165 | act_func: 'swish' 166 | 167 | -------------------------------------------------------------------------------- /configs/parallel_supernet_evo_search.yml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | arch: 'attentive_nas_dynamic_model' 4 | exp_name: "parallel_supernet_evo_search" 5 | 6 | # your pretrained supernet path, e.g., 7 | resume: "./alphanet_data/alphanet_pretrained.pth.tar" 8 | 9 | # settings for BN calibration 10 | batch_size_per_gpu: 256 11 | post_bn_calibration_batch_num: 32 12 | augment: "auto_augment_tf" 13 | 14 | evo_search: 15 | #default 512 16 | parent_popu_size: 4 17 | #default 128 18 | mutate_size: 4 19 | #default 128 20 | mutate_size: 4 21 | crossover_size: 4 22 | mutate_prob: 0.2 23 | evo_iter: 20 24 | targeted_min_flops: 200 25 | targeted_max_flops: 1200 26 | step: 10 27 | 28 | bn_momentum: 0.1 29 | bn_eps: 1e-5 30 | 31 | # just in case you have more GPUs 32 | n_gpu_per_node: 8 33 | data_loader_workers_per_gpu: 4 34 | num_nodes: 8 35 | n_cpu_per_node: 32 36 | gpu_type: 'GPU_V100_HOST' 37 | memory_per_node: '128g' 38 | 39 | ### distributed settings ### 40 | distributed: True 41 | distributed_val: False 42 | multiprocessing_distributed: True 43 | dist_backend: 'nccl' 44 | eval_only: True 45 | 46 | ### imagenet dataset ### 47 | dataset: 'imagenet' 48 | dataset_dir: "/data/local/packages/ai-group.imagenet-full-size/prod/imagenet_full_size/" 49 | n_classes: 1000 50 | drop_last: True 51 | 52 | print_freq: 50 53 | seed: 0 54 | 55 | #attentive nas search space 56 | # c: channels, d: layers, k: kernel size, t: expand ratio, s: stride, act: activation, se: se layer 57 | supernet_config: 58 | use_v3_head: True 59 | resolutions: [192, 224, 256, 288] 60 | first_conv: 61 | c: [16, 24] 62 | act_func: 'swish' 63 | s: 2 64 | mb1: 65 | c: [16, 24] 66 | d: [1, 2] 67 | k: [3, 5] 68 | t: [1] 69 | s: 1 70 | act_func: 'swish' 71 | se: False 72 | mb2: 73 | c: [24, 32] 74 | d: [3, 4, 5] 75 | k: [3, 5] 76 | t: [4, 5, 6] 77 | s: 2 78 | act_func: 'swish' 79 | se: False 80 | mb3: 81 | c: [32, 40] 82 | d: [3, 4, 5, 6] 83 | k: [3, 5] 84 | t: [4, 5, 6] 85 | s: 2 86 | act_func: 'swish' 87 | se: True 88 | mb4: 89 | c: [64, 72] 90 | d: [3, 4, 5, 6] 91 | k: [3, 5] 92 | t: [4, 5, 6] 93 | s: 2 94 | act_func: 'swish' 95 | se: False 96 | mb5: 97 | c: [112, 120, 128] 98 | d: [3, 4, 5, 6, 7, 8] 99 | k: [3, 5] 100 | t: [4, 5, 6] 101 | s: 1 102 | act_func: 'swish' 103 | se: True 104 | mb6: 105 | c: [192, 200, 208, 216] 106 | d: [3, 4, 5, 6, 7, 8] 107 | k: [3, 5] 108 | t: [6] 109 | s: 2 110 | act_func: 'swish' 111 | se: True 112 | mb7: 113 | c: [216, 224] 114 | d: [1, 2] 115 | k: [3, 5] 116 | t: [6] 117 | s: 1 118 | act_func: 'swish' 119 | se: True 120 | last_conv: 121 | c: [1792, 1984] 122 | act_func: 'swish' 123 | 124 | -------------------------------------------------------------------------------- /configs/train_alphanet_models.yml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # Modified from AttentiveNAS (https://github.com/facebookresearch/AttentiveNAS) 3 | 4 | arch: 'attentive_nas_dynamic_model' 5 | 6 | exp_name: "alphanet" 7 | 8 | batch_size_per_gpu: 32 9 | sandwich_rule: True 10 | 11 | alpha_min: -1.0 12 | alpha_max: 1.0 13 | iw_clip: 5.0 14 | 15 | grad_clip_value: 1.0 16 | 17 | augment: "auto_augment_tf" 18 | 19 | n_gpu_per_node: 8 20 | num_nodes: 8 21 | n_cpu_per_node: 32 22 | memory_per_node: '128g' 23 | 24 | warmup_epochs: 5 25 | epochs: 360 26 | start_epoch: 0 27 | 28 | label_smoothing: 0.1 29 | inplace_distill: True 30 | 31 | #sync-batchnormalization, suggested to use in bignas 32 | sync_bn: False 33 | 34 | bn_momentum: 0 35 | bn_eps: 1e-5 36 | 37 | post_bn_calibration_batch_num: 64 38 | 39 | num_arch_training: 4 40 | 41 | models_save_dir: "./saved_models" 42 | 43 | #### cloud training resources #### 44 | data_loader_workers_per_gpu: 4 45 | 46 | ########### regularization ################ 47 | # supernet training regularization (the largest network) 48 | dropout: 0.2 49 | drop_connect: 0.2 50 | drop_connect_only_last_two_stages: True 51 | 52 | weight_decay_weight: 0.00001 53 | weight_decay_bn_bias: 0. 54 | 55 | ## =================== optimizer and scheduler======================== # 56 | optimizer: 57 | method: sgd 58 | momentum: 0.9 59 | nesterov: True 60 | 61 | lr_scheduler: 62 | method: "warmup_cosine_lr" 63 | base_lr: 0.1 64 | clamp_lr_percent: 0.0 65 | 66 | 67 | ### distributed training settings ### 68 | multiprocessing_distributed: True 69 | dist_backend: 'nccl' 70 | distributed: True 71 | 72 | 73 | ### imagenet dataset ### 74 | dataset: 'imagenet' 75 | dataset_dir: "/data/local/packages/ai-group.imagenet-full-size/prod/imagenet_full_size/" 76 | n_classes: 1000 77 | drop_last: True 78 | 79 | print_freq: 10 80 | resume: "" 81 | 82 | seed: 0 83 | 84 | #attentive nas search space 85 | # c: channels, d: layers, k: kernel size, t: expand ratio, s: stride, act: activation, se: se layer 86 | supernet_config: 87 | use_v3_head: True 88 | resolutions: [192, 224, 256, 288] 89 | first_conv: 90 | c: [16, 24] 91 | act_func: 'swish' 92 | s: 2 93 | mb1: 94 | c: [16, 24] 95 | d: [1, 2] 96 | k: [3, 5] 97 | t: [1] 98 | s: 1 99 | act_func: 'swish' 100 | se: False 101 | mb2: 102 | c: [24, 32] 103 | d: [3, 4, 5] 104 | k: [3, 5] 105 | t: [4, 5, 6] 106 | s: 2 107 | act_func: 'swish' 108 | se: False 109 | mb3: 110 | c: [32, 40] 111 | d: [3, 4, 5, 6] 112 | k: [3, 5] 113 | t: [4, 5, 6] 114 | s: 2 115 | act_func: 'swish' 116 | se: True 117 | mb4: 118 | c: [64, 72] 119 | d: [3, 4, 5, 6] 120 | k: [3, 5] 121 | t: [4, 5, 6] 122 | s: 2 123 | act_func: 'swish' 124 | se: False 125 | mb5: 126 | c: [112, 120, 128] 127 | d: [3, 4, 5, 6, 7, 8] 128 | k: [3, 5] 129 | t: [4, 5, 6] 130 | s: 1 131 | act_func: 'swish' 132 | se: True 133 | mb6: 134 | c: [192, 200, 208, 216] 135 | d: [3, 4, 5, 6, 7, 8] 136 | k: [3, 5] 137 | t: [6] 138 | s: 2 139 | act_func: 'swish' 140 | se: True 141 | mb7: 142 | c: [216, 224] 143 | d: [1, 2] 144 | k: [3, 5] 145 | t: [6] 146 | s: 1 147 | act_func: 'swish' 148 | se: True 149 | last_conv: 150 | c: [1792, 1984] 151 | act_func: 'swish' 152 | 153 | 154 | 155 | 156 | -------------------------------------------------------------------------------- /loss_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # Implementation adapted from Slimmable - https://github.com/JiahuiYu/slimmable_networks 3 | 4 | import torch 5 | 6 | class CrossEntropyLossSoft(torch.nn.modules.loss._Loss): 7 | """ inplace distillation for image classification """ 8 | def forward(self, output, target): 9 | output_log_prob = torch.nn.functional.log_softmax(output, dim=1) 10 | target = target.unsqueeze(1) 11 | output_log_prob = output_log_prob.unsqueeze(2) 12 | cross_entropy_loss = -torch.bmm(target, output_log_prob) 13 | return cross_entropy_loss.mean() 14 | 15 | 16 | 17 | class KLLossSoft(torch.nn.modules.loss._Loss): 18 | """ inplace distillation for image classification 19 | output: output logits of the student network 20 | target: output logits of the teacher network 21 | T: temperature 22 | KL(p||q) = Ep \log p - \Ep log q 23 | """ 24 | def forward(self, output, soft_logits, target=None, temperature=1., alpha=0.9): 25 | output, soft_logits = output / temperature, soft_logits / temperature 26 | soft_target_prob = torch.nn.functional.softmax(soft_logits, dim=1) 27 | output_log_prob = torch.nn.functional.log_softmax(output, dim=1) 28 | kd_loss = -torch.sum(soft_target_prob * output_log_prob, dim=1) 29 | if target is not None: 30 | n_class = output.size(1) 31 | target = torch.zeros_like(output).scatter(1, target.view(-1, 1), 1) 32 | target = target.unsqueeze(1) 33 | output_log_prob = output_log_prob.unsqueeze(2) 34 | ce_loss = -torch.bmm(target, output_log_prob).squeeze() 35 | loss = alpha*temperature* temperature*kd_loss + (1.0-alpha)*ce_loss 36 | else: 37 | loss = kd_loss 38 | 39 | if self.reduction == 'mean': 40 | return loss.mean() 41 | elif self.reduction == 'sum': 42 | return loss.sum() 43 | return loss 44 | 45 | 46 | class CrossEntropyLossSmooth(torch.nn.modules.loss._Loss): 47 | def __init__(self, label_smoothing=0.1): 48 | super(CrossEntropyLossSmooth, self).__init__() 49 | self.eps = label_smoothing 50 | 51 | """ label smooth """ 52 | def forward(self, output, target): 53 | n_class = output.size(1) 54 | one_hot = torch.zeros_like(output).scatter(1, target.view(-1, 1), 1) 55 | target = one_hot * (1 - self.eps) + self.eps / n_class 56 | output_log_prob = torch.nn.functional.log_softmax(output, dim=1) 57 | target = target.unsqueeze(1) 58 | output_log_prob = output_log_prob.unsqueeze(2) 59 | loss = -torch.bmm(target, output_log_prob) 60 | 61 | if self.reduction == 'mean': 62 | return loss.mean() 63 | elif self.reduction == 'sum': 64 | return loss.sum() 65 | return loss 66 | 67 | 68 | def f_divergence(q_logits, p_logits, alpha, iw_clip=1e3): 69 | assert isinstance(alpha, float) 70 | q_prob = torch.nn.functional.softmax(q_logits, dim=1).detach() 71 | p_prob = torch.nn.functional.softmax(p_logits, dim=1).detach() 72 | q_log_prob = torch.nn.functional.log_softmax(q_logits, dim=1) #gradient is only backpropagated here 73 | 74 | importance_ratio = p_prob / q_prob 75 | if abs(alpha) < 1e-3: 76 | importance_ratio = importance_ratio.clamp(0, iw_clip) 77 | f = -importance_ratio.log() 78 | f_base = 0 79 | rho_f = importance_ratio.log() - 1.0 80 | elif abs(alpha - 1.0) < 1e-3: 81 | f = importance_ratio * importance_ratio.log() 82 | f_base = 0 83 | rho_f = importance_ratio 84 | else: 85 | iw_alpha = torch.pow(importance_ratio, alpha) 86 | iw_alpha = iw_alpha.clamp(0, iw_clip) 87 | f = iw_alpha / alpha / (alpha - 1.0) 88 | f_base = 1.0 / alpha / (alpha - 1.0) 89 | rho_f = iw_alpha / alpha + f_base 90 | 91 | loss = torch.sum(q_prob * (f - f_base), dim=1) 92 | grad_loss = -torch.sum(q_prob * rho_f * q_log_prob, dim=1) 93 | return loss, grad_loss 94 | 95 | 96 | """ 97 | It's often necessary to clip the maximum 98 | gradient value (e.g., 1.0) when using this adaptive KD loss 99 | """ 100 | class AdaptiveLossSoft(torch.nn.modules.loss._Loss): 101 | def __init__(self, alpha_min=-1.0, alpha_max=1.0, iw_clip=5.0): 102 | super(AdaptiveLossSoft, self).__init__() 103 | self.alpha_min = alpha_min 104 | self.alpha_max = alpha_max 105 | self.iw_clip = iw_clip 106 | 107 | def forward(self, output, target, alpha_min=None, alpha_max=None): 108 | alpha_min = alpha_min or self.alpha_min 109 | alpha_max = alpha_max or self.alpha_max 110 | 111 | loss_left, grad_loss_left = f_divergence(output, target, alpha_min, iw_clip=self.iw_clip) 112 | loss_right, grad_loss_right = f_divergence(output, target, alpha_max, iw_clip=self.iw_clip) 113 | 114 | ind = torch.gt(loss_left, loss_right).float() 115 | loss = ind * grad_loss_left + (1.0 - ind) * grad_loss_right 116 | 117 | if self.reduction == 'mean': 118 | return loss.mean() 119 | elif self.reduction == 'sum': 120 | return loss.sum() 121 | return loss 122 | 123 | 124 | -------------------------------------------------------------------------------- /parallel_supernet_evo_search.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import argparse 4 | import random 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.distributed as dist 9 | import torch.multiprocessing as mp 10 | 11 | import models 12 | from utils.config import setup 13 | import utils.comm as comm 14 | import utils.saver as saver 15 | 16 | from data.data_loader import build_data_loader 17 | from evaluate import attentive_nas_eval as attentive_nas_eval 18 | import utils.logging as logging 19 | import argparse 20 | 21 | """ 22 | using multiple nodes to run evolutionary search: 23 | 1) each GPU will evaluate its own sub-networks 24 | 2) all evaluation results will be aggregated on GPU 0 25 | """ 26 | parser = argparse.ArgumentParser(description='Test AlphaNet Models') 27 | parser.add_argument('--config-file', default='./configs/parallel_supernet_evo_search.yml') 28 | parser.add_argument('--machine-rank', default=0, type=int, 29 | help='machine rank, distributed setting') 30 | parser.add_argument('--num-machines', default=1, type=int, 31 | help='number of nodes, distributed setting') 32 | parser.add_argument('--dist-url', default="tcp://127.0.0.1:10001", type=str, 33 | help='init method, distributed setting') 34 | parser.add_argument('--seed', default=1, type=int, 35 | help='default random seed') 36 | run_args = parser.parse_args() 37 | 38 | 39 | logger = logging.get_logger(__name__) 40 | 41 | 42 | def eval_worker(gpu, ngpus_per_node, args): 43 | args.gpu = gpu # local rank, local machine cuda id 44 | args.local_rank = args.gpu 45 | args.batch_size = args.batch_size_per_gpu 46 | 47 | global_rank = args.gpu + args.machine_rank * ngpus_per_node 48 | dist.init_process_group( 49 | backend=args.dist_backend, 50 | init_method=args.dist_url, 51 | world_size=args.world_size, 52 | rank=global_rank 53 | ) 54 | 55 | # Setup logging format. 56 | logging.setup_logging("stdout.log", 'w') 57 | 58 | # synchronize is needed here to prevent a possible timeout after calling 59 | # init_process_group 60 | # See: https://github.com/facebookresearch/maskrcnn-benchmark/issues/172 61 | comm.synchronize() 62 | 63 | args.rank = comm.get_rank() # global rank 64 | torch.cuda.set_device(args.gpu) 65 | 66 | random.seed(args.seed) 67 | torch.manual_seed(args.seed) 68 | torch.cuda.manual_seed(args.seed) 69 | 70 | # build the supernet 71 | logger.info("=> creating model '{}'".format(args.arch)) 72 | model = models.model_factory.create_model(args) 73 | model.cuda(args.gpu) 74 | model = comm.get_parallel_model(model, args.gpu) #local rank 75 | 76 | # define loss function (criterion) 77 | criterion = nn.CrossEntropyLoss().cuda() 78 | 79 | ## load dataset, train_sampler: distributed 80 | train_loader, val_loader, train_sampler = build_data_loader(args) 81 | 82 | assert args.resume 83 | #reloading model 84 | model.module.load_weights_from_pretrained_models(args.resume) 85 | 86 | if train_sampler: 87 | train_sampler.set_epoch(0) 88 | 89 | targeted_min_flops = args.evo_search.targeted_min_flops 90 | targeted_max_flops = args.evo_search.targeted_max_flops 91 | 92 | # run evolutionary search 93 | parent_popu = [] 94 | for idx in range(args.evo_search.parent_popu_size): 95 | if idx == 0: 96 | cfg = model.module.sample_min_subnet() 97 | else: 98 | cfg = model.module.sample_active_subnet_within_range( 99 | targeted_min_flops, targeted_max_flops 100 | ) 101 | cfg['net_id'] = f'net_{idx % args.world_size}_evo_0_{idx}' 102 | parent_popu.append(cfg) 103 | 104 | pareto_global = {} 105 | for evo in range(args.evo_search.evo_iter): 106 | # partition the set of candidate sub-networks 107 | # and send them to each GPU for parallel evaluation 108 | 109 | # sub-networks to be evaluated on GPU {args.rank} 110 | my_subnets_to_be_evaluated = {} 111 | n_evaluated = len(parent_popu) // args.world_size * args.world_size 112 | for cfg in parent_popu[:n_evaluated]: 113 | if cfg['net_id'].startswith(f'net_{args.rank}_'): 114 | my_subnets_to_be_evaluated[cfg['net_id']] = cfg 115 | 116 | # aggregating all evaluation results 117 | eval_results = attentive_nas_eval.validate( 118 | my_subnets_to_be_evaluated, 119 | train_loader, 120 | val_loader, 121 | model, 122 | criterion, 123 | args, 124 | logger, 125 | ) 126 | 127 | # update the Pareto frontier 128 | # in this case, we search the best FLOPs vs. accuracy trade-offs 129 | for cfg in eval_results: 130 | f = round(cfg['flops'] / args.evo_search.step) * args.evo_search.step 131 | if f not in pareto_global or pareto_global[f]['acc1'] < cfg['acc1']: 132 | pareto_global[f] = cfg 133 | 134 | # next batch of sub-networks to be evaluated 135 | parent_popu = [] 136 | # mutate 137 | for idx in range(args.evo_search.mutate_size): 138 | while True: 139 | old_cfg = random.choice(list(pareto_global.values())) 140 | cfg = model.module.mutate_and_reset(old_cfg, prob=args.evo_search.mutate_prob) 141 | flops = model.module.compute_active_subnet_flops() 142 | if flops >= targeted_min_flops and flops <= targeted_max_flops: 143 | break 144 | cfg['net_id'] = f'net_{idx % args.world_size}_evo_{evo}_mutate_{idx}' 145 | parent_popu.append(cfg) 146 | 147 | # cross over 148 | for idx in range(args.evo_search.crossover_size): 149 | while True: 150 | cfg1 = random.choice(list(pareto_global.values())) 151 | cfg2 = random.choice(list(pareto_global.values())) 152 | cfg = model.module.crossover_and_reset(cfg1, cfg2) 153 | flops = model.module.compute_active_subnet_flops() 154 | if flops >= targeted_min_flops and flops <= targeted_max_flops: 155 | break 156 | cfg['net_id'] = f'net_{idx % args.world_size}_evo_{evo}_crossover_{idx}' 157 | parent_popu.append(cfg) 158 | 159 | if __name__ == '__main__': 160 | # setup enviroments 161 | args = setup(run_args.config_file) 162 | args.dist_url = run_args.dist_url 163 | args.machine_rank = run_args.machine_rank 164 | args.num_nodes = run_args.num_machines 165 | 166 | ngpus_per_node = torch.cuda.device_count() 167 | 168 | if args.multiprocessing_distributed: 169 | # Since we have ngpus_per_node processes per node, the total world_size 170 | # needs to be adjusted accordingly 171 | args.world_size = ngpus_per_node * args.num_nodes 172 | assert args.world_size > 1, "only support DDP settings" 173 | # Use torch.multiprocessing.spawn to launch distributed processes: the 174 | # eval_worker process function 175 | mp.spawn(eval_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 176 | else: 177 | raise NotImplementedError 178 | 179 | -------------------------------------------------------------------------------- /test_alphanet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | # Modified from AttentiveNAS (https://github.com/facebookresearch/AttentiveNAS) 4 | import argparse 5 | import builtins 6 | import math 7 | import os 8 | import random 9 | import shutil 10 | import time 11 | import warnings 12 | import sys 13 | from datetime import date 14 | 15 | import torch 16 | import torch.nn as nn 17 | import torch.nn.parallel 18 | import torch.backends.cudnn as cudnn 19 | import torch.distributed as dist 20 | import torch.optim 21 | import torch.multiprocessing as mp 22 | import torch.utils.data 23 | import torch.utils.data.distributed 24 | 25 | import models 26 | from utils.config import setup 27 | from utils.flops_counter import count_net_flops_and_params 28 | import utils.comm as comm 29 | import utils.saver as saver 30 | 31 | from data.data_loader import build_data_loader 32 | from utils.progress import AverageMeter, ProgressMeter, accuracy 33 | import argparse 34 | 35 | parser = argparse.ArgumentParser(description='Test AlphaNet Models') 36 | parser.add_argument('--config-file', default='./configs/eval_alphanet_models.yml') 37 | parser.add_argument('--model', default='a0', type=str, choices=['a0', 'a1', 'a2', 'a3', 'a4', 'a5', 'a5_1', 'a6']) 38 | parser.add_argument('--gpu', default=0, type=int, help='gpu id') 39 | 40 | run_args = parser.parse_args() 41 | 42 | if __name__ == '__main__': 43 | args = setup(run_args.config_file) 44 | args.model = run_args.model 45 | args.gpu = run_args.gpu 46 | 47 | random.seed(args.seed) 48 | torch.manual_seed(args.seed) 49 | torch.cuda.manual_seed(args.seed) 50 | 51 | args.__dict__['active_subnet'] = args.__dict__['pareto_models'][args.model] 52 | print(args.active_subnet) 53 | 54 | train_loader, val_loader, train_sampler = build_data_loader(args) 55 | 56 | ## init static attentivenas model with weights inherited from the supernet 57 | model = models.model_factory.create_model(args) 58 | 59 | model.to(args.gpu) 60 | model.eval() 61 | 62 | # bn running stats calibration following Slimmable (https://arxiv.org/abs/1903.05134) 63 | # please consider trying a different random seed if you see a small accuracy drop 64 | with torch.no_grad(): 65 | model.reset_running_stats_for_calibration() 66 | for batch_idx, (images, _) in enumerate(train_loader): 67 | if batch_idx >= args.post_bn_calibration_batch_num: 68 | break 69 | images = images.cuda(args.gpu, non_blocking=True) 70 | model(images) #forward only 71 | 72 | model.eval() 73 | with torch.no_grad(): 74 | criterion = nn.CrossEntropyLoss().cuda() 75 | 76 | from evaluate.imagenet_eval import validate_one_subnet 77 | acc1, acc5, loss, flops, params = validate_one_subnet(val_loader, model, criterion, args) 78 | print(acc1, acc5, flops, params) 79 | 80 | 81 | -------------------------------------------------------------------------------- /train_alphanet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | # Modified from AttentiveNAS (https://github.com/facebookresearch/AttentiveNAS) 4 | 5 | import argparse 6 | import builtins 7 | import math 8 | import os 9 | import random 10 | import shutil 11 | import time 12 | import warnings 13 | import sys 14 | import operator 15 | from datetime import date 16 | 17 | import torch 18 | import torch.nn as nn 19 | #from torch.utils.tensorboard import SummaryWriter 20 | import torch.nn.parallel 21 | import torch.backends.cudnn as cudnn 22 | import torch.distributed as dist 23 | import torch.optim 24 | import torch.multiprocessing as mp 25 | import torch.utils.data 26 | import torch.utils.data.distributed 27 | 28 | from data.data_loader import build_data_loader 29 | 30 | from utils.config import setup 31 | import utils.saver as saver 32 | from utils.progress import AverageMeter, ProgressMeter, accuracy 33 | import utils.comm as comm 34 | import utils.logging as logging 35 | from evaluate import attentive_nas_eval as attentive_nas_eval 36 | from solver import build_optimizer, build_lr_scheduler 37 | import models 38 | from copy import deepcopy 39 | import numpy as np 40 | 41 | import loss_ops as loss_ops 42 | 43 | 44 | parser = argparse.ArgumentParser(description='AlphaNet Training') 45 | parser.add_argument('--config-file', default=None, type=str, 46 | help='training configuration') 47 | parser.add_argument('--machine-rank', default=0, type=int, 48 | help='machine rank, distributed setting') 49 | parser.add_argument('--num-machines', default=1, type=int, 50 | help='number of nodes, distributed setting') 51 | parser.add_argument('--dist-url', default="tcp://127.0.0.1:10001", type=str, 52 | help='init method, distributed setting') 53 | 54 | logger = logging.get_logger(__name__) 55 | 56 | 57 | def build_args_and_env(run_args): 58 | 59 | assert run_args.config_file and os.path.isfile(run_args.config_file), 'cannot locate config file' 60 | args = setup(run_args.config_file) 61 | args.config_file = run_args.config_file 62 | 63 | #load config 64 | assert args.distributed and args.multiprocessing_distributed, 'only support DDP training' 65 | args.distributed = True 66 | 67 | args.machine_rank = run_args.machine_rank 68 | args.num_nodes = run_args.num_machines 69 | args.dist_url = run_args.dist_url 70 | args.models_save_dir = os.path.join(args.models_save_dir, args.exp_name) 71 | 72 | if not os.path.exists(args.models_save_dir): 73 | os.makedirs(args.models_save_dir) 74 | 75 | #backup config file 76 | saver.copy_file(args.config_file, '{}/{}'.format(args.models_save_dir, os.path.basename(args.config_file))) 77 | 78 | args.checkpoint_save_path = os.path.join( 79 | args.models_save_dir, 'alphanet.pth.tar' 80 | ) 81 | args.logging_save_path = os.path.join( 82 | args.models_save_dir, f'stdout.log' 83 | ) 84 | return args 85 | 86 | 87 | def main(): 88 | run_args = parser.parse_args() 89 | args = build_args_and_env(run_args) 90 | 91 | random.seed(args.seed) 92 | torch.manual_seed(args.seed) 93 | #cudnn.deterministic = True 94 | #warnings.warn('You have chosen to seed training. ' 95 | # 'This will turn on the CUDNN deterministic setting, ' 96 | # 'which can slow down your training considerably! ' 97 | # 'You may see unexpected behavior when restarting ' 98 | # 'from checkpoints.') 99 | 100 | ngpus_per_node = torch.cuda.device_count() 101 | if args.multiprocessing_distributed: 102 | # Since we have ngpus_per_node processes per node, the total world_size 103 | # needs to be adjusted accordingly 104 | args.world_size = ngpus_per_node * args.num_nodes 105 | # Use torch.multiprocessing.spawn to launch distributed processes: the 106 | # main_worker process function 107 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 108 | else: 109 | raise NotImplementedError 110 | 111 | assert args.world_size > 1, 'only support ddp training' 112 | 113 | 114 | def main_worker(gpu, ngpus_per_node, args): 115 | args.gpu = gpu # local rank, local machine cuda id 116 | args.local_rank = args.gpu 117 | args.batch_size = args.batch_size_per_gpu 118 | args.batch_size_total = args.batch_size * args.world_size 119 | #rescale base lr 120 | args.lr_scheduler.base_lr = args.lr_scheduler.base_lr * (max(1, args.batch_size_total // 256)) 121 | 122 | # set random seed, make sure all random subgraph generated would be the same 123 | random.seed(args.seed) 124 | torch.manual_seed(args.seed) 125 | if args.gpu: 126 | torch.cuda.manual_seed(args.seed) 127 | 128 | global_rank = args.gpu + args.machine_rank * ngpus_per_node 129 | dist.init_process_group( 130 | backend=args.dist_backend, 131 | init_method=args.dist_url, 132 | world_size=args.world_size, 133 | rank=global_rank 134 | ) 135 | 136 | # Setup logging format. 137 | logging.setup_logging(args.logging_save_path, 'w') 138 | 139 | logger.info(f"Use GPU: {args.gpu}, machine rank {args.machine_rank}, num_nodes {args.num_nodes}, \ 140 | gpu per node {ngpus_per_node}, world size {args.world_size}") 141 | 142 | # synchronize is needed here to prevent a possible timeout after calling 143 | # init_process_group 144 | # See: https://github.com/facebookresearch/maskrcnn-benchmark/issues/172 145 | comm.synchronize() 146 | 147 | args.rank = comm.get_rank() # global rank 148 | args.local_rank = args.gpu 149 | torch.cuda.set_device(args.gpu) 150 | 151 | # build model 152 | logger.info("=> creating model '{}'".format(args.arch)) 153 | model = models.model_factory.create_model(args) 154 | model.cuda(args.gpu) 155 | 156 | # use sync batchnorm 157 | if getattr(args, 'sync_bn', False): 158 | model.apply( 159 | lambda m: setattr(m, 'need_sync', True)) 160 | 161 | model = comm.get_parallel_model(model, args.gpu) #local rank 162 | 163 | logger.info(model) 164 | 165 | criterion = loss_ops.CrossEntropyLossSmooth(args.label_smoothing).cuda(args.gpu) 166 | soft_criterion = loss_ops.AdaptiveLossSoft(args.alpha_min, args.alpha_max, args.iw_clip).cuda(args.gpu) 167 | 168 | if not getattr(args, 'inplace_distill', True): 169 | soft_criterion = None 170 | 171 | ## load dataset, train_sampler: distributed 172 | train_loader, val_loader, train_sampler = build_data_loader(args) 173 | args.n_iters_per_epoch = len(train_loader) 174 | 175 | logger.info( f'building optimizer and lr scheduler, \ 176 | local rank {args.gpu}, global rank {args.rank}, world_size {args.world_size}') 177 | optimizer = build_optimizer(args, model) 178 | lr_scheduler = build_lr_scheduler(args, optimizer) 179 | 180 | # optionally resume from a checkpoint 181 | if args.resume: 182 | saver.load_checkpoints(args, model, optimizer, lr_scheduler, logger) 183 | 184 | logger.info(args) 185 | 186 | for epoch in range(args.start_epoch, args.epochs): 187 | if args.distributed: 188 | train_sampler.set_epoch(epoch) 189 | 190 | args.curr_epoch = epoch 191 | logger.info('Training lr {}'.format(lr_scheduler.get_lr()[0])) 192 | 193 | # train for one epoch 194 | acc1, acc5 = train_epoch(epoch, model, train_loader, optimizer, criterion, args, \ 195 | soft_criterion=soft_criterion, lr_scheduler=lr_scheduler) 196 | 197 | if comm.is_master_process() or args.distributed: 198 | # validate supernet model 199 | validate( 200 | train_loader, val_loader, model, criterion, args 201 | ) 202 | 203 | if comm.is_master_process(): 204 | # save checkpoints 205 | saver.save_checkpoint( 206 | args.checkpoint_save_path, 207 | model, 208 | optimizer, 209 | lr_scheduler, 210 | args, 211 | epoch, 212 | ) 213 | 214 | 215 | def train_epoch( 216 | epoch, 217 | model, 218 | train_loader, 219 | optimizer, 220 | criterion, 221 | args, 222 | soft_criterion=None, 223 | lr_scheduler=None, 224 | ): 225 | batch_time = AverageMeter('Time', ':6.3f') 226 | data_time = AverageMeter('Data', ':6.3f') 227 | losses = AverageMeter('Loss', ':.4e') 228 | top1 = AverageMeter('Acc@1', ':6.2f') 229 | top5 = AverageMeter('Acc@5', ':6.2f') 230 | progress = ProgressMeter( 231 | len(train_loader), 232 | [batch_time, data_time, losses, top1, top5], 233 | prefix="Epoch: [{}]".format(epoch)) 234 | 235 | model.train() 236 | end = time.time() 237 | 238 | num_updates = epoch * len(train_loader) 239 | 240 | for batch_idx, (images, target) in enumerate(train_loader): 241 | # measure data loading time 242 | data_time.update(time.time() - end) 243 | 244 | images = images.cuda(args.gpu, non_blocking=True) 245 | target = target.cuda(args.gpu, non_blocking=True) 246 | 247 | # total subnets to be sampled 248 | num_subnet_training = max(2, getattr(args, 'num_arch_training', 2)) 249 | optimizer.zero_grad() 250 | 251 | ### compute gradients using sandwich rule ### 252 | # step 1 sample the largest network, apply regularization to only the largest network 253 | drop_connect_only_last_two_stages = getattr(args, 'drop_connect_only_last_two_stages', True) 254 | model.module.sample_max_subnet() 255 | model.module.set_dropout_rate(args.dropout, args.drop_connect, drop_connect_only_last_two_stages) #dropout for supernet 256 | output = model(images) 257 | loss = criterion(output, target) 258 | loss.backward() 259 | 260 | with torch.no_grad(): 261 | soft_logits = output.clone().detach() 262 | 263 | #step 2. sample the smallest network and several random networks 264 | sandwich_rule = getattr(args, 'sandwich_rule', True) 265 | model.module.set_dropout_rate(0, 0, drop_connect_only_last_two_stages) #reset dropout rate 266 | for arch_id in range(1, num_subnet_training): 267 | if arch_id == num_subnet_training-1 and sandwich_rule: 268 | model.module.sample_min_subnet() 269 | else: 270 | model.module.sample_active_subnet() 271 | 272 | # calcualting loss 273 | output = model(images) 274 | 275 | if soft_criterion: 276 | loss = soft_criterion(output, soft_logits) 277 | else: 278 | assert not args.inplace_distill 279 | loss = criterion(output, target) 280 | 281 | loss.backward() 282 | 283 | #clip gradients if specfied 284 | if getattr(args, 'grad_clip_value', None): 285 | torch.nn.utils.clip_grad_value_(model.parameters(), args.grad_clip_value) 286 | 287 | optimizer.step() 288 | 289 | #accuracy measured on the local batch 290 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 291 | if args.distributed: 292 | corr1, corr5, loss = acc1*args.batch_size, acc5*args.batch_size, loss.item()*args.batch_size #just in case the batch size is different on different nodes 293 | stats = torch.tensor([corr1, corr5, loss, args.batch_size], device=args.gpu) 294 | dist.barrier() # synchronizes all processes 295 | dist.all_reduce(stats, op=torch.distributed.ReduceOp.SUM) 296 | corr1, corr5, loss, batch_size = stats.tolist() 297 | acc1, acc5, loss = corr1/batch_size, corr5/batch_size, loss/batch_size 298 | losses.update(loss, batch_size) 299 | top1.update(acc1, batch_size) 300 | top5.update(acc5, batch_size) 301 | else: 302 | losses.update(loss.item(), images.size(0)) 303 | top1.update(acc1, images.size(0)) 304 | top5.update(acc5, images.size(0)) 305 | 306 | 307 | # measure elapsed time 308 | batch_time.update(time.time() - end) 309 | end = time.time() 310 | 311 | num_updates += 1 312 | if lr_scheduler is not None: 313 | lr_scheduler.step() 314 | 315 | if batch_idx % args.print_freq == 0: 316 | progress.display(batch_idx, logger) 317 | 318 | return top1.avg, top5.avg 319 | 320 | 321 | def validate( 322 | train_loader, 323 | val_loader, 324 | model, 325 | criterion, 326 | args, 327 | distributed = True, 328 | ): 329 | subnets_to_be_evaluated = { 330 | 'attentive_nas_min_net': {}, 331 | 'attentive_nas_max_net': {}, 332 | } 333 | 334 | acc1_list, acc5_list = attentive_nas_eval.validate( 335 | subnets_to_be_evaluated, 336 | train_loader, 337 | val_loader, 338 | model, 339 | criterion, 340 | args, 341 | logger, 342 | bn_calibration = True, 343 | ) 344 | 345 | 346 | 347 | if __name__ == '__main__': 348 | main() 349 | 350 | 351 | --------------------------------------------------------------------------------