├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── INSTALL.md ├── LICENSE ├── README.md ├── datasets.py ├── drop_scheduler.py ├── engine.py ├── main.py ├── models ├── convnext.py ├── mlp_mixer.py ├── swin_transformer.py └── vision_transformer.py ├── optim_factory.py ├── run_with_submitit.py └── utils.py /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Meta's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to this repo, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. 32 | -------------------------------------------------------------------------------- /INSTALL.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | We provide installation instructions for ImageNet classification experiments here. 4 | 5 | ## Dependency Setup 6 | Create an new conda virtual environment 7 | ``` 8 | conda create -n dropout python=3.8 -y 9 | conda activate dropout 10 | ``` 11 | 12 | Install [Pytorch](https://pytorch.org/)>=1.8.0, [torchvision](https://pytorch.org/vision/stable/index.html)>=0.9.0 following official instructions. For example: 13 | ``` 14 | pip install torch==1.8.0+cu111 torchvision==0.9.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html 15 | ``` 16 | 17 | Clone this repo and install required packages: 18 | ``` 19 | git clone https://github.com/facebookresearch/dropout 20 | pip install timm==0.4.12 tensorboardX six 21 | ``` 22 | 23 | The results in the paper are produced with `torch==1.8.0+cu111 torchvision==0.9.0+cu111 timm==0.4.12`. 24 | 25 | ## Dataset Preparation 26 | 27 | Download the [ImageNet-1K](http://image-net.org/) classification dataset and structure the data as follows: 28 | ``` 29 | /path/to/imagenet-1k/ 30 | train/ 31 | class1/ 32 | img1.jpeg 33 | class2/ 34 | img2.jpeg 35 | val/ 36 | class1/ 37 | img3.jpeg 38 | class2/ 39 | img4.jpeg 40 | ``` -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Attribution-NonCommercial 4.0 International 3 | 4 | ======================================================================= 5 | 6 | Creative Commons Corporation ("Creative Commons") is not a law firm and 7 | does not provide legal services or legal advice. Distribution of 8 | Creative Commons public licenses does not create a lawyer-client or 9 | other relationship. Creative Commons makes its licenses and related 10 | information available on an "as-is" basis. Creative Commons gives no 11 | warranties regarding its licenses, any material licensed under their 12 | terms and conditions, or any related information. Creative Commons 13 | disclaims all liability for damages resulting from their use to the 14 | fullest extent possible. 15 | 16 | Using Creative Commons Public Licenses 17 | 18 | Creative Commons public licenses provide a standard set of terms and 19 | conditions that creators and other rights holders may use to share 20 | original works of authorship and other material subject to copyright 21 | and certain other rights specified in the public license below. The 22 | following considerations are for informational purposes only, are not 23 | exhaustive, and do not form part of our licenses. 24 | 25 | Considerations for licensors: Our public licenses are 26 | intended for use by those authorized to give the public 27 | permission to use material in ways otherwise restricted by 28 | copyright and certain other rights. Our licenses are 29 | irrevocable. Licensors should read and understand the terms 30 | and conditions of the license they choose before applying it. 31 | Licensors should also secure all rights necessary before 32 | applying our licenses so that the public can reuse the 33 | material as expected. Licensors should clearly mark any 34 | material not subject to the license. This includes other CC- 35 | licensed material, or material used under an exception or 36 | limitation to copyright. More considerations for licensors: 37 | wiki.creativecommons.org/Considerations_for_licensors 38 | 39 | Considerations for the public: By using one of our public 40 | licenses, a licensor grants the public permission to use the 41 | licensed material under specified terms and conditions. If 42 | the licensor's permission is not necessary for any reason--for 43 | example, because of any applicable exception or limitation to 44 | copyright--then that use is not regulated by the license. Our 45 | licenses grant only permissions under copyright and certain 46 | other rights that a licensor has authority to grant. Use of 47 | the licensed material may still be restricted for other 48 | reasons, including because others have copyright or other 49 | rights in the material. A licensor may make special requests, 50 | such as asking that all changes be marked or described. 51 | Although not required by our licenses, you are encouraged to 52 | respect those requests where reasonable. More_considerations 53 | for the public: 54 | wiki.creativecommons.org/Considerations_for_licensees 55 | 56 | ======================================================================= 57 | 58 | Creative Commons Attribution-NonCommercial 4.0 International Public 59 | License 60 | 61 | By exercising the Licensed Rights (defined below), You accept and agree 62 | to be bound by the terms and conditions of this Creative Commons 63 | Attribution-NonCommercial 4.0 International Public License ("Public 64 | License"). To the extent this Public License may be interpreted as a 65 | contract, You are granted the Licensed Rights in consideration of Your 66 | acceptance of these terms and conditions, and the Licensor grants You 67 | such rights in consideration of benefits the Licensor receives from 68 | making the Licensed Material available under these terms and 69 | conditions. 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. Copyright and Similar Rights means copyright and/or similar rights 88 | closely related to copyright including, without limitation, 89 | performance, broadcast, sound recording, and Sui Generis Database 90 | Rights, without regard to how the rights are labeled or 91 | categorized. For purposes of this Public License, the rights 92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 93 | Rights. 94 | d. Effective Technological Measures means those measures that, in the 95 | absence of proper authority, may not be circumvented under laws 96 | fulfilling obligations under Article 11 of the WIPO Copyright 97 | Treaty adopted on December 20, 1996, and/or similar international 98 | agreements. 99 | 100 | e. Exceptions and Limitations means fair use, fair dealing, and/or 101 | any other exception or limitation to Copyright and Similar Rights 102 | that applies to Your use of the Licensed Material. 103 | 104 | f. Licensed Material means the artistic or literary work, database, 105 | or other material to which the Licensor applied this Public 106 | License. 107 | 108 | g. Licensed Rights means the rights granted to You subject to the 109 | terms and conditions of this Public License, which are limited to 110 | all Copyright and Similar Rights that apply to Your use of the 111 | Licensed Material and that the Licensor has authority to license. 112 | 113 | h. Licensor means the individual(s) or entity(ies) granting rights 114 | under this Public License. 115 | 116 | i. NonCommercial means not primarily intended for or directed towards 117 | commercial advantage or monetary compensation. For purposes of 118 | this Public License, the exchange of the Licensed Material for 119 | other material subject to Copyright and Similar Rights by digital 120 | file-sharing or similar means is NonCommercial provided there is 121 | no payment of monetary compensation in connection with the 122 | exchange. 123 | 124 | j. Share means to provide material to the public by any means or 125 | process that requires permission under the Licensed Rights, such 126 | as reproduction, public display, public performance, distribution, 127 | dissemination, communication, or importation, and to make material 128 | available to the public including in ways that members of the 129 | public may access the material from a place and at a time 130 | individually chosen by them. 131 | 132 | k. Sui Generis Database Rights means rights other than copyright 133 | resulting from Directive 96/9/EC of the European Parliament and of 134 | the Council of 11 March 1996 on the legal protection of databases, 135 | as amended and/or succeeded, as well as other essentially 136 | equivalent rights anywhere in the world. 137 | 138 | l. You means the individual or entity exercising the Licensed Rights 139 | under this Public License. Your has a corresponding meaning. 140 | 141 | Section 2 -- Scope. 142 | 143 | a. License grant. 144 | 145 | 1. Subject to the terms and conditions of this Public License, 146 | the Licensor hereby grants You a worldwide, royalty-free, 147 | non-sublicensable, non-exclusive, irrevocable license to 148 | exercise the Licensed Rights in the Licensed Material to: 149 | 150 | a. reproduce and Share the Licensed Material, in whole or 151 | in part, for NonCommercial purposes only; and 152 | 153 | b. produce, reproduce, and Share Adapted Material for 154 | NonCommercial purposes only. 155 | 156 | 2. Exceptions and Limitations. For the avoidance of doubt, where 157 | Exceptions and Limitations apply to Your use, this Public 158 | License does not apply, and You do not need to comply with 159 | its terms and conditions. 160 | 161 | 3. Term. The term of this Public License is specified in Section 162 | 6(a). 163 | 164 | 4. Media and formats; technical modifications allowed. The 165 | Licensor authorizes You to exercise the Licensed Rights in 166 | all media and formats whether now known or hereafter created, 167 | and to make technical modifications necessary to do so. The 168 | Licensor waives and/or agrees not to assert any right or 169 | authority to forbid You from making technical modifications 170 | necessary to exercise the Licensed Rights, including 171 | technical modifications necessary to circumvent Effective 172 | Technological Measures. For purposes of this Public License, 173 | simply making modifications authorized by this Section 2(a) 174 | (4) never produces Adapted Material. 175 | 176 | 5. Downstream recipients. 177 | 178 | a. Offer from the Licensor -- Licensed Material. Every 179 | recipient of the Licensed Material automatically 180 | receives an offer from the Licensor to exercise the 181 | Licensed Rights under the terms and conditions of this 182 | Public License. 183 | 184 | b. No downstream restrictions. You may not offer or impose 185 | any additional or different terms or conditions on, or 186 | apply any Effective Technological Measures to, the 187 | Licensed Material if doing so restricts exercise of the 188 | Licensed Rights by any recipient of the Licensed 189 | Material. 190 | 191 | 6. No endorsement. Nothing in this Public License constitutes or 192 | may be construed as permission to assert or imply that You 193 | are, or that Your use of the Licensed Material is, connected 194 | with, or sponsored, endorsed, or granted official status by, 195 | the Licensor or others designated to receive attribution as 196 | provided in Section 3(a)(1)(A)(i). 197 | 198 | b. Other rights. 199 | 200 | 1. Moral rights, such as the right of integrity, are not 201 | licensed under this Public License, nor are publicity, 202 | privacy, and/or other similar personality rights; however, to 203 | the extent possible, the Licensor waives and/or agrees not to 204 | assert any such rights held by the Licensor to the limited 205 | extent necessary to allow You to exercise the Licensed 206 | Rights, but not otherwise. 207 | 208 | 2. Patent and trademark rights are not licensed under this 209 | Public License. 210 | 211 | 3. To the extent possible, the Licensor waives any right to 212 | collect royalties from You for the exercise of the Licensed 213 | Rights, whether directly or through a collecting society 214 | under any voluntary or waivable statutory or compulsory 215 | licensing scheme. In all other cases the Licensor expressly 216 | reserves any right to collect such royalties, including when 217 | the Licensed Material is used other than for NonCommercial 218 | purposes. 219 | 220 | Section 3 -- License Conditions. 221 | 222 | Your exercise of the Licensed Rights is expressly made subject to the 223 | following conditions. 224 | 225 | a. Attribution. 226 | 227 | 1. If You Share the Licensed Material (including in modified 228 | form), You must: 229 | 230 | a. retain the following if it is supplied by the Licensor 231 | with the Licensed Material: 232 | 233 | i. identification of the creator(s) of the Licensed 234 | Material and any others designated to receive 235 | attribution, in any reasonable manner requested by 236 | the Licensor (including by pseudonym if 237 | designated); 238 | 239 | ii. a copyright notice; 240 | 241 | iii. a notice that refers to this Public License; 242 | 243 | iv. a notice that refers to the disclaimer of 244 | warranties; 245 | 246 | v. a URI or hyperlink to the Licensed Material to the 247 | extent reasonably practicable; 248 | 249 | b. indicate if You modified the Licensed Material and 250 | retain an indication of any previous modifications; and 251 | 252 | c. indicate the Licensed Material is licensed under this 253 | Public License, and include the text of, or the URI or 254 | hyperlink to, this Public License. 255 | 256 | 2. You may satisfy the conditions in Section 3(a)(1) in any 257 | reasonable manner based on the medium, means, and context in 258 | which You Share the Licensed Material. For example, it may be 259 | reasonable to satisfy the conditions by providing a URI or 260 | hyperlink to a resource that includes the required 261 | information. 262 | 263 | 3. If requested by the Licensor, You must remove any of the 264 | information required by Section 3(a)(1)(A) to the extent 265 | reasonably practicable. 266 | 267 | 4. If You Share Adapted Material You produce, the Adapter's 268 | License You apply must not prevent recipients of the Adapted 269 | Material from complying with this Public License. 270 | 271 | Section 4 -- Sui Generis Database Rights. 272 | 273 | Where the Licensed Rights include Sui Generis Database Rights that 274 | apply to Your use of the Licensed Material: 275 | 276 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 277 | to extract, reuse, reproduce, and Share all or a substantial 278 | portion of the contents of the database for NonCommercial purposes 279 | only; 280 | 281 | b. if You include all or a substantial portion of the database 282 | contents in a database in which You have Sui Generis Database 283 | Rights, then the database in which You have Sui Generis Database 284 | Rights (but not its individual contents) is Adapted Material; and 285 | 286 | c. You must comply with the conditions in Section 3(a) if You Share 287 | all or a substantial portion of the contents of the database. 288 | 289 | For the avoidance of doubt, this Section 4 supplements and does not 290 | replace Your obligations under this Public License where the Licensed 291 | Rights include other Copyright and Similar Rights. 292 | 293 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 294 | 295 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 296 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 297 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 298 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 299 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 300 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 301 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 302 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 303 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 304 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 305 | 306 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 307 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 308 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 309 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 310 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 311 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 312 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 313 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 314 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 315 | 316 | c. The disclaimer of warranties and limitation of liability provided 317 | above shall be interpreted in a manner that, to the extent 318 | possible, most closely approximates an absolute disclaimer and 319 | waiver of all liability. 320 | 321 | Section 6 -- Term and Termination. 322 | 323 | a. This Public License applies for the term of the Copyright and 324 | Similar Rights licensed here. However, if You fail to comply with 325 | this Public License, then Your rights under this Public License 326 | terminate automatically. 327 | 328 | b. Where Your right to use the Licensed Material has terminated under 329 | Section 6(a), it reinstates: 330 | 331 | 1. automatically as of the date the violation is cured, provided 332 | it is cured within 30 days of Your discovery of the 333 | violation; or 334 | 335 | 2. upon express reinstatement by the Licensor. 336 | 337 | For the avoidance of doubt, this Section 6(b) does not affect any 338 | right the Licensor may have to seek remedies for Your violations 339 | of this Public License. 340 | 341 | c. For the avoidance of doubt, the Licensor may also offer the 342 | Licensed Material under separate terms or conditions or stop 343 | distributing the Licensed Material at any time; however, doing so 344 | will not terminate this Public License. 345 | 346 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 347 | License. 348 | 349 | Section 7 -- Other Terms and Conditions. 350 | 351 | a. The Licensor shall not be bound by any additional or different 352 | terms or conditions communicated by You unless expressly agreed. 353 | 354 | b. Any arrangements, understandings, or agreements regarding the 355 | Licensed Material not stated herein are separate from and 356 | independent of the terms and conditions of this Public License. 357 | 358 | Section 8 -- Interpretation. 359 | 360 | a. For the avoidance of doubt, this Public License does not, and 361 | shall not be interpreted to, reduce, limit, restrict, or impose 362 | conditions on any use of the Licensed Material that could lawfully 363 | be made without permission under this Public License. 364 | 365 | b. To the extent possible, if any provision of this Public License is 366 | deemed unenforceable, it shall be automatically reformed to the 367 | minimum extent necessary to make it enforceable. If the provision 368 | cannot be reformed, it shall be severed from this Public License 369 | without affecting the enforceability of the remaining terms and 370 | conditions. 371 | 372 | c. No term or condition of this Public License will be waived and no 373 | failure to comply consented to unless expressly agreed to by the 374 | Licensor. 375 | 376 | d. Nothing in this Public License constitutes or may be interpreted 377 | as a limitation upon, or waiver of, any privileges and immunities 378 | that apply to the Licensor or You, including from the legal 379 | processes of any jurisdiction or authority. 380 | 381 | ======================================================================= 382 | 383 | Creative Commons is not a party to its public 384 | licenses. Notwithstanding, Creative Commons may elect to apply one of 385 | its public licenses to material it publishes and in those instances 386 | will be considered the “Licensor.” The text of the Creative Commons 387 | public licenses is dedicated to the public domain under the CC0 Public 388 | Domain Dedication. Except for the limited purpose of indicating that 389 | material is shared under a Creative Commons public license or as 390 | otherwise permitted by the Creative Commons policies published at 391 | creativecommons.org/policies, Creative Commons does not authorize the 392 | use of the trademark "Creative Commons" or any other trademark or logo 393 | of Creative Commons without its prior written consent including, 394 | without limitation, in connection with any unauthorized modifications 395 | to any of its public licenses or any other arrangements, 396 | understandings, or agreements concerning use of licensed material. For 397 | the avoidance of doubt, this paragraph does not form part of the 398 | public licenses. 399 | 400 | Creative Commons may be contacted at creativecommons.org. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Dropout Reduces Underfitting 2 | 3 | Official PyTorch implementation for **Dropout Reduces Underfitting** 4 | 5 | > [**Dropout Reduces Underfitting**](https://arxiv.org/abs/2303.01500), ICML 2023
6 | > [Zhuang Liu*](https://liuzhuang13.github.io), [Zhiqiu Xu*](https://oscarxzq.github.io), [Joseph Jin](https://www.linkedin.com/in/joseph-jin/), [Zhiqiang Shen](https://zhiqiangshen.com/), [Trevor Darrell](https://people.eecs.berkeley.edu/~trevor/) (* equal contribution) 7 | >
Meta AI, UC Berkeley and MBZUAI
8 | 9 |

10 | 12 |

13 | 14 | Figure: We propose **early dropout** and **late dropout**. Early dropout helps underfitting models fit the data better and achieve lower training loss. Late dropout helps improve the generalization performance of overfitting models. 15 | 16 | 17 | ## Results on ImageNet-1K 18 | 19 | Model weights are released as links on results. 20 | 21 | ### Early Dropout 22 | 23 | results with basic recipe (s.d. = stochastic depth) 24 | 25 | | model| ViT-T | Mixer-S | Swin-F | ConvNeXt-F | 26 | |:---|:---:|:---:|:---:|:---:| 27 | | no dropout | 73.9 | 71.0 | 74.3 | 76.1 | 28 | | standard dropout | 67.9 | 67.1 | 71.6 | - | 29 | | standard s.d. | 72.6 | 70.5 | 73.7 | 75.5 | 30 | | early dropout | [**74.3**](https://drive.google.com/file/d/1Sk93Fz8Pih7qLvAcyRtRjJ2w2yihh7S3/view?usp=share_link) | [**71.3**](https://drive.google.com/file/d/199i9rRD-u2DA22qmoZH774ibyhFUH_mE/view?usp=share_link) | [**74.7**](https://drive.google.com/file/d/1gP6vnd-wNyRiU_BTxWHbvrhNykMH-N0r/view?usp=share_link) | - | 31 | | early s.d. | [**74.4**](https://drive.google.com/file/d/1E550KlZVgsK30u9wX5q5tXXUAxB9gfPg/view?usp=share_link) | [**71.7**](https://drive.google.com/file/d/1jPtWufetAQhM4oe6wOgdTsKozYCXRmdb/view?usp=share_link) | [**75.2**](https://drive.google.com/file/d/1kmFlmG_-eIdj_MLfHWtEvbvTPE6IBcYQ/view?usp=share_link) | [**76.3**](https://drive.google.com/file/d/10v7Ua4f6FlJ3VIf6TV35beSBxjGTcv9U/view?usp=share_link) | 32 | 33 | 34 | results with improved recipe 35 | 36 | | model | ViT-T | Swin-F | ConvNeXt-F | 37 | |:------------|:-----:|:------:|:----------:| 38 | | no dropout | 76.3 | 76.1 | 77.5 | 39 | | standard dropout | 71.5 | 73.5 | - | 40 | | standard s.d. | 75.6 | 75.6 | 77.4 | 41 | | early dropout | [**76.7**](https://drive.google.com/file/d/1q3kopfA2KazTaR9kuEEM5lzdNKHX2OQl/view?usp=share_link) | [**76.6**](https://drive.google.com/file/d/1Os16aIWD1WpSlccsFboesc0BgXN6KJ9C/view?usp=share_link) | - | 42 | | early s.d. | [**76.7**](https://drive.google.com/file/d/1GTfGbNObvGDytdb9F5wgUHxnhlVRXs6o/view?usp=share_link) | [**76.6**](https://drive.google.com/file/d/17mNr8e-TVQoVM0I6IxVJNC4f3Y--T4R_/view?usp=share_link) | [**77.7**](https://drive.google.com/file/d/1sIePqyxk5ajVdsSCRCJoTIZsV8O_fKmO/view?usp=share_link) | 43 | 44 | 45 | ### Late Dropout 46 | results with basic recipe 47 | 48 | | model | ViT-B | Mixer-B | 49 | |:------------:|:-----:|:-------:| 50 | | standard s.d. | 81.6 | 78.0 | 51 | | late s.d. | [**82.3**](https://drive.google.com/file/d/1_AB51g6AHF-C9oGWffwOw4C1Xug1LT_0/view?usp=share_link) | [**78.6**](https://drive.google.com/file/d/1CWEi8hyEIKz7F21HlsaEIgp8eaHNFHfe/view?usp=share_link) | 52 | 53 | 54 | ## Installation 55 | Please check [INSTALL.md](INSTALL.md) for installation instructions. 56 | 57 | ## Training 58 | 59 | ### Basic Recipe 60 | We list commands for early dropout, early stochastic depth on `ViT-T` and late stochastic depth on `ViT-B`. 61 | - For training other models, change `--model` accordingly, e.g., to `vit_tiny`, `mixer_s32`, `convnext_femto`, `mixer_b16`, `vit_base`. 62 | - Our results were produced with 4 nodes, each with 8 gpus. Below we give example commands on both multi-node and single-machine setups. 63 | 64 | **Early dropout** 65 | 66 | multi-node 67 | ``` 68 | python run_with_submitit.py --nodes 4 --ngpus 8 \ 69 | --model vit_tiny --epochs 300 \ 70 | --batch_size 128 --lr 4e-3 --update_freq 1 \ 71 | --dropout 0.1 --drop_mode early --drop_schedule linear --cutoff_epoch 50 \ 72 | --data_path /path/to/data/ \ 73 | --output_dir /path/to/results/ 74 | ``` 75 | 76 | single-machine 77 | ``` 78 | python -m torch.distributed.launch --nproc_per_node=8 main.py \ 79 | --model vit_tiny --epochs 300 \ 80 | --batch_size 128 --lr 4e-3 --update_freq 4 \ 81 | --dropout 0.1 --drop_mode early --drop_schedule linear --cutoff_epoch 50 \ 82 | --data_path /path/to/data/ \ 83 | --output_dir /path/to/results/ 84 | ``` 85 | 86 | **Early stochastic depth** 87 | ``` 88 | python -m torch.distributed.launch --nproc_per_node=8 main.py \ 89 | --model vit_tiny --epochs 300 \ 90 | --batch_size 128 --lr 4e-3 --update_freq 4 \ 91 | --drop_path 0.5 --drop_mode early --drop_schedule linear --cutoff_epoch 50 \ 92 | --data_path /path/to/data/ \ 93 | --output_dir /path/to/results/ 94 | ``` 95 | 96 | **Late stochastic depth** 97 | ``` 98 | python -m torch.distributed.launch --nproc_per_node=8 main.py \ 99 | --model vit_base --epochs 300 \ 100 | --batch_size 128 --lr 4e-3 --update_freq 4 \ 101 | --drop_path 0.4 --drop_mode late --drop_schedule constant --cutoff_epoch 50 \ 102 | --data_path /path/to/data/ \ 103 | --output_dir /path/to/results/ 104 | ``` 105 | 106 | **Standard dropout / no dropout** (replace $p with 0.1 / 0.0) 107 | ``` 108 | python -m torch.distributed.launch --nproc_per_node=8 main.py \ 109 | --model vit_tiny --epochs 300 \ 110 | --batch_size 128 --lr 4e-3 --update_freq 4 \ 111 | --dropout $p --drop_mode standard \ 112 | --data_path /path/to/data/ \ 113 | --output_dir /path/to/results/ 114 | ``` 115 | 116 | 117 | ### Improved Recipe 118 | Our improved recipe extends training epochs from `300` to `600`, and reduces both `mixup` and `cutmix` to `0.3`. 119 | 120 | **Early dropout** 121 | ``` 122 | python -m torch.distributed.launch --nproc_per_node=8 main.py \ 123 | --model vit_tiny --epochs 600 --mixup 0.3 --cutmix 0.3 \ 124 | --batch_size 128 --lr 4e-3 --update_freq 4 \ 125 | --dropout 0.1 --drop_mode early --drop_schedule linear --cutoff_epoch 50 \ 126 | --data_path /path/to/data/ \ 127 | --output_dir /path/to/results/ 128 | ``` 129 | 130 | **Early stochastic depth** 131 | ``` 132 | python -m torch.distributed.launch --nproc_per_node=8 main.py \ 133 | --model vit_tiny --epochs 600 --mixup 0.3 --cutmix 0.3 \ 134 | --batch_size 128 --lr 4e-3 --update_freq 4 \ 135 | --drop_path 0.5 --drop_mode early --drop_schedule linear --cutoff_epoch 50 \ 136 | --data_path /path/to/data/ \ 137 | --output_dir /path/to/results/ 138 | ``` 139 | 140 | ### Evaluation 141 | 142 | single-GPU 143 | ``` 144 | python main.py --model vit_tiny --eval true \ 145 | --resume /path/to/model \ 146 | --data_path /path/to/data 147 | ``` 148 | 149 | multi-GPU 150 | ``` 151 | python -m torch.distributed.launch --nproc_per_node=8 main.py \ 152 | --model vit_tiny --eval true \ 153 | --resume /path/to/model \ 154 | --data_path /path/to/data 155 | ``` 156 | 157 | 158 | ## Acknowledgement 159 | This repository is built using the [timm](https://github.com/rwightman/pytorch-image-models) library and [ConvNeXt](https://github.com/facebookresearch/ConvNeXt) codebase. 160 | 161 | ## License 162 | This project is released under the CC-BY-NC 4.0 license. Please see the [LICENSE](LICENSE) file for more information. 163 | 164 | ## Citation 165 | If you find this repository helpful, please consider citing: 166 | ```bibtex 167 | @inproceedings{liu2023dropout, 168 | title={Dropout Reduces Underfitting}, 169 | author={Zhuang Liu, Zhiqiu Xu, Joseph Jin, Zhiqiang Shen, Trevor Darrell}, 170 | year={2023}, 171 | booktitle={International Conference on Machine Learning}, 172 | } 173 | ``` 174 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import os 9 | from torchvision import datasets, transforms 10 | 11 | from timm.data.constants import \ 12 | IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD 13 | from timm.data import create_transform 14 | 15 | def build_dataset(is_train, args): 16 | transform = build_transform(is_train, args) 17 | 18 | print("Transform = ") 19 | if isinstance(transform, tuple): 20 | for trans in transform: 21 | print(" - - - - - - - - - - ") 22 | for t in trans.transforms: 23 | print(t) 24 | else: 25 | for t in transform.transforms: 26 | print(t) 27 | print("---------------------------") 28 | 29 | if args.data_set == 'CIFAR': 30 | dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform, download=True) 31 | nb_classes = 100 32 | elif args.data_set == 'IMNET': 33 | print("reading from datapath", args.data_path) 34 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 35 | dataset = datasets.ImageFolder(root, transform=transform) 36 | nb_classes = 1000 37 | elif args.data_set == "image_folder": 38 | root = args.data_path if is_train else args.eval_data_path 39 | dataset = datasets.ImageFolder(root, transform=transform) 40 | nb_classes = args.nb_classes 41 | assert len(dataset.class_to_idx) == nb_classes 42 | else: 43 | raise NotImplementedError() 44 | print("Number of the class = %d" % nb_classes) 45 | 46 | return dataset, nb_classes 47 | 48 | 49 | def build_transform(is_train, args): 50 | resize_im = args.input_size > 32 51 | imagenet_default_mean_and_std = args.imagenet_default_mean_and_std 52 | mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN 53 | std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD 54 | 55 | if is_train: 56 | # this should always dispatch to transforms_imagenet_train 57 | transform = create_transform( 58 | input_size=args.input_size, 59 | is_training=True, 60 | color_jitter=args.color_jitter, 61 | auto_augment=args.aa, 62 | interpolation=args.train_interpolation, 63 | re_prob=args.reprob, 64 | re_mode=args.remode, 65 | re_count=args.recount, 66 | mean=mean, 67 | std=std, 68 | ) 69 | if not resize_im: 70 | transform.transforms[0] = transforms.RandomCrop( 71 | args.input_size, padding=4) 72 | return transform 73 | 74 | t = [] 75 | if resize_im: 76 | # warping (no cropping) when evaluated at 384 or larger 77 | if args.input_size >= 384: 78 | t.append( 79 | transforms.Resize((args.input_size, args.input_size), 80 | interpolation=transforms.InterpolationMode.BICUBIC), 81 | ) 82 | print(f"Warping {args.input_size} size input images...") 83 | else: 84 | if args.crop_pct is None: 85 | args.crop_pct = 224 / 256 86 | size = int(args.input_size / args.crop_pct) 87 | t.append( 88 | # to maintain same ratio w.r.t. 224 images 89 | transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC), 90 | ) 91 | t.append(transforms.CenterCrop(args.input_size)) 92 | 93 | t.append(transforms.ToTensor()) 94 | t.append(transforms.Normalize(mean, std)) 95 | return transforms.Compose(t) 96 | -------------------------------------------------------------------------------- /drop_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def drop_scheduler(drop_rate, epochs, niter_per_ep, cutoff_epoch=0, mode="standard", schedule="constant"): 4 | assert mode in ["standard", "early", "late"] 5 | if mode == "standard": 6 | return np.full(epochs * niter_per_ep, drop_rate) 7 | 8 | early_iters = cutoff_epoch * niter_per_ep 9 | late_iters = (epochs - cutoff_epoch) * niter_per_ep 10 | 11 | if mode == "early": 12 | assert schedule in ["constant", "linear"] 13 | if schedule == 'constant': 14 | early_schedule = np.full(early_iters, drop_rate) 15 | elif schedule == 'linear': 16 | early_schedule = np.linspace(drop_rate, 0, early_iters) 17 | final_schedule = np.concatenate((early_schedule, np.full(late_iters, 0))) 18 | 19 | elif mode == "late": 20 | assert schedule in ["constant"] 21 | early_schedule = np.full(early_iters, 0) 22 | final_schedule = np.concatenate((early_schedule, np.full(late_iters, drop_rate))) 23 | 24 | assert len(final_schedule) == epochs * niter_per_ep 25 | return final_schedule -------------------------------------------------------------------------------- /engine.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import math 9 | from typing import Iterable, Optional 10 | import torch 11 | from timm.data import Mixup 12 | from timm.utils import accuracy, ModelEma 13 | 14 | import utils 15 | 16 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, 17 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 18 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 19 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, log_writer=None, 20 | wandb_logger=None, start_steps=None, lr_schedule_values=None, wd_schedule_values=None, schedules={}, 21 | num_training_steps_per_epoch=None, update_freq=None, use_amp=False): 22 | model.train(True) 23 | metric_logger = utils.MetricLogger(delimiter=" ") 24 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 25 | metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 26 | header = 'Epoch: [{}]'.format(epoch) 27 | print_freq = 10 28 | 29 | optimizer.zero_grad() 30 | 31 | for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 32 | step = data_iter_step // update_freq 33 | if step >= num_training_steps_per_epoch: 34 | continue 35 | it = start_steps + step # global training iteration 36 | # Update LR & WD for the first acc 37 | if data_iter_step % update_freq == 0: 38 | if lr_schedule_values is not None or wd_schedule_values is not None: 39 | for i, param_group in enumerate(optimizer.param_groups): 40 | if lr_schedule_values is not None: 41 | param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"] 42 | if wd_schedule_values is not None and param_group["weight_decay"] > 0: 43 | param_group["weight_decay"] = wd_schedule_values[it] 44 | if 'dp' in schedules: 45 | model.module.update_drop_path(schedules['dp'][it]) 46 | if 'do' in schedules: 47 | model.module.update_dropout(schedules['do'][it]) 48 | 49 | samples = samples.to(device, non_blocking=True) 50 | targets = targets.to(device, non_blocking=True) 51 | 52 | if mixup_fn is not None: 53 | samples, targets = mixup_fn(samples, targets) 54 | 55 | if use_amp: 56 | with torch.cuda.amp.autocast(): 57 | output = model(samples) 58 | loss = criterion(output, targets) 59 | else: # full precision 60 | output = model(samples) 61 | loss = criterion(output, targets) 62 | 63 | loss_value = loss.item() 64 | 65 | if not math.isfinite(loss_value): # this could trigger if using AMP 66 | print("Loss is {}, stopping training".format(loss_value)) 67 | assert math.isfinite(loss_value) 68 | 69 | if use_amp: 70 | # this attribute is added by timm on one optimizer (adahessian) 71 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 72 | loss /= update_freq 73 | grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm, 74 | parameters=model.parameters(), create_graph=is_second_order, 75 | update_grad=(data_iter_step + 1) % update_freq == 0) 76 | if (data_iter_step + 1) % update_freq == 0: 77 | optimizer.zero_grad() 78 | if model_ema is not None: 79 | model_ema.update(model) 80 | else: # full precision 81 | loss /= update_freq 82 | loss.backward() 83 | if (data_iter_step + 1) % update_freq == 0: 84 | optimizer.step() 85 | optimizer.zero_grad() 86 | if model_ema is not None: 87 | model_ema.update(model) 88 | 89 | torch.cuda.synchronize() 90 | 91 | if mixup_fn is None: 92 | class_acc = (output.max(-1)[-1] == targets).float().mean() 93 | else: 94 | class_acc = None 95 | metric_logger.update(loss=loss_value) 96 | metric_logger.update(class_acc=class_acc) 97 | min_lr = 10. 98 | max_lr = 0. 99 | for group in optimizer.param_groups: 100 | min_lr = min(min_lr, group["lr"]) 101 | max_lr = max(max_lr, group["lr"]) 102 | 103 | metric_logger.update(lr=max_lr) 104 | metric_logger.update(min_lr=min_lr) 105 | weight_decay_value = None 106 | for group in optimizer.param_groups: 107 | if group["weight_decay"] > 0: 108 | weight_decay_value = group["weight_decay"] 109 | metric_logger.update(weight_decay=weight_decay_value) 110 | 111 | if 'dp' in schedules: 112 | metric_logger.update(drop_path=model.module.drop_path) 113 | 114 | if 'do' in schedules: 115 | metric_logger.update(dropout=model.module.drop_rate) 116 | 117 | if use_amp: 118 | metric_logger.update(grad_norm=grad_norm) 119 | 120 | if log_writer is not None: 121 | log_writer.update(loss=loss_value, head="loss") 122 | log_writer.update(class_acc=class_acc, head="loss") 123 | log_writer.update(lr=max_lr, head="opt") 124 | log_writer.update(min_lr=min_lr, head="opt") 125 | log_writer.update(weight_decay=weight_decay_value, head="opt") 126 | if use_amp: 127 | log_writer.update(grad_norm=grad_norm, head="opt") 128 | log_writer.set_step() 129 | 130 | if wandb_logger: 131 | wandb_logger._wandb.log({ 132 | 'Rank-0 Batch Wise/train_loss': loss_value, 133 | 'Rank-0 Batch Wise/train_max_lr': max_lr, 134 | 'Rank-0 Batch Wise/train_min_lr': min_lr 135 | }, commit=False) 136 | if class_acc: 137 | wandb_logger._wandb.log({'Rank-0 Batch Wise/train_class_acc': class_acc}, commit=False) 138 | if use_amp: 139 | wandb_logger._wandb.log({'Rank-0 Batch Wise/train_grad_norm': grad_norm}, commit=False) 140 | wandb_logger._wandb.log({'Rank-0 Batch Wise/global_train_step': it}) 141 | 142 | 143 | # gather the stats from all processes 144 | metric_logger.synchronize_between_processes() 145 | print("Averaged stats:", metric_logger) 146 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 147 | 148 | @torch.no_grad() 149 | def evaluate(data_loader, model, device, use_amp=False): 150 | criterion = torch.nn.CrossEntropyLoss() 151 | 152 | metric_logger = utils.MetricLogger(delimiter=" ") 153 | header = 'Test:' 154 | 155 | # switch to evaluation mode 156 | model.eval() 157 | for batch in metric_logger.log_every(data_loader, 10, header): 158 | images = batch[0] 159 | target = batch[-1] 160 | 161 | images = images.to(device, non_blocking=True) 162 | target = target.to(device, non_blocking=True) 163 | 164 | # compute output 165 | if use_amp: 166 | with torch.cuda.amp.autocast(): 167 | output = model(images) 168 | loss = criterion(output, target) 169 | else: 170 | output = model(images) 171 | loss = criterion(output, target) 172 | 173 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 174 | 175 | batch_size = images.shape[0] 176 | metric_logger.update(loss=loss.item()) 177 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 178 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 179 | # gather the stats from all processes 180 | metric_logger.synchronize_between_processes() 181 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 182 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 183 | 184 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 185 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import argparse 9 | import datetime 10 | import numpy as np 11 | import time 12 | import torch 13 | import torch.nn as nn 14 | import torch.backends.cudnn as cudnn 15 | import json 16 | import os 17 | 18 | from pathlib import Path 19 | 20 | from timm.data.mixup import Mixup 21 | from timm.models import create_model 22 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 23 | from timm.utils import ModelEma 24 | from optim_factory import create_optimizer, LayerDecayValueAssigner 25 | 26 | from datasets import build_dataset 27 | from engine import train_one_epoch, evaluate 28 | 29 | from utils import NativeScalerWithGradNormCount as NativeScaler 30 | import utils 31 | from drop_scheduler import drop_scheduler 32 | 33 | import models.convnext 34 | import models.vision_transformer 35 | import models.swin_transformer 36 | import models.mlp_mixer 37 | 38 | def str2bool(v): 39 | """ 40 | Converts string to bool type; enables command line 41 | arguments in the format of '--arg1 true --arg2 false' 42 | """ 43 | if isinstance(v, bool): 44 | return v 45 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 46 | return True 47 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 48 | return False 49 | else: 50 | raise argparse.ArgumentTypeError('Boolean value expected.') 51 | 52 | def get_args_parser(): 53 | parser = argparse.ArgumentParser('ConvNeXt training and evaluation script for image classification', add_help=False) 54 | parser.add_argument('--batch_size', default=64, type=int, 55 | help='Per GPU batch size') 56 | parser.add_argument('--epochs', default=300, type=int) 57 | parser.add_argument('--update_freq', default=1, type=int, 58 | help='gradient accumulation steps') 59 | 60 | # Model parameters 61 | parser.add_argument('--model', default='convnext_tiny', type=str, metavar='MODEL', 62 | help='Name of model to train') 63 | parser.add_argument('--input_size', default=224, type=int, 64 | help='image input size') 65 | parser.add_argument('--layer_scale_init_value', default=1e-6, type=float, 66 | help="Layer scale initial values") 67 | 68 | ########################## settings specific to this project ########################## 69 | 70 | # dropout and stochastic depth drop rate; set at most one to non-zero 71 | parser.add_argument('--dropout', type=float, default=0, metavar='PCT', 72 | help='Drop path rate (default: 0.0)') 73 | parser.add_argument('--drop_path', type=float, default=0, metavar='PCT', 74 | help='Drop path rate (default: 0.0)') 75 | 76 | # early / late dropout and stochastic depth settings 77 | parser.add_argument('--drop_mode', type=str, default='standard', choices=['standard', 'early', 'late'], help='drop mode') 78 | parser.add_argument('--drop_schedule', type=str, default='constant', choices=['constant', 'linear'], 79 | help='drop schedule for early dropout / s.d. only') 80 | parser.add_argument('--cutoff_epoch', type=int, default=0, 81 | help='if drop_mode is early / late, this is the epoch where dropout ends / starts') 82 | 83 | ####################################################################################### 84 | 85 | # EMA related parameters 86 | parser.add_argument('--model_ema', type=str2bool, default=False) 87 | parser.add_argument('--model_ema_decay', type=float, default=0.9999, help='') 88 | parser.add_argument('--model_ema_force_cpu', type=str2bool, default=False, help='') 89 | parser.add_argument('--model_ema_eval', type=str2bool, default=False, help='Using ema to eval during training.') 90 | 91 | # Optimization parameters 92 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', 93 | help='Optimizer (default: "adamw"') 94 | parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON', 95 | help='Optimizer Epsilon (default: 1e-8)') 96 | parser.add_argument('--opt_betas', default=None, type=float, nargs='+', metavar='BETA', 97 | help='Optimizer Betas (default: None, use opt default)') 98 | parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM', 99 | help='Clip gradient norm (default: None, no clipping)') 100 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 101 | help='SGD momentum (default: 0.9)') 102 | parser.add_argument('--weight_decay', type=float, default=0.05, 103 | help='weight decay (default: 0.05)') 104 | parser.add_argument('--weight_decay_end', type=float, default=None, help="""Final value of the 105 | weight decay. We use a cosine schedule for WD and using a larger decay by 106 | the end of training improves performance for ViTs.""") 107 | 108 | parser.add_argument('--lr', type=float, default=4e-3, metavar='LR', 109 | help='learning rate (default: 4e-3), with total batch size 4096') 110 | parser.add_argument('--layer_decay', type=float, default=1.0) 111 | parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR', 112 | help='lower lr bound for cyclic schedulers that hit 0 (1e-6)') 113 | parser.add_argument('--warmup_epochs', type=int, default=50, metavar='N', 114 | help='epochs to warmup LR, if scheduler supports') 115 | parser.add_argument('--warmup_steps', type=int, default=-1, metavar='N', 116 | help='num of steps to warmup LR, will overload warmup_epochs if set > 0') 117 | 118 | # Augmentation parameters 119 | parser.add_argument('--color_jitter', type=float, default=0.4, metavar='PCT', 120 | help='Color jitter factor (default: 0.4)') 121 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', 122 | help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'), 123 | parser.add_argument('--smoothing', type=float, default=0.1, 124 | help='Label smoothing (default: 0.1)') 125 | parser.add_argument('--train_interpolation', type=str, default='bicubic', 126 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")') 127 | 128 | # Evaluation parameters 129 | parser.add_argument('--crop_pct', type=float, default=None) 130 | 131 | # * Random Erase params 132 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', 133 | help='Random erase prob (default: 0.25)') 134 | parser.add_argument('--remode', type=str, default='pixel', 135 | help='Random erase mode (default: "pixel")') 136 | parser.add_argument('--recount', type=int, default=1, 137 | help='Random erase count (default: 1)') 138 | parser.add_argument('--resplit', type=str2bool, default=False, 139 | help='Do not random erase first (clean) augmentation split') 140 | 141 | # * Mixup params 142 | parser.add_argument('--mixup', type=float, default=0.8, 143 | help='mixup alpha, mixup enabled if > 0.') 144 | parser.add_argument('--cutmix', type=float, default=1.0, 145 | help='cutmix alpha, cutmix enabled if > 0.') 146 | parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None, 147 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') 148 | parser.add_argument('--mixup_prob', type=float, default=1.0, 149 | help='Probability of performing mixup or cutmix when either/both is enabled') 150 | parser.add_argument('--mixup_switch_prob', type=float, default=0.5, 151 | help='Probability of switching to cutmix when both mixup and cutmix enabled') 152 | parser.add_argument('--mixup_mode', type=str, default='batch', 153 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') 154 | 155 | # * Finetuning params 156 | parser.add_argument('--finetune', default='', 157 | help='finetune from checkpoint') 158 | parser.add_argument('--head_init_scale', default=1.0, type=float, 159 | help='classifier head initial scale, typically adjusted in fine-tuning') 160 | parser.add_argument('--model_key', default='model|module', type=str, 161 | help='which key to load from saved state dict, usually model or model_ema') 162 | parser.add_argument('--model_prefix', default='', type=str) 163 | 164 | # Dataset parameters 165 | parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str, 166 | help='dataset path') 167 | parser.add_argument('--eval_data_path', default=None, type=str, 168 | help='dataset path for evaluation') 169 | parser.add_argument('--nb_classes', default=1000, type=int, 170 | help='number of the classification types') 171 | parser.add_argument('--imagenet_default_mean_and_std', type=str2bool, default=True) 172 | parser.add_argument('--data_set', default='IMNET', choices=['CIFAR', 'IMNET', 'image_folder'], 173 | type=str, help='ImageNet dataset path') 174 | parser.add_argument('--output_dir', default='', 175 | help='path where to save, empty for no saving') 176 | parser.add_argument('--log_dir', default=None, 177 | help='path where to tensorboard log') 178 | parser.add_argument('--device', default='cuda', 179 | help='device to use for training / testing') 180 | parser.add_argument('--seed', default=0, type=int) 181 | 182 | parser.add_argument('--resume', default='', 183 | help='resume from checkpoint') 184 | parser.add_argument('--auto_resume', type=str2bool, default=True) 185 | parser.add_argument('--save_ckpt', type=str2bool, default=True) 186 | parser.add_argument('--save_ckpt_freq', default=1, type=int) 187 | parser.add_argument('--save_ckpt_num', default=3, type=int) 188 | 189 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 190 | help='start epoch') 191 | parser.add_argument('--eval', type=str2bool, default=False, 192 | help='Perform evaluation only') 193 | parser.add_argument('--dist_eval', type=str2bool, default=True, 194 | help='Enabling distributed evaluation') 195 | parser.add_argument('--disable_eval', type=str2bool, default=False, 196 | help='Disabling evaluation during training') 197 | parser.add_argument('--num_workers', default=10, type=int) 198 | parser.add_argument('--pin_mem', type=str2bool, default=True, 199 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 200 | 201 | # distributed training parameters 202 | parser.add_argument('--world_size', default=1, type=int, 203 | help='number of distributed processes') 204 | parser.add_argument('--local_rank', default=-1, type=int) 205 | parser.add_argument('--dist_on_itp', type=str2bool, default=False) 206 | parser.add_argument('--dist_url', default='env://', 207 | help='url used to set up distributed training') 208 | 209 | parser.add_argument('--use_amp', type=str2bool, default=False, 210 | help="Use PyTorch's AMP (Automatic Mixed Precision) or not") 211 | 212 | # Weights and Biases arguments 213 | parser.add_argument('--enable_wandb', type=str2bool, default=False, 214 | help="enable logging to Weights and Biases") 215 | parser.add_argument('--project', default='convnext', type=str, 216 | help="The name of the W&B project where you're sending the new run.") 217 | parser.add_argument('--wandb_ckpt', type=str2bool, default=False, 218 | help="Save model checkpoints as W&B Artifacts.") 219 | 220 | return parser 221 | 222 | def main(args): 223 | utils.init_distributed_mode(args) 224 | print(args) 225 | device = torch.device(args.device) 226 | 227 | # fix the seed for reproducibility 228 | seed = args.seed + utils.get_rank() 229 | torch.manual_seed(seed) 230 | np.random.seed(seed) 231 | cudnn.benchmark = True 232 | 233 | dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) 234 | if args.disable_eval: 235 | args.dist_eval = False 236 | dataset_val = None 237 | else: 238 | dataset_val, _ = build_dataset(is_train=False, args=args) 239 | 240 | num_tasks = utils.get_world_size() 241 | global_rank = utils.get_rank() 242 | 243 | sampler_train = torch.utils.data.DistributedSampler( 244 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True, seed=args.seed, 245 | ) 246 | print("Sampler_train = %s" % str(sampler_train)) 247 | if args.dist_eval: 248 | if len(dataset_val) % num_tasks != 0: 249 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 250 | 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 251 | 'equal num of samples per-process.') 252 | sampler_val = torch.utils.data.DistributedSampler( 253 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) 254 | else: 255 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 256 | 257 | if global_rank == 0 and args.log_dir is not None: 258 | os.makedirs(args.log_dir, exist_ok=True) 259 | log_writer = utils.TensorboardLogger(log_dir=args.log_dir) 260 | else: 261 | log_writer = None 262 | 263 | if global_rank == 0 and args.enable_wandb: 264 | wandb_logger = utils.WandbLogger(args) 265 | else: 266 | wandb_logger = None 267 | 268 | data_loader_train = torch.utils.data.DataLoader( 269 | dataset_train, sampler=sampler_train, 270 | batch_size=args.batch_size, 271 | num_workers=args.num_workers, 272 | pin_memory=args.pin_mem, 273 | drop_last=True, 274 | ) 275 | 276 | if dataset_val is not None: 277 | data_loader_val = torch.utils.data.DataLoader( 278 | dataset_val, sampler=sampler_val, 279 | batch_size=int(1.5 * args.batch_size), 280 | num_workers=args.num_workers, 281 | pin_memory=args.pin_mem, 282 | drop_last=False 283 | ) 284 | else: 285 | data_loader_val = None 286 | 287 | mixup_fn = None 288 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None 289 | if mixup_active: 290 | print("Mixup is activated!") 291 | mixup_fn = Mixup( 292 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, 293 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, 294 | label_smoothing=args.smoothing, num_classes=args.nb_classes) 295 | 296 | model = utils.build_model(args) 297 | if args.finetune: 298 | if args.finetune.startswith('https'): 299 | checkpoint = torch.hub.load_state_dict_from_url( 300 | args.finetune, map_location='cpu', check_hash=True) 301 | else: 302 | checkpoint = torch.load(args.finetune, map_location='cpu') 303 | 304 | print("Load ckpt from %s" % args.finetune) 305 | checkpoint_model = None 306 | for model_key in args.model_key.split('|'): 307 | if model_key in checkpoint: 308 | checkpoint_model = checkpoint[model_key] 309 | print("Load state_dict by model_key = %s" % model_key) 310 | break 311 | if checkpoint_model is None: 312 | checkpoint_model = checkpoint 313 | state_dict = model.state_dict() 314 | for k in ['head.weight', 'head.bias']: 315 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: 316 | print(f"Removing key {k} from pretrained checkpoint") 317 | del checkpoint_model[k] 318 | utils.load_state_dict(model, checkpoint_model, prefix=args.model_prefix) 319 | model.to(device) 320 | 321 | model_ema = None 322 | if args.model_ema: 323 | # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper 324 | model_ema = ModelEma( 325 | model, 326 | decay=args.model_ema_decay, 327 | device='cpu' if args.model_ema_force_cpu else '', 328 | resume='') 329 | print("Using EMA with decay = %.8f" % args.model_ema_decay) 330 | 331 | model_without_ddp = model 332 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 333 | 334 | print("Model = %s" % str(model_without_ddp)) 335 | print('number of params:', n_parameters) 336 | 337 | total_batch_size = args.batch_size * args.update_freq * utils.get_world_size() 338 | num_training_steps_per_epoch = len(dataset_train) // total_batch_size 339 | print("LR = %.8f" % args.lr) 340 | print("Batch size = %d" % total_batch_size) 341 | print("Update frequent = %d" % args.update_freq) 342 | print("Number of training examples = %d" % len(dataset_train)) 343 | print("Number of training training per epoch = %d" % num_training_steps_per_epoch) 344 | 345 | if args.layer_decay < 1.0 or args.layer_decay > 1.0: 346 | num_layers = 12 # convnext layers divided into 12 parts, each with a different decayed lr value. 347 | assert args.model in ['convnext_small', 'convnext_base', 'convnext_large', 'convnext_xlarge'], \ 348 | "Layer Decay impl only supports convnext_small/base/large/xlarge" 349 | assigner = LayerDecayValueAssigner(list(args.layer_decay ** (num_layers + 1 - i) for i in range(num_layers + 2))) 350 | else: 351 | assigner = None 352 | 353 | if assigner is not None: 354 | print("Assigned values = %s" % str(assigner.values)) 355 | 356 | if args.distributed: 357 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False) 358 | model_without_ddp = model.module 359 | 360 | optimizer = create_optimizer( 361 | args, model_without_ddp, skip_list=None, 362 | get_num_layer=assigner.get_layer_id if assigner is not None else None, 363 | get_layer_scale=assigner.get_scale if assigner is not None else None) 364 | 365 | loss_scaler = NativeScaler() # if args.use_amp is False, this won't be used 366 | 367 | print("Use Cosine LR scheduler") 368 | lr_schedule_values = utils.cosine_scheduler( 369 | args.lr, args.min_lr, args.epochs, num_training_steps_per_epoch, 370 | warmup_epochs=args.warmup_epochs, warmup_steps=args.warmup_steps, 371 | ) 372 | 373 | if args.weight_decay_end is None: 374 | args.weight_decay_end = args.weight_decay 375 | wd_schedule_values = utils.cosine_scheduler( 376 | args.weight_decay, args.weight_decay_end, args.epochs, num_training_steps_per_epoch) 377 | print("Max WD = %.7f, Min WD = %.7f" % (max(wd_schedule_values), min(wd_schedule_values))) 378 | 379 | schedules = {} 380 | 381 | # At most one of dropout and stochastic depth should be enabled. 382 | assert(args.dropout == 0 or args.drop_path == 0) 383 | # ConvNeXt does not support dropout. 384 | assert(args.dropout == 0 if args.model.startswith("convnext") else True) 385 | 386 | if args.dropout > 0: 387 | schedules['do'] = drop_scheduler( 388 | args.dropout, args.epochs, num_training_steps_per_epoch, 389 | args.cutoff_epoch, args.drop_mode, args.drop_schedule) 390 | print("Min DO = %.7f, Max DO = %.7f" % (min(schedules['do']), max(schedules['do']))) 391 | 392 | if args.drop_path > 0: 393 | schedules['dp'] = drop_scheduler( 394 | args.drop_path, args.epochs, num_training_steps_per_epoch, 395 | args.cutoff_epoch, args.drop_mode, args.drop_schedule) 396 | print("Min DP = %.7f, Max DP = %.7f" % (min(schedules['dp']), max(schedules['dp']))) 397 | 398 | if mixup_fn is not None: 399 | # smoothing is handled with mixup label transform 400 | criterion = SoftTargetCrossEntropy() 401 | elif args.smoothing > 0.: 402 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) 403 | else: 404 | criterion = torch.nn.CrossEntropyLoss() 405 | 406 | print("criterion = %s" % str(criterion)) 407 | 408 | utils.auto_load_model( 409 | args=args, model=model, model_without_ddp=model_without_ddp, 410 | optimizer=optimizer, loss_scaler=loss_scaler, model_ema=model_ema) 411 | 412 | if args.eval: 413 | print(f"Eval only mode") 414 | test_stats = evaluate(data_loader_val, model, device, use_amp=args.use_amp) 415 | print(f"Accuracy of the network on {len(dataset_val)} test images: {test_stats['acc1']:.5f}%") 416 | return 417 | 418 | max_accuracy = 0.0 419 | if args.model_ema and args.model_ema_eval: 420 | max_accuracy_ema = 0.0 421 | 422 | print("Start training for %d epochs" % args.epochs) 423 | start_time = time.time() 424 | for epoch in range(args.start_epoch, args.epochs): 425 | if args.distributed: 426 | data_loader_train.sampler.set_epoch(epoch) 427 | if log_writer is not None: 428 | log_writer.set_step(epoch * num_training_steps_per_epoch * args.update_freq) 429 | if wandb_logger: 430 | wandb_logger.set_steps() 431 | train_stats = train_one_epoch( 432 | model, criterion, data_loader_train, optimizer, 433 | device, epoch, loss_scaler, args.clip_grad, model_ema, mixup_fn, 434 | log_writer=log_writer, wandb_logger=wandb_logger, start_steps=epoch * num_training_steps_per_epoch, 435 | lr_schedule_values=lr_schedule_values, wd_schedule_values=wd_schedule_values, schedules=schedules, 436 | num_training_steps_per_epoch=num_training_steps_per_epoch, update_freq=args.update_freq, 437 | use_amp=args.use_amp 438 | ) 439 | if args.output_dir and args.save_ckpt: 440 | if (epoch + 1) % args.save_ckpt_freq == 0 or epoch + 1 == args.epochs: 441 | utils.save_model( 442 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 443 | loss_scaler=loss_scaler, epoch=epoch, model_ema=model_ema) 444 | if data_loader_val is not None: 445 | test_stats = evaluate(data_loader_val, model, device, use_amp=args.use_amp) 446 | print(f"Accuracy of the model on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 447 | if max_accuracy < test_stats["acc1"]: 448 | max_accuracy = test_stats["acc1"] 449 | if args.output_dir and args.save_ckpt: 450 | utils.save_model( 451 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 452 | loss_scaler=loss_scaler, epoch="best", model_ema=model_ema) 453 | print(f'Max accuracy: {max_accuracy:.2f}%') 454 | 455 | if log_writer is not None: 456 | log_writer.update(test_acc1=test_stats['acc1'], head="perf", step=epoch) 457 | log_writer.update(test_acc5=test_stats['acc5'], head="perf", step=epoch) 458 | log_writer.update(test_loss=test_stats['loss'], head="perf", step=epoch) 459 | 460 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 461 | **{f'test_{k}': v for k, v in test_stats.items()}, 462 | 'epoch': epoch, 463 | 'n_parameters': n_parameters} 464 | 465 | # repeat testing routines for EMA, if ema eval is turned on 466 | if args.model_ema and args.model_ema_eval: 467 | test_stats_ema = evaluate(data_loader_val, model_ema.ema, device, use_amp=args.use_amp) 468 | print(f"Accuracy of the model EMA on {len(dataset_val)} test images: {test_stats_ema['acc1']:.1f}%") 469 | if max_accuracy_ema < test_stats_ema["acc1"]: 470 | max_accuracy_ema = test_stats_ema["acc1"] 471 | if args.output_dir and args.save_ckpt: 472 | utils.save_model( 473 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 474 | loss_scaler=loss_scaler, epoch="best-ema", model_ema=model_ema) 475 | print(f'Max EMA accuracy: {max_accuracy_ema:.2f}%') 476 | if log_writer is not None: 477 | log_writer.update(test_acc1_ema=test_stats_ema['acc1'], head="perf", step=epoch) 478 | log_stats.update({**{f'test_{k}_ema': v for k, v in test_stats_ema.items()}}) 479 | else: 480 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 481 | 'epoch': epoch, 482 | 'n_parameters': n_parameters} 483 | 484 | if args.output_dir and utils.is_main_process(): 485 | if log_writer is not None: 486 | log_writer.flush() 487 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: 488 | f.write(json.dumps(log_stats) + "\n") 489 | 490 | if wandb_logger: 491 | wandb_logger.log_epoch_metrics(log_stats) 492 | 493 | if wandb_logger and args.wandb_ckpt and args.save_ckpt and args.output_dir: 494 | wandb_logger.log_checkpoints() 495 | 496 | 497 | total_time = time.time() - start_time 498 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 499 | print('Training time {}'.format(total_time_str)) 500 | 501 | if __name__ == '__main__': 502 | parser = argparse.ArgumentParser('ConvNeXt training and evaluation script', parents=[get_args_parser()]) 503 | args = parser.parse_args() 504 | if args.output_dir: 505 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 506 | main(args) 507 | -------------------------------------------------------------------------------- /models/convnext.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from timm.models.layers import trunc_normal_, DropPath 12 | from timm.models.registry import register_model 13 | 14 | class Block(nn.Module): 15 | r""" ConvNeXt Block. There are two equivalent implementations: 16 | (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) 17 | (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back 18 | We use (2) as we find it slightly faster in PyTorch 19 | 20 | Args: 21 | dim (int): Number of input channels. 22 | drop_path (float): Stochastic depth rate. Default: 0.0 23 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. 24 | """ 25 | def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6, drop_rate=0.): 26 | super().__init__() 27 | self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv 28 | self.norm = LayerNorm(dim, eps=1e-6) 29 | self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers 30 | self.act = nn.GELU() 31 | self.pwconv2 = nn.Linear(4 * dim, dim) 32 | self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), 33 | requires_grad=True) if layer_scale_init_value > 0 else None 34 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 35 | self.dropout = nn.Dropout(drop_rate) 36 | 37 | def forward(self, x): 38 | input = x 39 | x = self.dwconv(x) 40 | x = self.dropout(x) 41 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) 42 | x = self.norm(x) 43 | x = self.pwconv1(x) 44 | x = self.act(x) 45 | x = self.dropout(x) 46 | x = self.pwconv2(x) 47 | x = self.dropout(x) 48 | if self.gamma is not None: 49 | x = self.gamma * x 50 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) 51 | x = input + self.drop_path(x) 52 | return x 53 | 54 | 55 | class ConvNeXt(nn.Module): 56 | r""" ConvNeXt 57 | A PyTorch impl of : `A ConvNet for the 2020s` - 58 | https://arxiv.org/pdf/2201.03545.pdf 59 | 60 | Args: 61 | in_chans (int): Number of input image channels. Default: 3 62 | num_classes (int): Number of classes for classification head. Default: 1000 63 | depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] 64 | dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] 65 | drop_path_rate (float): Stochastic depth rate. Default: 0. 66 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. 67 | head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. 68 | drop_rate (float): Dropout rate 69 | """ 70 | def __init__(self, in_chans=3, num_classes=1000, 71 | depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0., 72 | layer_scale_init_value=1e-6, head_init_scale=1., drop_rate=0. 73 | ): 74 | super().__init__() 75 | 76 | self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers 77 | stem = nn.Sequential( 78 | nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), 79 | LayerNorm(dims[0], eps=1e-6, data_format="channels_first") 80 | ) 81 | self.downsample_layers.append(stem) 82 | for i in range(3): 83 | downsample_layer = nn.Sequential( 84 | LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), 85 | nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2), 86 | ) 87 | self.downsample_layers.append(downsample_layer) 88 | 89 | self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks 90 | self.depths = depths 91 | self.drop_path = drop_path_rate 92 | self.drop_rate = drop_rate 93 | dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 94 | cur = 0 95 | for i in range(4): 96 | stage = nn.Sequential( 97 | *[Block(dim=dims[i], drop_path=dp_rates[cur + j], 98 | layer_scale_init_value=layer_scale_init_value, drop_rate=drop_rate) for j in range(depths[i])] 99 | ) 100 | self.stages.append(stage) 101 | cur += depths[i] 102 | 103 | self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer 104 | self.head = nn.Linear(dims[-1], num_classes) 105 | 106 | self.apply(self._init_weights) 107 | self.head.weight.data.mul_(head_init_scale) 108 | self.head.bias.data.mul_(head_init_scale) 109 | 110 | def _init_weights(self, m): 111 | if isinstance(m, (nn.Conv2d, nn.Linear)): 112 | trunc_normal_(m.weight, std=.02) 113 | nn.init.constant_(m.bias, 0) 114 | 115 | def forward_features(self, x): 116 | for i in range(4): 117 | x = self.downsample_layers[i](x) 118 | x = self.stages[i](x) 119 | return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C) 120 | 121 | def forward(self, x): 122 | x = self.forward_features(x) 123 | x = self.head(x) 124 | return x 125 | 126 | def update_drop_path(self, drop_path_rate): 127 | self.drop_path = drop_path_rate 128 | dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))] 129 | cur = 0 130 | for i in range(4): 131 | for j in range(self.depths[i]): 132 | self.stages[i][j].drop_path.drop_prob = dp_rates[cur + j] 133 | cur += self.depths[i] 134 | 135 | 136 | class LayerNorm(nn.Module): 137 | r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. 138 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with 139 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs 140 | with shape (batch_size, channels, height, width). 141 | """ 142 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): 143 | super().__init__() 144 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 145 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 146 | self.eps = eps 147 | self.data_format = data_format 148 | if self.data_format not in ["channels_last", "channels_first"]: 149 | raise NotImplementedError 150 | self.normalized_shape = (normalized_shape, ) 151 | 152 | def forward(self, x): 153 | if self.data_format == "channels_last": 154 | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 155 | elif self.data_format == "channels_first": 156 | u = x.mean(1, keepdim=True) 157 | s = (x - u).pow(2).mean(1, keepdim=True) 158 | x = (x - u) / torch.sqrt(s + self.eps) 159 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 160 | return x 161 | 162 | 163 | model_urls = { 164 | "convnext_tiny_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth", 165 | "convnext_small_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth", 166 | "convnext_base_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth", 167 | "convnext_large_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth", 168 | "convnext_tiny_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth", 169 | "convnext_small_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth", 170 | "convnext_base_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth", 171 | "convnext_large_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth", 172 | "convnext_xlarge_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth", 173 | } 174 | 175 | @register_model 176 | def convnext_atto(pretrained=False, **kwargs): 177 | model = ConvNeXt(depths=[2, 2, 6, 2], dims=[40, 80, 160, 320], **kwargs) 178 | return model 179 | 180 | @register_model 181 | def convnext_mini(pretrained=False, **kwargs): 182 | model = ConvNeXt(depths=[2, 2, 4, 2], dims=[48, 96, 192, 384], **kwargs) 183 | return model 184 | 185 | @register_model 186 | def convnext_femto(pretrained=False, **kwargs): 187 | model = ConvNeXt(depths=[2, 2, 6, 2], dims=[48, 96, 192, 384], **kwargs) 188 | return model 189 | 190 | @register_model 191 | def convnext_pico(pretrained=False, **kwargs): 192 | # timm pico variant 193 | model = ConvNeXt(depths=[2, 2, 6, 2], dims=[64, 128, 256, 512], **kwargs) 194 | return model 195 | 196 | @register_model 197 | def convnext_tiny(pretrained=False,in_22k=False, **kwargs): 198 | model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs) 199 | return model 200 | 201 | @register_model 202 | def convnext_small(pretrained=False,in_22k=False, **kwargs): 203 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs) 204 | return model 205 | 206 | @register_model 207 | def convnext_base(pretrained=False, in_22k=False, **kwargs): 208 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) 209 | return model 210 | 211 | @register_model 212 | def convnext_large(pretrained=False, in_22k=False, **kwargs): 213 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) 214 | return model 215 | 216 | @register_model 217 | def convnext_xlarge(pretrained=False, in_22k=False, **kwargs): 218 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs) 219 | return model 220 | -------------------------------------------------------------------------------- /models/mlp_mixer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import math 9 | from copy import deepcopy 10 | from functools import partial 11 | 12 | import torch 13 | import torch.nn as nn 14 | 15 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 16 | from timm.models.helpers import build_model_with_cfg, overlay_external_default_cfg, named_apply 17 | from timm.models.layers import PatchEmbed, Mlp, GluMlp, GatedMlp, DropPath, lecun_normal_, to_2tuple 18 | from timm.models.registry import register_model 19 | 20 | 21 | def _cfg(url='', **kwargs): 22 | return { 23 | 'url': url, 24 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 25 | 'crop_pct': 0.875, 'interpolation': 'bicubic', 'fixed_input_size': True, 26 | 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), 27 | 'first_conv': 'stem.proj', 'classifier': 'head', 28 | **kwargs 29 | } 30 | 31 | 32 | default_cfgs = dict( 33 | mixer_s32_224=_cfg(), 34 | mixer_s16_224=_cfg(), 35 | mixer_b32_224=_cfg(), 36 | mixer_b16_224=_cfg( 37 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_b16_224-76587d61.pth', 38 | ), 39 | mixer_b16_224_in21k=_cfg( 40 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_b16_224_in21k-617b3de2.pth', 41 | num_classes=21843 42 | ), 43 | mixer_l32_224=_cfg(), 44 | mixer_l16_224=_cfg( 45 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_l16_224-92f9adc4.pth', 46 | ), 47 | mixer_l16_224_in21k=_cfg( 48 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_l16_224_in21k-846aa33c.pth', 49 | num_classes=21843 50 | ), 51 | 52 | # Mixer ImageNet-21K-P pretraining 53 | mixer_b16_224_miil_in21k=_cfg( 54 | url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/mixer_b16_224_miil_in21k.pth', 55 | mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', num_classes=11221, 56 | ), 57 | mixer_b16_224_miil=_cfg( 58 | url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/mixer_b16_224_miil.pth', 59 | mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', 60 | ), 61 | 62 | gmixer_12_224=_cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 63 | gmixer_24_224=_cfg( 64 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gmixer_24_224_raa-7daf7ae6.pth', 65 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 66 | 67 | resmlp_12_224=_cfg( 68 | url='https://dl.fbaipublicfiles.com/deit/resmlp_12_no_dist.pth', 69 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 70 | resmlp_24_224=_cfg( 71 | url='https://dl.fbaipublicfiles.com/deit/resmlp_24_no_dist.pth', 72 | #url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resmlp_24_224_raa-a8256759.pth', 73 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 74 | resmlp_36_224=_cfg( 75 | url='https://dl.fbaipublicfiles.com/deit/resmlp_36_no_dist.pth', 76 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 77 | resmlp_big_24_224=_cfg( 78 | url='https://dl.fbaipublicfiles.com/deit/resmlpB_24_no_dist.pth', 79 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 80 | 81 | resmlp_12_distilled_224=_cfg( 82 | url='https://dl.fbaipublicfiles.com/deit/resmlp_12_dist.pth', 83 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 84 | resmlp_24_distilled_224=_cfg( 85 | url='https://dl.fbaipublicfiles.com/deit/resmlp_24_dist.pth', 86 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 87 | resmlp_36_distilled_224=_cfg( 88 | url='https://dl.fbaipublicfiles.com/deit/resmlp_36_dist.pth', 89 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 90 | resmlp_big_24_distilled_224=_cfg( 91 | url='https://dl.fbaipublicfiles.com/deit/resmlpB_24_dist.pth', 92 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 93 | 94 | resmlp_big_24_224_in22ft1k=_cfg( 95 | url='https://dl.fbaipublicfiles.com/deit/resmlpB_24_22k.pth', 96 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 97 | 98 | gmlp_ti16_224=_cfg(), 99 | gmlp_s16_224=_cfg( 100 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gmlp_s16_224_raa-10536d42.pth', 101 | ), 102 | gmlp_b16_224=_cfg(), 103 | ) 104 | 105 | 106 | class MixerBlock(nn.Module): 107 | """ Residual Block w/ token mixing and channel MLPs 108 | Based on: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 109 | """ 110 | def __init__( 111 | self, dim, seq_len, mlp_ratio=(0.5, 4.0), mlp_layer=Mlp, 112 | norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, drop=0., drop_path=0.): 113 | super().__init__() 114 | tokens_dim, channels_dim = [int(x * dim) for x in to_2tuple(mlp_ratio)] 115 | self.norm1 = norm_layer(dim) 116 | self.mlp_tokens = mlp_layer(seq_len, tokens_dim, act_layer=act_layer, drop=drop) 117 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 118 | self.norm2 = norm_layer(dim) 119 | self.mlp_channels = mlp_layer(dim, channels_dim, act_layer=act_layer, drop=drop) 120 | 121 | def forward(self, x): 122 | x = x + self.drop_path(self.mlp_tokens(self.norm1(x).transpose(1, 2)).transpose(1, 2)) 123 | x = x + self.drop_path(self.mlp_channels(self.norm2(x))) 124 | return x 125 | 126 | 127 | class MlpMixer(nn.Module): 128 | 129 | def __init__( 130 | self, 131 | num_classes=1000, 132 | img_size=224, 133 | in_chans=3, 134 | patch_size=16, 135 | num_blocks=8, 136 | embed_dim=512, 137 | mlp_ratio=(0.5, 4.0), 138 | block_layer=MixerBlock, 139 | mlp_layer=Mlp, 140 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 141 | act_layer=nn.GELU, 142 | drop_rate=0., 143 | drop_path_rate=0., 144 | nlhb=False, 145 | stem_norm=False, 146 | ): 147 | super().__init__() 148 | self.num_classes = num_classes 149 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 150 | 151 | self.stem = PatchEmbed( 152 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, 153 | embed_dim=embed_dim, norm_layer=norm_layer if stem_norm else None) 154 | # FIXME drop_path (stochastic depth scaling rule or all the same?) 155 | self.drop_path = drop_path_rate 156 | self.drop_rate = drop_rate 157 | self.num_blocks = num_blocks 158 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_blocks)] 159 | self.blocks = nn.Sequential(*[ 160 | block_layer( 161 | embed_dim, self.stem.num_patches, mlp_ratio, mlp_layer=mlp_layer, norm_layer=norm_layer, 162 | act_layer=act_layer, drop=drop_rate, drop_path=dpr[i]) 163 | for i in range(num_blocks)]) 164 | self.norm = norm_layer(embed_dim) 165 | self.head = nn.Linear(embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() 166 | 167 | self.init_weights(nlhb=nlhb) 168 | 169 | def init_weights(self, nlhb=False): 170 | head_bias = -math.log(self.num_classes) if nlhb else 0. 171 | named_apply(partial(_init_weights, head_bias=head_bias), module=self) # depth-first 172 | 173 | def get_classifier(self): 174 | return self.head 175 | 176 | def reset_classifier(self, num_classes, global_pool=''): 177 | self.num_classes = num_classes 178 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 179 | 180 | def forward_features(self, x): 181 | x = self.stem(x) 182 | x = self.blocks(x) 183 | x = self.norm(x) 184 | x = x.mean(dim=1) 185 | return x 186 | 187 | def forward(self, x): 188 | x = self.forward_features(x) 189 | x = self.head(x) 190 | return x 191 | 192 | def update_drop_path(self, drop_path_rate): 193 | self.drop_path = drop_path_rate 194 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, self.num_blocks)] 195 | cur = 0 196 | for block in self.blocks: 197 | block.drop_path.drop_prob = dpr[cur] 198 | cur += 1 199 | assert cur == self.num_blocks 200 | 201 | def update_dropout(self, drop_rate): 202 | self.drop_rate = drop_rate 203 | for module in self.modules(): 204 | if isinstance(module, nn.Dropout): 205 | module.p = drop_rate 206 | 207 | 208 | def _init_weights(module: nn.Module, name: str, head_bias: float = 0., flax=False): 209 | """ Mixer weight initialization (trying to match Flax defaults) 210 | """ 211 | if isinstance(module, nn.Linear): 212 | if name.startswith('head'): 213 | nn.init.zeros_(module.weight) 214 | nn.init.constant_(module.bias, head_bias) 215 | else: 216 | if flax: 217 | # Flax defaults 218 | lecun_normal_(module.weight) 219 | if module.bias is not None: 220 | nn.init.zeros_(module.bias) 221 | else: 222 | # like MLP init in vit (my original init) 223 | nn.init.xavier_uniform_(module.weight) 224 | if module.bias is not None: 225 | if 'mlp' in name: 226 | nn.init.normal_(module.bias, std=1e-6) 227 | else: 228 | nn.init.zeros_(module.bias) 229 | elif isinstance(module, nn.Conv2d): 230 | lecun_normal_(module.weight) 231 | if module.bias is not None: 232 | nn.init.zeros_(module.bias) 233 | elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm)): 234 | nn.init.ones_(module.weight) 235 | nn.init.zeros_(module.bias) 236 | elif hasattr(module, 'init_weights'): 237 | # NOTE if a parent module contains init_weights method, it can override the init of the 238 | # child modules as this will be called in depth-first order. 239 | module.init_weights() 240 | 241 | 242 | def checkpoint_filter_fn(state_dict, model): 243 | """ Remap checkpoints if needed """ 244 | if 'patch_embed.proj.weight' in state_dict: 245 | # Remap FB ResMlp models -> timm 246 | out_dict = {} 247 | for k, v in state_dict.items(): 248 | k = k.replace('patch_embed.', 'stem.') 249 | k = k.replace('attn.', 'linear_tokens.') 250 | k = k.replace('mlp.', 'mlp_channels.') 251 | k = k.replace('gamma_', 'ls') 252 | if k.endswith('.alpha') or k.endswith('.beta'): 253 | v = v.reshape(1, 1, -1) 254 | out_dict[k] = v 255 | return out_dict 256 | return state_dict 257 | 258 | 259 | def _create_mixer(variant, pretrained=False, **kwargs): 260 | if kwargs.get('features_only', None): 261 | raise RuntimeError('features_only not implemented for MLP-Mixer models.') 262 | 263 | model = build_model_with_cfg( 264 | MlpMixer, variant, pretrained, 265 | default_cfg=default_cfgs[variant], 266 | pretrained_filter_fn=checkpoint_filter_fn, 267 | **kwargs) 268 | return model 269 | 270 | @register_model 271 | def mixer_t32(pretrained=False, **kwargs): 272 | """ Mixer-S/32 224x224 273 | Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 274 | """ 275 | model = MlpMixer(patch_size=32, num_blocks=8, embed_dim=256, **kwargs) 276 | return model 277 | 278 | 279 | @register_model 280 | def mixer_s32(pretrained=False, **kwargs): 281 | """ Mixer-S/32 224x224 282 | Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 283 | """ 284 | model = MlpMixer(patch_size=32, num_blocks=8, embed_dim=512, **kwargs) 285 | return model 286 | 287 | 288 | @register_model 289 | def mixer_s16(pretrained=False, **kwargs): 290 | """ Mixer-S/16 224x224 291 | Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 292 | """ 293 | model = MlpMixer(patch_size=16, num_blocks=8, embed_dim=512, **kwargs) 294 | return model 295 | 296 | 297 | @register_model 298 | def mixer_b32(pretrained=False, **kwargs): 299 | """ Mixer-B/32 224x224 300 | Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 301 | """ 302 | model_args = dict(patch_size=32, num_blocks=12, embed_dim=768, **kwargs) 303 | model = _create_mixer('mixer_b32_224', pretrained=pretrained, **model_args) 304 | return model 305 | 306 | 307 | @register_model 308 | def mixer_b16(pretrained=False, **kwargs): 309 | """ Mixer-B/16 224x224. ImageNet-1k pretrained weights. 310 | Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 311 | """ 312 | model_args = dict(patch_size=16, num_blocks=12, embed_dim=768, **kwargs) 313 | model = _create_mixer('mixer_b16_224', pretrained=pretrained, **model_args) 314 | return model 315 | 316 | -------------------------------------------------------------------------------- /models/swin_transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import logging 9 | import math 10 | from functools import partial 11 | from typing import Optional 12 | 13 | import torch 14 | import torch.nn as nn 15 | 16 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 17 | #from timm.models.fx_features import register_notrace_function 18 | from timm.models.layers import PatchEmbed, Mlp, DropPath, to_2tuple, to_ntuple, trunc_normal_ 19 | from timm.models.registry import register_model 20 | 21 | 22 | _logger = logging.getLogger(__name__) 23 | 24 | 25 | def _cfg(url='', **kwargs): 26 | return { 27 | 'url': url, 28 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 29 | 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, 30 | 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 31 | 'first_conv': 'patch_embed.proj', 'classifier': 'head', 32 | **kwargs 33 | } 34 | 35 | 36 | default_cfgs = { 37 | 'swin_base_patch4_window12_384': _cfg( 38 | url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22kto1k.pth', 39 | input_size=(3, 384, 384), crop_pct=1.0), 40 | 41 | 'swin_base_patch4_window7_224': _cfg( 42 | url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth', 43 | ), 44 | 45 | 'swin_large_patch4_window12_384': _cfg( 46 | url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth', 47 | input_size=(3, 384, 384), crop_pct=1.0), 48 | 49 | 'swin_large_patch4_window7_224': _cfg( 50 | url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22kto1k.pth', 51 | ), 52 | 53 | 'swin_small_patch4_window7_224': _cfg( 54 | url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth', 55 | ), 56 | 57 | 'swin_tiny_patch4_window7_224': _cfg( 58 | url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth', 59 | ), 60 | 61 | 'swin_base_patch4_window12_384_in22k': _cfg( 62 | url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth', 63 | input_size=(3, 384, 384), crop_pct=1.0, num_classes=21841), 64 | 65 | 'swin_base_patch4_window7_224_in22k': _cfg( 66 | url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth', 67 | num_classes=21841), 68 | 69 | 'swin_large_patch4_window12_384_in22k': _cfg( 70 | url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth', 71 | input_size=(3, 384, 384), crop_pct=1.0, num_classes=21841), 72 | 73 | 'swin_large_patch4_window7_224_in22k': _cfg( 74 | url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth', 75 | num_classes=21841), 76 | 77 | 'swin_s3_tiny_224': _cfg( 78 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_t-1d53f6a8.pth' 79 | ), 80 | 'swin_s3_small_224': _cfg( 81 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_s-3bb4c69d.pth' 82 | ), 83 | 'swin_s3_base_224': _cfg( 84 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_b-a1e95db4.pth' 85 | ) 86 | } 87 | 88 | 89 | def window_partition(x, window_size: int): 90 | """ 91 | Args: 92 | x: (B, H, W, C) 93 | window_size (int): window size 94 | Returns: 95 | windows: (num_windows*B, window_size, window_size, C) 96 | """ 97 | B, H, W, C = x.shape 98 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 99 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 100 | return windows 101 | 102 | 103 | #@register_notrace_function # reason: int argument is a Proxy 104 | def window_reverse(windows, window_size: int, H: int, W: int): 105 | """ 106 | Args: 107 | windows: (num_windows*B, window_size, window_size, C) 108 | window_size (int): Window size 109 | H (int): Height of image 110 | W (int): Width of image 111 | Returns: 112 | x: (B, H, W, C) 113 | """ 114 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 115 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 116 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 117 | return x 118 | 119 | 120 | def get_relative_position_index(win_h, win_w): 121 | # get pair-wise relative position index for each token inside the window 122 | coords = torch.stack(torch.meshgrid([torch.arange(win_h), torch.arange(win_w)])) # 2, Wh, Ww 123 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 124 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 125 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 126 | relative_coords[:, :, 0] += win_h - 1 # shift to start from 0 127 | relative_coords[:, :, 1] += win_w - 1 128 | relative_coords[:, :, 0] *= 2 * win_w - 1 129 | return relative_coords.sum(-1) # Wh*Ww, Wh*Ww 130 | 131 | 132 | class WindowAttention(nn.Module): 133 | r""" Window based multi-head self attention (W-MSA) module with relative position bias. 134 | It supports both of shifted and non-shifted window. 135 | Args: 136 | dim (int): Number of input channels. 137 | num_heads (int): Number of attention heads. 138 | head_dim (int): Number of channels per head (dim // num_heads if not set) 139 | window_size (tuple[int]): The height and width of the window. 140 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 141 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 142 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 143 | """ 144 | 145 | def __init__(self, dim, num_heads, head_dim=None, window_size=7, qkv_bias=True, attn_drop=0., proj_drop=0.): 146 | 147 | super().__init__() 148 | self.dim = dim 149 | self.window_size = to_2tuple(window_size) # Wh, Ww 150 | win_h, win_w = self.window_size 151 | self.window_area = win_h * win_w 152 | self.num_heads = num_heads 153 | head_dim = head_dim or dim // num_heads 154 | attn_dim = head_dim * num_heads 155 | self.scale = head_dim ** -0.5 156 | 157 | # define a parameter table of relative position bias, shape: 2*Wh-1 * 2*Ww-1, nH 158 | self.relative_position_bias_table = nn.Parameter(torch.zeros((2 * win_h - 1) * (2 * win_w - 1), num_heads)) 159 | 160 | # get pair-wise relative position index for each token inside the window 161 | self.register_buffer("relative_position_index", get_relative_position_index(win_h, win_w)) 162 | 163 | self.qkv = nn.Linear(dim, attn_dim * 3, bias=qkv_bias) 164 | self.attn_drop = nn.Dropout(attn_drop) 165 | self.proj = nn.Linear(attn_dim, dim) 166 | self.proj_drop = nn.Dropout(proj_drop) 167 | 168 | trunc_normal_(self.relative_position_bias_table, std=.02) 169 | self.softmax = nn.Softmax(dim=-1) 170 | 171 | def _get_rel_pos_bias(self) -> torch.Tensor: 172 | relative_position_bias = self.relative_position_bias_table[ 173 | self.relative_position_index.view(-1)].view(self.window_area, self.window_area, -1) # Wh*Ww,Wh*Ww,nH 174 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 175 | return relative_position_bias.unsqueeze(0) 176 | 177 | def forward(self, x, mask: Optional[torch.Tensor] = None): 178 | """ 179 | Args: 180 | x: input features with shape of (num_windows*B, N, C) 181 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None 182 | """ 183 | B_, N, C = x.shape 184 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) 185 | q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) 186 | 187 | q = q * self.scale 188 | attn = (q @ k.transpose(-2, -1)) 189 | attn = attn + self._get_rel_pos_bias() 190 | 191 | if mask is not None: 192 | num_win = mask.shape[0] 193 | attn = attn.view(B_ // num_win, num_win, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 194 | attn = attn.view(-1, self.num_heads, N, N) 195 | attn = self.softmax(attn) 196 | else: 197 | attn = self.softmax(attn) 198 | 199 | attn = self.attn_drop(attn) 200 | 201 | x = (attn @ v).transpose(1, 2).reshape(B_, N, -1) 202 | x = self.proj(x) 203 | x = self.proj_drop(x) 204 | return x 205 | 206 | 207 | class SwinTransformerBlock(nn.Module): 208 | r""" Swin Transformer Block. 209 | Args: 210 | dim (int): Number of input channels. 211 | input_resolution (tuple[int]): Input resulotion. 212 | window_size (int): Window size. 213 | num_heads (int): Number of attention heads. 214 | head_dim (int): Enforce the number of channels per head 215 | shift_size (int): Shift size for SW-MSA. 216 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 217 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 218 | drop (float, optional): Dropout rate. Default: 0.0 219 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 220 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 221 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU 222 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 223 | """ 224 | 225 | def __init__( 226 | self, dim, input_resolution, num_heads=4, head_dim=None, window_size=7, shift_size=0, 227 | mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0., 228 | act_layer=nn.GELU, norm_layer=nn.LayerNorm, drop_rate=0.): 229 | super().__init__() 230 | self.dim = dim 231 | self.input_resolution = input_resolution 232 | self.window_size = window_size 233 | self.shift_size = shift_size 234 | self.mlp_ratio = mlp_ratio 235 | drop = drop_rate 236 | attn_drop = drop_rate 237 | self.drop_rate = drop_rate 238 | if min(self.input_resolution) <= self.window_size: 239 | # if window size is larger than input resolution, we don't partition windows 240 | self.shift_size = 0 241 | self.window_size = min(self.input_resolution) 242 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" 243 | 244 | self.norm1 = norm_layer(dim) 245 | self.attn = WindowAttention( 246 | dim, num_heads=num_heads, head_dim=head_dim, window_size=to_2tuple(self.window_size), 247 | qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) 248 | 249 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 250 | self.norm2 = norm_layer(dim) 251 | self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) 252 | 253 | if self.shift_size > 0: 254 | # calculate attention mask for SW-MSA 255 | H, W = self.input_resolution 256 | img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 257 | cnt = 0 258 | for h in ( 259 | slice(0, -self.window_size), 260 | slice(-self.window_size, -self.shift_size), 261 | slice(-self.shift_size, None)): 262 | for w in ( 263 | slice(0, -self.window_size), 264 | slice(-self.window_size, -self.shift_size), 265 | slice(-self.shift_size, None)): 266 | img_mask[:, h, w, :] = cnt 267 | cnt += 1 268 | mask_windows = window_partition(img_mask, self.window_size) # num_win, window_size, window_size, 1 269 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size) 270 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 271 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 272 | else: 273 | attn_mask = None 274 | 275 | self.register_buffer("attn_mask", attn_mask) 276 | 277 | def forward(self, x): 278 | H, W = self.input_resolution 279 | B, L, C = x.shape 280 | assert(L == H * W, "input feature has wrong size") 281 | 282 | shortcut = x 283 | x = self.norm1(x) 284 | x = x.view(B, H, W, C) 285 | 286 | # cyclic shift 287 | if self.shift_size > 0: 288 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 289 | else: 290 | shifted_x = x 291 | 292 | # partition windows 293 | x_windows = window_partition(shifted_x, self.window_size) # num_win*B, window_size, window_size, C 294 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # num_win*B, window_size*window_size, C 295 | 296 | # W-MSA/SW-MSA 297 | attn_windows = self.attn(x_windows, mask=self.attn_mask) # num_win*B, window_size*window_size, C 298 | 299 | # merge windows 300 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) 301 | shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C 302 | 303 | # reverse cyclic shift 304 | if self.shift_size > 0: 305 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) 306 | else: 307 | x = shifted_x 308 | x = x.view(B, H * W, C) 309 | 310 | # FFN 311 | x = shortcut + self.drop_path(x) 312 | x = x + self.drop_path(self.mlp(self.norm2(x))) 313 | 314 | return x 315 | 316 | 317 | class PatchMerging(nn.Module): 318 | r""" Patch Merging Layer. 319 | Args: 320 | input_resolution (tuple[int]): Resolution of input feature. 321 | dim (int): Number of input channels. 322 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 323 | """ 324 | 325 | def __init__(self, input_resolution, dim, out_dim=None, norm_layer=nn.LayerNorm): 326 | super().__init__() 327 | self.input_resolution = input_resolution 328 | self.dim = dim 329 | self.out_dim = out_dim or 2 * dim 330 | self.norm = norm_layer(4 * dim) 331 | self.reduction = nn.Linear(4 * dim, self.out_dim, bias=False) 332 | 333 | def forward(self, x): 334 | """ 335 | x: B, H*W, C 336 | """ 337 | H, W = self.input_resolution 338 | B, L, C = x.shape 339 | assert(L == H * W, "input feature has wrong size") 340 | assert(H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even.") 341 | 342 | x = x.view(B, H, W, C) 343 | 344 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C 345 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C 346 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C 347 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C 348 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C 349 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C 350 | 351 | x = self.norm(x) 352 | x = self.reduction(x) 353 | 354 | return x 355 | 356 | 357 | class BasicLayer(nn.Module): 358 | """ A basic Swin Transformer layer for one stage. 359 | Args: 360 | dim (int): Number of input channels. 361 | input_resolution (tuple[int]): Input resolution. 362 | depth (int): Number of blocks. 363 | num_heads (int): Number of attention heads. 364 | head_dim (int): Channels per head (dim // num_heads if not set) 365 | window_size (int): Local window size. 366 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 367 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 368 | drop (float, optional): Dropout rate. Default: 0.0 369 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 370 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 371 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 372 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 373 | """ 374 | 375 | def __init__( 376 | self, dim, out_dim, input_resolution, depth, num_heads=4, head_dim=None, 377 | window_size=7, mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., 378 | drop_path=0., norm_layer=nn.LayerNorm, downsample=None): 379 | 380 | super().__init__() 381 | self.dim = dim 382 | self.input_resolution = input_resolution 383 | self.depth = depth 384 | self.grad_checkpointing = False 385 | 386 | # build blocks 387 | self.blocks = nn.Sequential(*[ 388 | SwinTransformerBlock( 389 | dim=dim, input_resolution=input_resolution, num_heads=num_heads, head_dim=head_dim, 390 | window_size=window_size, shift_size=0 if (i % 2 == 0) else window_size // 2, 391 | mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop, attn_drop=attn_drop, 392 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer) 393 | for i in range(depth)]) 394 | 395 | # patch merging layer 396 | if downsample is not None: 397 | self.downsample = downsample(input_resolution, dim=dim, out_dim=out_dim, norm_layer=norm_layer) 398 | else: 399 | self.downsample = None 400 | 401 | def forward(self, x): 402 | x = self.blocks(x) 403 | if self.downsample is not None: 404 | x = self.downsample(x) 405 | return x 406 | 407 | 408 | class SwinTransformer(nn.Module): 409 | r""" Swin Transformer 410 | A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - 411 | https://arxiv.org/pdf/2103.14030 412 | Args: 413 | img_size (int | tuple(int)): Input image size. Default 224 414 | patch_size (int | tuple(int)): Patch size. Default: 4 415 | in_chans (int): Number of input image channels. Default: 3 416 | num_classes (int): Number of classes for classification head. Default: 1000 417 | embed_dim (int): Patch embedding dimension. Default: 96 418 | depths (tuple(int)): Depth of each Swin Transformer layer. 419 | num_heads (tuple(int)): Number of attention heads in different layers. 420 | head_dim (int, tuple(int)): 421 | window_size (int): Window size. Default: 7 422 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 423 | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True 424 | drop_rate (float): Dropout rate. Default: 0 425 | attn_drop_rate (float): Attention dropout rate. Default: 0 426 | drop_path_rate (float): Stochastic depth rate. Default: 0.1 427 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. 428 | ape (bool): If True, add absolute position embedding to the patch embedding. Default: False 429 | patch_norm (bool): If True, add normalization after patch embedding. Default: True 430 | """ 431 | 432 | def __init__( 433 | self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, global_pool='avg', 434 | embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), head_dim=None, 435 | window_size=7, mlp_ratio=4., qkv_bias=True, 436 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, 437 | norm_layer=nn.LayerNorm, ape=False, patch_norm=True, weight_init='', **kwargs): 438 | super().__init__() 439 | assert global_pool in ('', 'avg') 440 | self.num_classes = num_classes 441 | self.global_pool = global_pool 442 | self.num_layers = len(depths) 443 | self.embed_dim = embed_dim 444 | self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) 445 | 446 | # split image into non-overlapping patches 447 | self.patch_embed = PatchEmbed( 448 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, 449 | norm_layer=norm_layer if patch_norm else None) 450 | num_patches = self.patch_embed.num_patches 451 | self.patch_grid = self.patch_embed.grid_size 452 | 453 | # absolute position embedding 454 | self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) if ape else None 455 | self.pos_drop = nn.Dropout(p=drop_rate) 456 | self.drop_rate = drop_rate 457 | attn_drop_rate = drop_rate 458 | 459 | # build layers 460 | if not isinstance(embed_dim, (tuple, list)): 461 | embed_dim = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] 462 | embed_out_dim = embed_dim[1:] + [None] 463 | head_dim = to_ntuple(self.num_layers)(head_dim) 464 | window_size = to_ntuple(self.num_layers)(window_size) 465 | mlp_ratio = to_ntuple(self.num_layers)(mlp_ratio) 466 | self.depth = sum(depths) 467 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 468 | self.drop_path = drop_path_rate 469 | layers = [] 470 | for i in range(self.num_layers): 471 | layers += [BasicLayer( 472 | dim=embed_dim[i], 473 | out_dim=embed_out_dim[i], 474 | input_resolution=(self.patch_grid[0] // (2 ** i), self.patch_grid[1] // (2 ** i)), 475 | depth=depths[i], 476 | num_heads=num_heads[i], 477 | head_dim=head_dim[i], 478 | window_size=window_size[i], 479 | mlp_ratio=mlp_ratio[i], 480 | qkv_bias=qkv_bias, 481 | drop=drop_rate, 482 | attn_drop=attn_drop_rate, 483 | drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])], 484 | norm_layer=norm_layer, 485 | downsample=PatchMerging if (i < self.num_layers - 1) else None 486 | )] 487 | self.layers = nn.Sequential(*layers) 488 | 489 | self.norm = norm_layer(self.num_features) 490 | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() 491 | 492 | if weight_init != 'skip': 493 | self.apply(self._init_weights) 494 | 495 | 496 | def update_drop_path(self, drop_path_rate): 497 | self.drop_path = drop_path_rate 498 | dp_rates = [x.item() for x in 499 | torch.linspace(0, drop_path_rate, self.depth)] 500 | cur = 0 501 | for i in range(self.num_layers): 502 | for block in self.layers[i].blocks: 503 | block.drop_path.drop_prob = dp_rates[cur] 504 | cur += 1 505 | assert cur == self.depth 506 | 507 | def update_dropout(self, drop_rate): 508 | self.drop_rate = drop_rate 509 | for module in self.modules(): 510 | if isinstance(module, nn.Dropout): 511 | module.p = drop_rate 512 | 513 | 514 | @torch.jit.ignore 515 | def _init_weights(self, m): 516 | if isinstance(m, nn.Linear): 517 | trunc_normal_(m.weight, std=.02) 518 | if isinstance(m, nn.Linear) and m.bias is not None: 519 | nn.init.constant_(m.bias, 0) 520 | elif isinstance(m, nn.LayerNorm): 521 | nn.init.constant_(m.bias, 0) 522 | nn.init.constant_(m.weight, 1.0) 523 | 524 | @torch.jit.ignore 525 | def no_weight_decay(self): 526 | nwd = {'absolute_pos_embed'} 527 | for n, _ in self.named_parameters(): 528 | if 'relative_position_bias_table' in n: 529 | nwd.add(n) 530 | return nwd 531 | 532 | @torch.jit.ignore 533 | def group_matcher(self, coarse=False): 534 | return dict( 535 | stem=r'^absolute_pos_embed|patch_embed', # stem and embed 536 | blocks=r'^layers\.(\d+)' if coarse else [ 537 | (r'^layers\.(\d+).downsample', (0,)), 538 | (r'^layers\.(\d+)\.\w+\.(\d+)', None), 539 | (r'^norm', (99999,)), 540 | ] 541 | ) 542 | 543 | @torch.jit.ignore 544 | def set_grad_checkpointing(self, enable=True): 545 | for l in self.layers: 546 | l.grad_checkpointing = enable 547 | 548 | @torch.jit.ignore 549 | def get_classifier(self): 550 | return self.head 551 | 552 | def reset_classifier(self, num_classes, global_pool=None): 553 | self.num_classes = num_classes 554 | if global_pool is not None: 555 | assert global_pool in ('', 'avg') 556 | self.global_pool = global_pool 557 | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() 558 | 559 | def forward_features(self, x): 560 | x = self.patch_embed(x) 561 | if self.absolute_pos_embed is not None: 562 | x = x + self.absolute_pos_embed 563 | x = self.pos_drop(x) 564 | x = self.layers(x) 565 | x = self.norm(x) # B L C 566 | return x 567 | 568 | def forward_head(self, x, pre_logits: bool = False): 569 | if self.global_pool == 'avg': 570 | x = x.mean(dim=1) 571 | return x if pre_logits else self.head(x) 572 | 573 | def forward(self, x): 574 | x = self.forward_features(x) 575 | x = self.forward_head(x) 576 | return x 577 | 578 | @register_model 579 | def swin_femto(pretrained=False, **kwargs): 580 | """ ViT-mini (Vit-Ti/16) 581 | """ 582 | model = SwinTransformer( 583 | patch_size=4, window_size=7, embed_dim=40, depths=(2, 2, 6, 2), num_heads=(2, 4, 8, 16), **kwargs) 584 | return model 585 | 586 | @register_model 587 | def swin_tiny(pretrained=False, **kwargs): 588 | """ ViT-mini (Vit-Ti/16) 589 | """ 590 | model = SwinTransformer( 591 | patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), **kwargs) 592 | return model 593 | 594 | @register_model 595 | def swin_small(pretrained=False, **kwargs): 596 | """ ViT-mini (Vit-Ti/16) 597 | """ 598 | model = SwinTransformer( 599 | patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 18, 2), num_heads=(3, 6, 12, 24), **kwargs) 600 | return model 601 | -------------------------------------------------------------------------------- /models/vision_transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import torch 9 | import torch.nn as nn 10 | from functools import partial 11 | 12 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 13 | from timm.models.helpers import load_pretrained 14 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 15 | from timm.models.resnet import resnet26d, resnet50d 16 | from timm.models.registry import register_model 17 | 18 | 19 | def _cfg(url='', **kwargs): 20 | return { 21 | 'url': url, 22 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 23 | 'crop_pct': .9, 'interpolation': 'bicubic', 24 | 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 25 | 'first_conv': 'patch_embed.proj', 'classifier': 'head', 26 | **kwargs 27 | } 28 | 29 | 30 | default_cfgs = { 31 | # patch models 32 | 'vit_small_patch16_224': _cfg( 33 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth', 34 | ), 35 | 'vit_base_patch16_224': _cfg( 36 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth', 37 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), 38 | ), 39 | 'vit_base_patch16_384': _cfg( 40 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth', 41 | input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), 42 | 'vit_base_patch32_384': _cfg( 43 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p32_384-830016f5.pth', 44 | input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), 45 | 'vit_large_patch16_224': _cfg( 46 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth', 47 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), 48 | 'vit_large_patch16_384': _cfg( 49 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth', 50 | input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), 51 | 'vit_large_patch32_384': _cfg( 52 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth', 53 | input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), 54 | 'vit_huge_patch16_224': _cfg(), 55 | 'vit_huge_patch32_384': _cfg(input_size=(3, 384, 384)), 56 | # hybrid models 57 | 'vit_small_resnet26d_224': _cfg(), 58 | 'vit_small_resnet50d_s3_224': _cfg(), 59 | 'vit_base_resnet26d_224': _cfg(), 60 | 'vit_base_resnet50d_224': _cfg(), 61 | } 62 | 63 | 64 | class Mlp(nn.Module): 65 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 66 | super().__init__() 67 | out_features = out_features or in_features 68 | hidden_features = hidden_features or in_features 69 | self.fc1 = nn.Linear(in_features, hidden_features) 70 | self.act = act_layer() 71 | self.fc2 = nn.Linear(hidden_features, out_features) 72 | self.drop = nn.Dropout(drop) 73 | 74 | def forward(self, x): 75 | x = self.fc1(x) 76 | x = self.act(x) 77 | x = self.drop(x) 78 | x = self.fc2(x) 79 | x = self.drop(x) 80 | return x 81 | 82 | 83 | class Attention(nn.Module): 84 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 85 | super().__init__() 86 | self.num_heads = num_heads 87 | head_dim = dim // num_heads 88 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 89 | self.scale = qk_scale or head_dim ** -0.5 90 | 91 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 92 | self.attn_drop = nn.Dropout(attn_drop) 93 | self.proj = nn.Linear(dim, dim) 94 | self.proj_drop = nn.Dropout(proj_drop) 95 | 96 | def forward(self, x): 97 | B, N, C = x.shape 98 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 99 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 100 | 101 | attn = (q @ k.transpose(-2, -1)) * self.scale 102 | attn = attn.softmax(dim=-1) 103 | attn = self.attn_drop(attn) 104 | 105 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 106 | x = self.proj(x) 107 | x = self.proj_drop(x) 108 | return x 109 | 110 | 111 | class Block(nn.Module): 112 | 113 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 114 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 115 | super().__init__() 116 | self.norm1 = norm_layer(dim) 117 | self.attn = Attention( 118 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 119 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 120 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 121 | self.norm2 = norm_layer(dim) 122 | mlp_hidden_dim = int(dim * mlp_ratio) 123 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 124 | 125 | def forward(self, x): 126 | x = x + self.drop_path(self.attn(self.norm1(x))) 127 | x = x + self.drop_path(self.mlp(self.norm2(x))) 128 | return x 129 | 130 | 131 | class PatchEmbed(nn.Module): 132 | """ Image to Patch Embedding 133 | """ 134 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 135 | super().__init__() 136 | img_size = to_2tuple(img_size) 137 | patch_size = to_2tuple(patch_size) 138 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) 139 | self.img_size = img_size 140 | self.patch_size = patch_size 141 | self.num_patches = num_patches 142 | 143 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 144 | 145 | def forward(self, x): 146 | B, C, H, W = x.shape 147 | # FIXME look at relaxing size constraints 148 | assert H == self.img_size[0] and W == self.img_size[1], \ 149 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 150 | x = self.proj(x).flatten(2).transpose(1, 2) 151 | return x 152 | 153 | 154 | class HybridEmbed(nn.Module): 155 | """ CNN Feature Map Embedding 156 | Extract feature map from CNN, flatten, project to embedding dim. 157 | """ 158 | def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768): 159 | super().__init__() 160 | assert isinstance(backbone, nn.Module) 161 | img_size = to_2tuple(img_size) 162 | self.img_size = img_size 163 | self.backbone = backbone 164 | if feature_size is None: 165 | with torch.no_grad(): 166 | # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature 167 | # map for all networks, the feature metadata has reliable channel and stride info, but using 168 | # stride to calc feature dim requires info about padding of each stage that isn't captured. 169 | training = backbone.training 170 | if training: 171 | backbone.eval() 172 | o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1] 173 | feature_size = o.shape[-2:] 174 | feature_dim = o.shape[1] 175 | backbone.train(training) 176 | else: 177 | feature_size = to_2tuple(feature_size) 178 | feature_dim = self.backbone.feature_info.channels()[-1] 179 | self.num_patches = feature_size[0] * feature_size[1] 180 | self.proj = nn.Linear(feature_dim, embed_dim) 181 | 182 | def forward(self, x): 183 | x = self.backbone(x)[-1] 184 | x = x.flatten(2).transpose(1, 2) 185 | x = self.proj(x) 186 | return x 187 | 188 | 189 | class VisionTransformer(nn.Module): 190 | """ Vision Transformer with support for patch or hybrid CNN input stage 191 | """ 192 | 193 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 194 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 195 | drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm): 196 | super().__init__() 197 | self.num_classes = num_classes 198 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 199 | # I add these two lines 200 | self.drop_rate=drop_rate 201 | attn_drop_rate=drop_rate 202 | if hybrid_backbone is not None: 203 | self.patch_embed = HybridEmbed( 204 | hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim) 205 | else: 206 | self.patch_embed = PatchEmbed( 207 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 208 | num_patches = self.patch_embed.num_patches 209 | 210 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 211 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 212 | self.pos_drop = nn.Dropout(p=drop_rate) 213 | self.depth = depth 214 | 215 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 216 | self.blocks = nn.ModuleList([ 217 | Block( 218 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 219 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) 220 | for i in range(depth)]) 221 | self.norm = norm_layer(embed_dim) 222 | 223 | # NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here 224 | #self.repr = nn.Linear(embed_dim, representation_size) 225 | #self.repr_act = nn.Tanh() 226 | 227 | # Classifier head 228 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 229 | 230 | trunc_normal_(self.pos_embed, std=.02) 231 | trunc_normal_(self.cls_token, std=.02) 232 | self.apply(self._init_weights) 233 | 234 | def _init_weights(self, m): 235 | if isinstance(m, nn.Linear): 236 | trunc_normal_(m.weight, std=.02) 237 | if isinstance(m, nn.Linear) and m.bias is not None: 238 | nn.init.constant_(m.bias, 0) 239 | elif isinstance(m, nn.LayerNorm): 240 | nn.init.constant_(m.bias, 0) 241 | nn.init.constant_(m.weight, 1.0) 242 | 243 | @torch.jit.ignore 244 | def no_weight_decay(self): 245 | return {'pos_embed', 'cls_token'} 246 | 247 | def get_classifier(self): 248 | return self.head 249 | 250 | def reset_classifier(self, num_classes, global_pool=''): 251 | self.num_classes = num_classes 252 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 253 | 254 | def forward_features(self, x): 255 | B = x.shape[0] 256 | x = self.patch_embed(x) 257 | 258 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 259 | x = torch.cat((cls_tokens, x), dim=1) 260 | x = x + self.pos_embed 261 | x = self.pos_drop(x) 262 | 263 | for blk in self.blocks: 264 | x = blk(x) 265 | 266 | x = self.norm(x) 267 | return x[:, 0] 268 | 269 | def forward(self, x): 270 | x = self.forward_features(x) 271 | x = self.head(x) 272 | return x 273 | 274 | def update_drop_path(self, drop_path_rate): 275 | self.drop_path = drop_path_rate 276 | dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, self.depth)] 277 | for i in range(self.depth): 278 | self.blocks[i].drop_path.drop_prob = dp_rates[i] 279 | 280 | def update_dropout(self, drop_rate): 281 | self.drop_rate = drop_rate 282 | for module in self.modules(): 283 | if isinstance(module, nn.Dropout): 284 | module.p = drop_rate 285 | 286 | 287 | def _conv_filter(state_dict, patch_size=16): 288 | """ convert patch embedding weight from manual patchify + linear proj to conv""" 289 | out_dict = {} 290 | for k, v in state_dict.items(): 291 | if 'patch_embed.proj.weight' in k: 292 | v = v.reshape((v.shape[0], 3, patch_size, patch_size)) 293 | out_dict[k] = v 294 | return out_dict 295 | 296 | @register_model 297 | def vit_tiny(pretrained=False, **kwargs): 298 | model = VisionTransformer( 299 | patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, 300 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 301 | return model 302 | 303 | @register_model 304 | def vit_small(pretrained=False, **kwargs): 305 | model = VisionTransformer( 306 | patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, 307 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 308 | return model 309 | 310 | @register_model 311 | def vit_base(pretrained=False, **kwargs): 312 | model = VisionTransformer( 313 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 314 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 315 | return model 316 | 317 | @register_model 318 | def vit_large(pretrained=False, **kwargs): 319 | model = VisionTransformer( 320 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 321 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 322 | return model 323 | 324 | 325 | -------------------------------------------------------------------------------- /optim_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import torch 9 | from torch import optim as optim 10 | 11 | from timm.optim.adafactor import Adafactor 12 | from timm.optim.adahessian import Adahessian 13 | from timm.optim.adamp import AdamP 14 | from timm.optim.lookahead import Lookahead 15 | from timm.optim.nadam import Nadam 16 | from timm.optim.novograd import NovoGrad 17 | from timm.optim.nvnovograd import NvNovoGrad 18 | from timm.optim.radam import RAdam 19 | from timm.optim.rmsprop_tf import RMSpropTF 20 | from timm.optim.sgdp import SGDP 21 | 22 | import json 23 | 24 | try: 25 | from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD 26 | has_apex = True 27 | except ImportError: 28 | has_apex = False 29 | 30 | 31 | def get_num_layer_for_convnext(var_name): 32 | """ 33 | Divide [3, 3, 27, 3] layers into 12 groups; each group is three 34 | consecutive blocks, including possible neighboring downsample layers; 35 | adapted from https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py 36 | """ 37 | num_max_layer = 12 38 | if var_name.startswith("downsample_layers"): 39 | stage_id = int(var_name.split('.')[1]) 40 | if stage_id == 0: 41 | layer_id = 0 42 | elif stage_id == 1 or stage_id == 2: 43 | layer_id = stage_id + 1 44 | elif stage_id == 3: 45 | layer_id = 12 46 | return layer_id 47 | 48 | elif var_name.startswith("stages"): 49 | stage_id = int(var_name.split('.')[1]) 50 | block_id = int(var_name.split('.')[2]) 51 | if stage_id == 0 or stage_id == 1: 52 | layer_id = stage_id + 1 53 | elif stage_id == 2: 54 | layer_id = 3 + block_id // 3 55 | elif stage_id == 3: 56 | layer_id = 12 57 | return layer_id 58 | else: 59 | return num_max_layer + 1 60 | 61 | class LayerDecayValueAssigner(object): 62 | def __init__(self, values): 63 | self.values = values 64 | 65 | def get_scale(self, layer_id): 66 | return self.values[layer_id] 67 | 68 | def get_layer_id(self, var_name): 69 | return get_num_layer_for_convnext(var_name) 70 | 71 | 72 | def get_parameter_groups(model, weight_decay=1e-5, skip_list=(), get_num_layer=None, get_layer_scale=None): 73 | parameter_group_names = {} 74 | parameter_group_vars = {} 75 | 76 | for name, param in model.named_parameters(): 77 | if not param.requires_grad: 78 | continue # frozen weights 79 | if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list: 80 | group_name = "no_decay" 81 | this_weight_decay = 0. 82 | else: 83 | group_name = "decay" 84 | this_weight_decay = weight_decay 85 | if get_num_layer is not None: 86 | layer_id = get_num_layer(name) 87 | group_name = "layer_%d_%s" % (layer_id, group_name) 88 | else: 89 | layer_id = None 90 | 91 | if group_name not in parameter_group_names: 92 | if get_layer_scale is not None: 93 | scale = get_layer_scale(layer_id) 94 | else: 95 | scale = 1. 96 | 97 | parameter_group_names[group_name] = { 98 | "weight_decay": this_weight_decay, 99 | "params": [], 100 | "lr_scale": scale 101 | } 102 | parameter_group_vars[group_name] = { 103 | "weight_decay": this_weight_decay, 104 | "params": [], 105 | "lr_scale": scale 106 | } 107 | 108 | parameter_group_vars[group_name]["params"].append(param) 109 | parameter_group_names[group_name]["params"].append(name) 110 | print("Param groups = %s" % json.dumps(parameter_group_names, indent=2)) 111 | return list(parameter_group_vars.values()) 112 | 113 | 114 | def create_optimizer(args, model, get_num_layer=None, get_layer_scale=None, filter_bias_and_bn=True, skip_list=None): 115 | opt_lower = args.opt.lower() 116 | weight_decay = args.weight_decay 117 | # if weight_decay and filter_bias_and_bn: 118 | if filter_bias_and_bn: 119 | skip = {} 120 | if skip_list is not None: 121 | skip = skip_list 122 | elif hasattr(model, 'no_weight_decay'): 123 | skip = model.no_weight_decay() 124 | parameters = get_parameter_groups(model, weight_decay, skip, get_num_layer, get_layer_scale) 125 | weight_decay = 0. 126 | else: 127 | parameters = model.parameters() 128 | 129 | if 'fused' in opt_lower: 130 | assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers' 131 | 132 | opt_args = dict(lr=args.lr, weight_decay=weight_decay) 133 | if hasattr(args, 'opt_eps') and args.opt_eps is not None: 134 | opt_args['eps'] = args.opt_eps 135 | if hasattr(args, 'opt_betas') and args.opt_betas is not None: 136 | opt_args['betas'] = args.opt_betas 137 | 138 | opt_split = opt_lower.split('_') 139 | opt_lower = opt_split[-1] 140 | if opt_lower == 'sgd' or opt_lower == 'nesterov': 141 | opt_args.pop('eps', None) 142 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) 143 | elif opt_lower == 'momentum': 144 | opt_args.pop('eps', None) 145 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) 146 | elif opt_lower == 'adam': 147 | optimizer = optim.Adam(parameters, **opt_args) 148 | elif opt_lower == 'adamw': 149 | optimizer = optim.AdamW(parameters, **opt_args) 150 | elif opt_lower == 'nadam': 151 | optimizer = Nadam(parameters, **opt_args) 152 | elif opt_lower == 'radam': 153 | optimizer = RAdam(parameters, **opt_args) 154 | elif opt_lower == 'adamp': 155 | optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args) 156 | elif opt_lower == 'sgdp': 157 | optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args) 158 | elif opt_lower == 'adadelta': 159 | optimizer = optim.Adadelta(parameters, **opt_args) 160 | elif opt_lower == 'adafactor': 161 | if not args.lr: 162 | opt_args['lr'] = None 163 | optimizer = Adafactor(parameters, **opt_args) 164 | elif opt_lower == 'adahessian': 165 | optimizer = Adahessian(parameters, **opt_args) 166 | elif opt_lower == 'rmsprop': 167 | optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args) 168 | elif opt_lower == 'rmsproptf': 169 | optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args) 170 | elif opt_lower == 'novograd': 171 | optimizer = NovoGrad(parameters, **opt_args) 172 | elif opt_lower == 'nvnovograd': 173 | optimizer = NvNovoGrad(parameters, **opt_args) 174 | elif opt_lower == 'fusedsgd': 175 | opt_args.pop('eps', None) 176 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) 177 | elif opt_lower == 'fusedmomentum': 178 | opt_args.pop('eps', None) 179 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) 180 | elif opt_lower == 'fusedadam': 181 | optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args) 182 | elif opt_lower == 'fusedadamw': 183 | optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args) 184 | elif opt_lower == 'fusedlamb': 185 | optimizer = FusedLAMB(parameters, **opt_args) 186 | elif opt_lower == 'fusednovograd': 187 | opt_args.setdefault('betas', (0.95, 0.98)) 188 | optimizer = FusedNovoGrad(parameters, **opt_args) 189 | else: 190 | assert False and "Invalid optimizer" 191 | 192 | if len(opt_split) > 1: 193 | if opt_split[0] == 'lookahead': 194 | optimizer = Lookahead(optimizer) 195 | 196 | return optimizer 197 | -------------------------------------------------------------------------------- /run_with_submitit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | 9 | import argparse 10 | import os 11 | import uuid 12 | from pathlib import Path 13 | 14 | import main as classification 15 | import submitit 16 | 17 | 18 | def parse_args(): 19 | classification_parser = classification.get_args_parser() 20 | parser = argparse.ArgumentParser("Submitit", parents=[classification_parser]) 21 | parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node") 22 | parser.add_argument("--nodes", default=2, type=int, help="Number of nodes to request") 23 | parser.add_argument("--timeout", default=72, type=int, help="Duration of the job, in hours") 24 | parser.add_argument("--job_name", default="dropout", type=str, help="Job name") 25 | parser.add_argument("--job_dir", default="", type=str, help="Job directory; leave empty for default") 26 | parser.add_argument("--partition", default="learnlab", type=str, help="Partition where to submit") 27 | parser.add_argument("--use_volta32", action='store_true', default=True, help="Big models? Use this") 28 | parser.add_argument('--comment', default="", type=str, 29 | help='Comment to pass to scheduler, e.g. priority message') 30 | return parser.parse_args() 31 | 32 | def get_shared_folder() -> Path: 33 | user = os.getenv("USER") 34 | if Path("/checkpoint/").is_dir(): 35 | p = Path(f"/checkpoint/{user}/dropout") 36 | p.mkdir(exist_ok=True) 37 | return p 38 | raise RuntimeError("No shared folder available") 39 | 40 | def get_init_file(): 41 | # Init file must not exist, but it's parent dir must exist. 42 | os.makedirs(str(get_shared_folder()), exist_ok=True) 43 | init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init" 44 | if init_file.exists(): 45 | os.remove(str(init_file)) 46 | return init_file 47 | 48 | class Trainer(object): 49 | def __init__(self, args): 50 | self.args = args 51 | 52 | def __call__(self): 53 | import main as classification 54 | 55 | self._setup_gpu_args() 56 | classification.main(self.args) 57 | 58 | def checkpoint(self): 59 | import os 60 | import submitit 61 | 62 | self.args.dist_url = get_init_file().as_uri() 63 | self.args.auto_resume = True 64 | print("Requeuing ", self.args) 65 | empty_trainer = type(self)(self.args) 66 | return submitit.helpers.DelayedSubmission(empty_trainer) 67 | 68 | def _setup_gpu_args(self): 69 | import submitit 70 | from pathlib import Path 71 | 72 | job_env = submitit.JobEnvironment() 73 | self.args.output_dir = Path(self.args.job_dir) 74 | self.args.gpu = job_env.local_rank 75 | self.args.rank = job_env.global_rank 76 | self.args.world_size = job_env.num_tasks 77 | print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") 78 | 79 | 80 | def main(): 81 | args = parse_args() 82 | 83 | if args.job_dir == "": 84 | args.job_dir = get_shared_folder() / "%j" 85 | 86 | executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30) 87 | 88 | num_gpus_per_node = args.ngpus 89 | nodes = args.nodes 90 | timeout_min = args.timeout * 60 91 | 92 | partition = args.partition 93 | kwargs = {} 94 | if args.use_volta32: 95 | kwargs['slurm_constraint'] = 'volta32gb' 96 | if args.comment: 97 | kwargs['slurm_comment'] = args.comment 98 | 99 | executor.update_parameters( 100 | mem_gb=40 * num_gpus_per_node, 101 | gpus_per_node=num_gpus_per_node, 102 | tasks_per_node=num_gpus_per_node, # one task per GPU 103 | cpus_per_task=10, 104 | nodes=nodes, 105 | timeout_min=timeout_min, # max is 60 * 72 106 | # Below are cluster dependent parameters 107 | slurm_partition=partition, 108 | slurm_signal_delay_s=120, 109 | **kwargs 110 | ) 111 | 112 | executor.update_parameters(name=args.job_name) 113 | 114 | args.dist_url = get_init_file().as_uri() 115 | args.output_dir = args.job_dir 116 | 117 | trainer = Trainer(args) 118 | job = executor.submit(trainer) 119 | 120 | print("Submitted job_id:", job.job_id) 121 | 122 | if __name__ == "__main__": 123 | main() 124 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import os 9 | import math 10 | import time 11 | from collections import defaultdict, deque 12 | import datetime 13 | import numpy as np 14 | from timm.utils import get_state_dict 15 | 16 | from pathlib import Path 17 | from timm.models import create_model 18 | import torch 19 | import torch.distributed as dist 20 | from torch._six import inf 21 | 22 | from tensorboardX import SummaryWriter 23 | 24 | class SmoothedValue(object): 25 | """Track a series of values and provide access to smoothed values over a 26 | window or the global series average. 27 | """ 28 | 29 | def __init__(self, window_size=20, fmt=None): 30 | if fmt is None: 31 | fmt = "{median:.4f} ({global_avg:.4f})" 32 | self.deque = deque(maxlen=window_size) 33 | self.total = 0.0 34 | self.count = 0 35 | self.fmt = fmt 36 | 37 | def update(self, value, n=1): 38 | self.deque.append(value) 39 | self.count += n 40 | self.total += value * n 41 | 42 | def synchronize_between_processes(self): 43 | """ 44 | Warning: does not synchronize the deque! 45 | """ 46 | if not is_dist_avail_and_initialized(): 47 | return 48 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 49 | dist.barrier() 50 | dist.all_reduce(t) 51 | t = t.tolist() 52 | self.count = int(t[0]) 53 | self.total = t[1] 54 | 55 | @property 56 | def median(self): 57 | d = torch.tensor(list(self.deque)) 58 | return d.median().item() 59 | 60 | @property 61 | def avg(self): 62 | d = torch.tensor(list(self.deque), dtype=torch.float32) 63 | return d.mean().item() 64 | 65 | @property 66 | def global_avg(self): 67 | return self.total / self.count 68 | 69 | @property 70 | def max(self): 71 | return max(self.deque) 72 | 73 | @property 74 | def value(self): 75 | return self.deque[-1] 76 | 77 | def __str__(self): 78 | return self.fmt.format( 79 | median=self.median, 80 | avg=self.avg, 81 | global_avg=self.global_avg, 82 | max=self.max, 83 | value=self.value) 84 | 85 | 86 | class MetricLogger(object): 87 | def __init__(self, delimiter="\t"): 88 | self.meters = defaultdict(SmoothedValue) 89 | self.delimiter = delimiter 90 | 91 | def update(self, **kwargs): 92 | for k, v in kwargs.items(): 93 | if v is None: 94 | continue 95 | if isinstance(v, torch.Tensor): 96 | v = v.item() 97 | assert isinstance(v, (float, int)) 98 | self.meters[k].update(v) 99 | 100 | def __getattr__(self, attr): 101 | if attr in self.meters: 102 | return self.meters[attr] 103 | if attr in self.__dict__: 104 | return self.__dict__[attr] 105 | raise AttributeError("'{}' object has no attribute '{}'".format( 106 | type(self).__name__, attr)) 107 | 108 | def __str__(self): 109 | loss_str = [] 110 | for name, meter in self.meters.items(): 111 | loss_str.append( 112 | "{}: {}".format(name, str(meter)) 113 | ) 114 | return self.delimiter.join(loss_str) 115 | 116 | def synchronize_between_processes(self): 117 | for meter in self.meters.values(): 118 | meter.synchronize_between_processes() 119 | 120 | def add_meter(self, name, meter): 121 | self.meters[name] = meter 122 | 123 | def log_every(self, iterable, print_freq, header=None): 124 | i = 0 125 | if not header: 126 | header = '' 127 | start_time = time.time() 128 | end = time.time() 129 | iter_time = SmoothedValue(fmt='{avg:.4f}') 130 | data_time = SmoothedValue(fmt='{avg:.4f}') 131 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 132 | log_msg = [ 133 | header, 134 | '[{0' + space_fmt + '}/{1}]', 135 | 'eta: {eta}', 136 | '{meters}', 137 | 'time: {time}', 138 | 'data: {data}' 139 | ] 140 | if torch.cuda.is_available(): 141 | log_msg.append('max mem: {memory:.0f}') 142 | log_msg = self.delimiter.join(log_msg) 143 | MB = 1024.0 * 1024.0 144 | for obj in iterable: 145 | data_time.update(time.time() - end) 146 | yield obj 147 | iter_time.update(time.time() - end) 148 | if i % print_freq == 0 or i == len(iterable) - 1: 149 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 150 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 151 | if torch.cuda.is_available(): 152 | print(log_msg.format( 153 | i, len(iterable), eta=eta_string, 154 | meters=str(self), 155 | time=str(iter_time), data=str(data_time), 156 | memory=torch.cuda.max_memory_allocated() / MB)) 157 | else: 158 | print(log_msg.format( 159 | i, len(iterable), eta=eta_string, 160 | meters=str(self), 161 | time=str(iter_time), data=str(data_time))) 162 | i += 1 163 | end = time.time() 164 | total_time = time.time() - start_time 165 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 166 | print('{} Total time: {} ({:.4f} s / it)'.format( 167 | header, total_time_str, total_time / len(iterable))) 168 | 169 | 170 | class TensorboardLogger(object): 171 | def __init__(self, log_dir): 172 | self.writer = SummaryWriter(logdir=log_dir) 173 | self.step = 0 174 | 175 | def set_step(self, step=None): 176 | if step is not None: 177 | self.step = step 178 | else: 179 | self.step += 1 180 | 181 | def update(self, head='scalar', step=None, **kwargs): 182 | for k, v in kwargs.items(): 183 | if v is None: 184 | continue 185 | if isinstance(v, torch.Tensor): 186 | v = v.item() 187 | assert isinstance(v, (float, int)) 188 | self.writer.add_scalar(head + "/" + k, v, self.step if step is None else step) 189 | 190 | def flush(self): 191 | self.writer.flush() 192 | 193 | 194 | class WandbLogger(object): 195 | def __init__(self, args): 196 | self.args = args 197 | 198 | try: 199 | import wandb 200 | self._wandb = wandb 201 | except ImportError: 202 | raise ImportError( 203 | "To use the Weights and Biases Logger please install wandb." 204 | "Run `pip install wandb` to install it." 205 | ) 206 | 207 | # Initialize a W&B run 208 | if self._wandb.run is None: 209 | self._wandb.init( 210 | project=args.project, 211 | config=args 212 | ) 213 | 214 | def log_epoch_metrics(self, metrics, commit=True): 215 | """ 216 | Log train/test metrics onto W&B. 217 | """ 218 | # Log number of model parameters as W&B summary 219 | self._wandb.summary['n_parameters'] = metrics.get('n_parameters', None) 220 | metrics.pop('n_parameters', None) 221 | 222 | # Log current epoch 223 | self._wandb.log({'epoch': metrics.get('epoch')}, commit=False) 224 | metrics.pop('epoch') 225 | 226 | for k, v in metrics.items(): 227 | if 'train' in k: 228 | self._wandb.log({f'Global Train/{k}': v}, commit=False) 229 | elif 'test' in k: 230 | self._wandb.log({f'Global Test/{k}': v}, commit=False) 231 | 232 | self._wandb.log({}) 233 | 234 | def log_checkpoints(self): 235 | output_dir = self.args.output_dir 236 | model_artifact = self._wandb.Artifact( 237 | self._wandb.run.id + "_model", type="model" 238 | ) 239 | 240 | model_artifact.add_dir(output_dir) 241 | self._wandb.log_artifact(model_artifact, aliases=["latest", "best"]) 242 | 243 | def set_steps(self): 244 | # Set global training step 245 | self._wandb.define_metric('Rank-0 Batch Wise/*', step_metric='Rank-0 Batch Wise/global_train_step') 246 | # Set epoch-wise step 247 | self._wandb.define_metric('Global Train/*', step_metric='epoch') 248 | self._wandb.define_metric('Global Test/*', step_metric='epoch') 249 | 250 | 251 | def setup_for_distributed(is_master): 252 | """ 253 | This function disables printing when not in master process 254 | """ 255 | import builtins as __builtin__ 256 | builtin_print = __builtin__.print 257 | 258 | def print(*args, **kwargs): 259 | force = kwargs.pop('force', False) 260 | if is_master or force: 261 | builtin_print(*args, **kwargs) 262 | 263 | __builtin__.print = print 264 | 265 | 266 | def is_dist_avail_and_initialized(): 267 | if not dist.is_available(): 268 | return False 269 | if not dist.is_initialized(): 270 | return False 271 | return True 272 | 273 | 274 | def get_world_size(): 275 | if not is_dist_avail_and_initialized(): 276 | return 1 277 | return dist.get_world_size() 278 | 279 | 280 | def get_rank(): 281 | if not is_dist_avail_and_initialized(): 282 | return 0 283 | return dist.get_rank() 284 | 285 | 286 | def is_main_process(): 287 | return get_rank() == 0 288 | 289 | 290 | def save_on_master(*args, **kwargs): 291 | if is_main_process(): 292 | torch.save(*args, **kwargs) 293 | 294 | 295 | def init_distributed_mode(args): 296 | 297 | if args.dist_on_itp: 298 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 299 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) 300 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 301 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) 302 | os.environ['LOCAL_RANK'] = str(args.gpu) 303 | os.environ['RANK'] = str(args.rank) 304 | os.environ['WORLD_SIZE'] = str(args.world_size) 305 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] 306 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 307 | args.rank = int(os.environ["RANK"]) 308 | args.world_size = int(os.environ['WORLD_SIZE']) 309 | args.gpu = int(os.environ['LOCAL_RANK']) 310 | elif 'SLURM_PROCID' in os.environ: 311 | args.rank = int(os.environ['SLURM_PROCID']) 312 | args.gpu = args.rank % torch.cuda.device_count() 313 | 314 | os.environ['RANK'] = str(args.rank) 315 | os.environ['LOCAL_RANK'] = str(args.gpu) 316 | os.environ['WORLD_SIZE'] = str(args.world_size) 317 | else: 318 | print('Not using distributed mode') 319 | args.distributed = False 320 | return 321 | 322 | args.distributed = True 323 | 324 | torch.cuda.set_device(args.gpu) 325 | args.dist_backend = 'nccl' 326 | print('| distributed init (rank {}): {}, gpu {}'.format( 327 | args.rank, args.dist_url, args.gpu), flush=True) 328 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 329 | world_size=args.world_size, rank=args.rank) 330 | torch.distributed.barrier() 331 | setup_for_distributed(args.rank == 0) 332 | 333 | 334 | def load_state_dict(model, state_dict, prefix='', ignore_missing="relative_position_index"): 335 | missing_keys = [] 336 | unexpected_keys = [] 337 | error_msgs = [] 338 | # copy state_dict so _load_from_state_dict can modify it 339 | metadata = getattr(state_dict, '_metadata', None) 340 | state_dict = state_dict.copy() 341 | if metadata is not None: 342 | state_dict._metadata = metadata 343 | 344 | def load(module, prefix=''): 345 | local_metadata = {} if metadata is None else metadata.get( 346 | prefix[:-1], {}) 347 | module._load_from_state_dict( 348 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) 349 | for name, child in module._modules.items(): 350 | if child is not None: 351 | load(child, prefix + name + '.') 352 | 353 | load(model, prefix=prefix) 354 | 355 | warn_missing_keys = [] 356 | ignore_missing_keys = [] 357 | for key in missing_keys: 358 | keep_flag = True 359 | for ignore_key in ignore_missing.split('|'): 360 | if ignore_key in key: 361 | keep_flag = False 362 | break 363 | if keep_flag: 364 | warn_missing_keys.append(key) 365 | else: 366 | ignore_missing_keys.append(key) 367 | 368 | missing_keys = warn_missing_keys 369 | 370 | if len(missing_keys) > 0: 371 | print("Weights of {} not initialized from pretrained model: {}".format( 372 | model.__class__.__name__, missing_keys)) 373 | if len(unexpected_keys) > 0: 374 | print("Weights from pretrained model not used in {}: {}".format( 375 | model.__class__.__name__, unexpected_keys)) 376 | if len(ignore_missing_keys) > 0: 377 | print("Ignored weights of {} not initialized from pretrained model: {}".format( 378 | model.__class__.__name__, ignore_missing_keys)) 379 | if len(error_msgs) > 0: 380 | print('\n'.join(error_msgs)) 381 | 382 | 383 | class NativeScalerWithGradNormCount: 384 | state_dict_key = "amp_scaler" 385 | 386 | def __init__(self): 387 | self._scaler = torch.cuda.amp.GradScaler() 388 | 389 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 390 | self._scaler.scale(loss).backward(create_graph=create_graph) 391 | if update_grad: 392 | if clip_grad is not None: 393 | assert parameters is not None 394 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 395 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 396 | else: 397 | self._scaler.unscale_(optimizer) 398 | norm = get_grad_norm_(parameters) 399 | self._scaler.step(optimizer) 400 | self._scaler.update() 401 | else: 402 | norm = None 403 | return norm 404 | 405 | def state_dict(self): 406 | return self._scaler.state_dict() 407 | 408 | def load_state_dict(self, state_dict): 409 | self._scaler.load_state_dict(state_dict) 410 | 411 | 412 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 413 | if isinstance(parameters, torch.Tensor): 414 | parameters = [parameters] 415 | parameters = [p for p in parameters if p.grad is not None] 416 | norm_type = float(norm_type) 417 | if len(parameters) == 0: 418 | return torch.tensor(0.) 419 | device = parameters[0].grad.device 420 | if norm_type == inf: 421 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 422 | else: 423 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) 424 | return total_norm 425 | 426 | 427 | def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, 428 | start_warmup_value=0, warmup_steps=-1): 429 | warmup_schedule = np.array([]) 430 | warmup_iters = warmup_epochs * niter_per_ep 431 | if warmup_steps > 0: 432 | warmup_iters = warmup_steps 433 | print("Set warmup steps = %d" % warmup_iters) 434 | if warmup_epochs > 0: 435 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) 436 | 437 | iters = np.arange(epochs * niter_per_ep - warmup_iters) 438 | schedule = np.array( 439 | [final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / (len(iters)))) for i in iters]) 440 | 441 | schedule = np.concatenate((warmup_schedule, schedule)) 442 | 443 | assert len(schedule) == epochs * niter_per_ep 444 | return schedule 445 | 446 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, model_ema=None): 447 | output_dir = Path(args.output_dir) 448 | epoch_name = str(epoch) 449 | checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)] 450 | for checkpoint_path in checkpoint_paths: 451 | to_save = { 452 | 'model': model_without_ddp.state_dict(), 453 | 'optimizer': optimizer.state_dict(), 454 | 'epoch': epoch, 455 | 'scaler': loss_scaler.state_dict(), 456 | 'args': args, 457 | } 458 | 459 | if model_ema is not None: 460 | to_save['model_ema'] = get_state_dict(model_ema) 461 | 462 | save_on_master(to_save, checkpoint_path) 463 | 464 | if is_main_process() and isinstance(epoch, int): 465 | to_del = epoch - args.save_ckpt_num * args.save_ckpt_freq 466 | old_ckpt = output_dir / ('checkpoint-%s.pth' % to_del) 467 | if os.path.exists(old_ckpt): 468 | os.remove(old_ckpt) 469 | 470 | 471 | def auto_load_model(args, model, model_without_ddp, optimizer, loss_scaler, model_ema=None): 472 | output_dir = Path(args.output_dir) 473 | if args.auto_resume and len(args.resume) == 0: 474 | import glob 475 | all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*.pth')) 476 | latest_ckpt = -1 477 | for ckpt in all_checkpoints: 478 | t = ckpt.split('-')[-1].split('.')[0] 479 | if t.isdigit(): 480 | latest_ckpt = max(int(t), latest_ckpt) 481 | if latest_ckpt >= 0: 482 | args.resume = os.path.join(output_dir, 'checkpoint-%d.pth' % latest_ckpt) 483 | print("Auto resume checkpoint: %s" % args.resume) 484 | 485 | if args.resume: 486 | if args.resume.startswith('https'): 487 | checkpoint = torch.hub.load_state_dict_from_url( 488 | args.resume, map_location='cpu', check_hash=True) 489 | else: 490 | checkpoint = torch.load(args.resume, map_location='cpu') 491 | model_without_ddp.load_state_dict(checkpoint['model']) 492 | print("Resume checkpoint %s" % args.resume) 493 | if 'optimizer' in checkpoint and 'epoch' in checkpoint: 494 | optimizer.load_state_dict(checkpoint['optimizer']) 495 | if not isinstance(checkpoint['epoch'], str): # does not support resuming with 'best', 'best-ema' 496 | args.start_epoch = checkpoint['epoch'] + 1 497 | else: 498 | assert args.eval, 'Does not support resuming with checkpoint-best' 499 | if hasattr(args, 'model_ema') and args.model_ema: 500 | if 'model_ema' in checkpoint.keys(): 501 | model_ema.ema.load_state_dict(checkpoint['model_ema']) 502 | else: 503 | model_ema.ema.load_state_dict(checkpoint['model']) 504 | if 'scaler' in checkpoint: 505 | loss_scaler.load_state_dict(checkpoint['scaler']) 506 | print("With optim & sched!") 507 | 508 | def reg_scheduler(base_value, final_value, epochs, niter_per_ep, early_epochs=0, early_value=None, 509 | mode='linear', early_mode='regular'): 510 | early_schedule = np.array([]) 511 | early_iters = early_epochs * niter_per_ep 512 | if early_value is None: 513 | early_value = final_value 514 | if early_epochs > 0: 515 | print(f"Set early value to {early_mode} {early_value}") 516 | if early_mode == 'regular': 517 | early_schedule = np.array([early_value] * early_iters) 518 | elif early_mode == 'linear': 519 | early_schedule = np.linspace(early_value, base_value, early_iters) 520 | elif early_mode == 'cosine': 521 | early_schedule = np.array( 522 | [base_value + 0.5 * (early_value - base_value) * (1 + math.cos(math.pi * i / early_iters)) for i in np.arange(early_iters)]) 523 | regular_epochs = epochs - early_epochs 524 | iters = np.arange(regular_epochs * niter_per_ep) 525 | schedule = np.linspace(base_value, final_value, len(iters)) 526 | schedule = np.concatenate((early_schedule, schedule)) 527 | 528 | assert len(schedule) == epochs * niter_per_ep 529 | return schedule 530 | 531 | def build_model(args): 532 | if args.model.startswith("convnext"): 533 | model = create_model( 534 | args.model, 535 | pretrained=False, 536 | num_classes=args.nb_classes, 537 | drop_path_rate=args.drop_path, 538 | layer_scale_init_value=args.layer_scale_init_value, 539 | head_init_scale=args.head_init_scale, 540 | drop_rate=args.dropout, 541 | ) 542 | else: 543 | model = create_model( 544 | args.model, 545 | pretrained=False, 546 | num_classes=args.nb_classes, 547 | drop_path_rate=args.drop_path, 548 | drop_rate =args.dropout 549 | ) 550 | return model --------------------------------------------------------------------------------