├── .gitmodules ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── INSTALL.md ├── LICENSE ├── README.md ├── TRAINING.md ├── datasets.py ├── engine_finetune.py ├── engine_pretrain.py ├── figures ├── fcmae_convnextv2.png └── model_scaling.png ├── main_finetune.py ├── main_pretrain.py ├── models ├── convnextv2.py ├── convnextv2_sparse.py ├── fcmae.py └── utils.py ├── optim_factory.py ├── submitit_finetune.py ├── submitit_pretrain.py └── utils.py /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "MinkowskiEngine"] 2 | path = MinkowskiEngine 3 | url = git@github.com:shwoo93/MinkowskiEngine.git 4 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to ConvNeXt 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Meta's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to ConvNeXt, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. 32 | -------------------------------------------------------------------------------- /INSTALL.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | We provide installation instructions for ImageNet classification experiments here. 4 | 5 | ## Dependency Setup 6 | Create an new conda virtual environment 7 | ``` 8 | conda create -n convnextv2 python=3.8 -y 9 | conda activate convnextv2 10 | ``` 11 | 12 | Install [Pytorch](https://pytorch.org/)>=1.8.0, [torchvision](https://pytorch.org/vision/stable/index.html)>=0.9.0 following official instructions. For example: 13 | ``` 14 | pip install torch==1.8.0+cu111 torchvision==0.9.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html 15 | ``` 16 | 17 | Clone this repo and install required packages: 18 | ``` 19 | git clone https://github.com/facebookresearch/ConvNeXt-V2.git 20 | pip install timm==0.3.2 tensorboardX six 21 | pip install submitit 22 | conda install openblas-devel -c anaconda -y 23 | ``` 24 | 25 | Install MinkowskiEngine: 26 | 27 | *(Note: we have implemented a customized CUDA kernel for depth-wise convolutions, which the original MinkowskiEngine does not support.)* 28 | ``` 29 | git submodule update --init --recursive 30 | git submodule update --recursive --remote 31 | cd MinkowskiEngine 32 | python setup.py install --blas_include_dirs=${CONDA_PREFIX}/include --blas=openblas 33 | ``` 34 | 35 | Install apex 36 | ``` 37 | git clone https://github.com/NVIDIA/apex 38 | cd apex 39 | pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ 40 | cd .. 41 | ``` 42 | 43 | ## Dataset Preparation 44 | 45 | Download the [ImageNet-1K](http://image-net.org/) classification dataset and structure the data as follows: 46 | ``` 47 | /path/to/imagenet-1k/ 48 | train/ 49 | class1/ 50 | img1.jpeg 51 | class2/ 52 | img2.jpeg 53 | val/ 54 | class1/ 55 | img3.jpeg 56 | class2/ 57 | img4.jpeg 58 | ``` 59 | 60 | For pre-training on [ImageNet-22K](http://image-net.org/), download the dataset and structure the data as follows: 61 | ``` 62 | /path/to/imagenet-22k/ 63 | class1/ 64 | img1.jpeg 65 | class2/ 66 | img2.jpeg 67 | class3/ 68 | img3.jpeg 69 | class4/ 70 | img4.jpeg 71 | ``` 72 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | ======================================================================= 3 | 4 | Copyright (c) Meta Platforms, Inc. and affiliates. 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | ======================================================================= 24 | 25 | Attribution-NonCommercial 4.0 International 26 | 27 | ======================================================================= 28 | 29 | Creative Commons Corporation ("Creative Commons") is not a law firm and 30 | does not provide legal services or legal advice. Distribution of 31 | Creative Commons public licenses does not create a lawyer-client or 32 | other relationship. Creative Commons makes its licenses and related 33 | information available on an "as-is" basis. Creative Commons gives no 34 | warranties regarding its licenses, any material licensed under their 35 | terms and conditions, or any related information. Creative Commons 36 | disclaims all liability for damages resulting from their use to the 37 | fullest extent possible. 38 | 39 | Using Creative Commons Public Licenses 40 | 41 | Creative Commons public licenses provide a standard set of terms and 42 | conditions that creators and other rights holders may use to share 43 | original works of authorship and other material subject to copyright 44 | and certain other rights specified in the public license below. The 45 | following considerations are for informational purposes only, are not 46 | exhaustive, and do not form part of our licenses. 47 | 48 | Considerations for licensors: Our public licenses are 49 | intended for use by those authorized to give the public 50 | permission to use material in ways otherwise restricted by 51 | copyright and certain other rights. Our licenses are 52 | irrevocable. Licensors should read and understand the terms 53 | and conditions of the license they choose before applying it. 54 | Licensors should also secure all rights necessary before 55 | applying our licenses so that the public can reuse the 56 | material as expected. Licensors should clearly mark any 57 | material not subject to the license. This includes other CC- 58 | licensed material, or material used under an exception or 59 | limitation to copyright. More considerations for licensors: 60 | wiki.creativecommons.org/Considerations_for_licensors 61 | 62 | Considerations for the public: By using one of our public 63 | licenses, a licensor grants the public permission to use the 64 | licensed material under specified terms and conditions. If 65 | the licensor's permission is not necessary for any reason--for 66 | example, because of any applicable exception or limitation to 67 | copyright--then that use is not regulated by the license. Our 68 | licenses grant only permissions under copyright and certain 69 | other rights that a licensor has authority to grant. Use of 70 | the licensed material may still be restricted for other 71 | reasons, including because others have copyright or other 72 | rights in the material. A licensor may make special requests, 73 | such as asking that all changes be marked or described. 74 | Although not required by our licenses, you are encouraged to 75 | respect those requests where reasonable. More_considerations 76 | for the public: 77 | wiki.creativecommons.org/Considerations_for_licensees 78 | 79 | ======================================================================= 80 | 81 | Creative Commons Attribution-NonCommercial 4.0 International Public 82 | License 83 | 84 | By exercising the Licensed Rights (defined below), You accept and agree 85 | to be bound by the terms and conditions of this Creative Commons 86 | Attribution-NonCommercial 4.0 International Public License ("Public 87 | License"). To the extent this Public License may be interpreted as a 88 | contract, You are granted the Licensed Rights in consideration of Your 89 | acceptance of these terms and conditions, and the Licensor grants You 90 | such rights in consideration of benefits the Licensor receives from 91 | making the Licensed Material available under these terms and 92 | conditions. 93 | 94 | Section 1 -- Definitions. 95 | 96 | a. Adapted Material means material subject to Copyright and Similar 97 | Rights that is derived from or based upon the Licensed Material 98 | and in which the Licensed Material is translated, altered, 99 | arranged, transformed, or otherwise modified in a manner requiring 100 | permission under the Copyright and Similar Rights held by the 101 | Licensor. For purposes of this Public License, where the Licensed 102 | Material is a musical work, performance, or sound recording, 103 | Adapted Material is always produced where the Licensed Material is 104 | synched in timed relation with a moving image. 105 | 106 | b. Adapter's License means the license You apply to Your Copyright 107 | and Similar Rights in Your contributions to Adapted Material in 108 | accordance with the terms and conditions of this Public License. 109 | 110 | c. Copyright and Similar Rights means copyright and/or similar rights 111 | closely related to copyright including, without limitation, 112 | performance, broadcast, sound recording, and Sui Generis Database 113 | Rights, without regard to how the rights are labeled or 114 | categorized. For purposes of this Public License, the rights 115 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 116 | Rights. 117 | d. Effective Technological Measures means those measures that, in the 118 | absence of proper authority, may not be circumvented under laws 119 | fulfilling obligations under Article 11 of the WIPO Copyright 120 | Treaty adopted on December 20, 1996, and/or similar international 121 | agreements. 122 | 123 | e. Exceptions and Limitations means fair use, fair dealing, and/or 124 | any other exception or limitation to Copyright and Similar Rights 125 | that applies to Your use of the Licensed Material. 126 | 127 | f. Licensed Material means the artistic or literary work, database, 128 | or other material to which the Licensor applied this Public 129 | License. 130 | 131 | g. Licensed Rights means the rights granted to You subject to the 132 | terms and conditions of this Public License, which are limited to 133 | all Copyright and Similar Rights that apply to Your use of the 134 | Licensed Material and that the Licensor has authority to license. 135 | 136 | h. Licensor means the individual(s) or entity(ies) granting rights 137 | under this Public License. 138 | 139 | i. NonCommercial means not primarily intended for or directed towards 140 | commercial advantage or monetary compensation. For purposes of 141 | this Public License, the exchange of the Licensed Material for 142 | other material subject to Copyright and Similar Rights by digital 143 | file-sharing or similar means is NonCommercial provided there is 144 | no payment of monetary compensation in connection with the 145 | exchange. 146 | 147 | j. Share means to provide material to the public by any means or 148 | process that requires permission under the Licensed Rights, such 149 | as reproduction, public display, public performance, distribution, 150 | dissemination, communication, or importation, and to make material 151 | available to the public including in ways that members of the 152 | public may access the material from a place and at a time 153 | individually chosen by them. 154 | 155 | k. Sui Generis Database Rights means rights other than copyright 156 | resulting from Directive 96/9/EC of the European Parliament and of 157 | the Council of 11 March 1996 on the legal protection of databases, 158 | as amended and/or succeeded, as well as other essentially 159 | equivalent rights anywhere in the world. 160 | 161 | l. You means the individual or entity exercising the Licensed Rights 162 | under this Public License. Your has a corresponding meaning. 163 | 164 | Section 2 -- Scope. 165 | 166 | a. License grant. 167 | 168 | 1. Subject to the terms and conditions of this Public License, 169 | the Licensor hereby grants You a worldwide, royalty-free, 170 | non-sublicensable, non-exclusive, irrevocable license to 171 | exercise the Licensed Rights in the Licensed Material to: 172 | 173 | a. reproduce and Share the Licensed Material, in whole or 174 | in part, for NonCommercial purposes only; and 175 | 176 | b. produce, reproduce, and Share Adapted Material for 177 | NonCommercial purposes only. 178 | 179 | 2. Exceptions and Limitations. For the avoidance of doubt, where 180 | Exceptions and Limitations apply to Your use, this Public 181 | License does not apply, and You do not need to comply with 182 | its terms and conditions. 183 | 184 | 3. Term. The term of this Public License is specified in Section 185 | 6(a). 186 | 187 | 4. Media and formats; technical modifications allowed. The 188 | Licensor authorizes You to exercise the Licensed Rights in 189 | all media and formats whether now known or hereafter created, 190 | and to make technical modifications necessary to do so. The 191 | Licensor waives and/or agrees not to assert any right or 192 | authority to forbid You from making technical modifications 193 | necessary to exercise the Licensed Rights, including 194 | technical modifications necessary to circumvent Effective 195 | Technological Measures. For purposes of this Public License, 196 | simply making modifications authorized by this Section 2(a) 197 | (4) never produces Adapted Material. 198 | 199 | 5. Downstream recipients. 200 | 201 | a. Offer from the Licensor -- Licensed Material. Every 202 | recipient of the Licensed Material automatically 203 | receives an offer from the Licensor to exercise the 204 | Licensed Rights under the terms and conditions of this 205 | Public License. 206 | 207 | b. No downstream restrictions. You may not offer or impose 208 | any additional or different terms or conditions on, or 209 | apply any Effective Technological Measures to, the 210 | Licensed Material if doing so restricts exercise of the 211 | Licensed Rights by any recipient of the Licensed 212 | Material. 213 | 214 | 6. No endorsement. Nothing in this Public License constitutes or 215 | may be construed as permission to assert or imply that You 216 | are, or that Your use of the Licensed Material is, connected 217 | with, or sponsored, endorsed, or granted official status by, 218 | the Licensor or others designated to receive attribution as 219 | provided in Section 3(a)(1)(A)(i). 220 | 221 | b. Other rights. 222 | 223 | 1. Moral rights, such as the right of integrity, are not 224 | licensed under this Public License, nor are publicity, 225 | privacy, and/or other similar personality rights; however, to 226 | the extent possible, the Licensor waives and/or agrees not to 227 | assert any such rights held by the Licensor to the limited 228 | extent necessary to allow You to exercise the Licensed 229 | Rights, but not otherwise. 230 | 231 | 2. Patent and trademark rights are not licensed under this 232 | Public License. 233 | 234 | 3. To the extent possible, the Licensor waives any right to 235 | collect royalties from You for the exercise of the Licensed 236 | Rights, whether directly or through a collecting society 237 | under any voluntary or waivable statutory or compulsory 238 | licensing scheme. In all other cases the Licensor expressly 239 | reserves any right to collect such royalties, including when 240 | the Licensed Material is used other than for NonCommercial 241 | purposes. 242 | 243 | Section 3 -- License Conditions. 244 | 245 | Your exercise of the Licensed Rights is expressly made subject to the 246 | following conditions. 247 | 248 | a. Attribution. 249 | 250 | 1. If You Share the Licensed Material (including in modified 251 | form), You must: 252 | 253 | a. retain the following if it is supplied by the Licensor 254 | with the Licensed Material: 255 | 256 | i. identification of the creator(s) of the Licensed 257 | Material and any others designated to receive 258 | attribution, in any reasonable manner requested by 259 | the Licensor (including by pseudonym if 260 | designated); 261 | 262 | ii. a copyright notice; 263 | 264 | iii. a notice that refers to this Public License; 265 | 266 | iv. a notice that refers to the disclaimer of 267 | warranties; 268 | 269 | v. a URI or hyperlink to the Licensed Material to the 270 | extent reasonably practicable; 271 | 272 | b. indicate if You modified the Licensed Material and 273 | retain an indication of any previous modifications; and 274 | 275 | c. indicate the Licensed Material is licensed under this 276 | Public License, and include the text of, or the URI or 277 | hyperlink to, this Public License. 278 | 279 | 2. You may satisfy the conditions in Section 3(a)(1) in any 280 | reasonable manner based on the medium, means, and context in 281 | which You Share the Licensed Material. For example, it may be 282 | reasonable to satisfy the conditions by providing a URI or 283 | hyperlink to a resource that includes the required 284 | information. 285 | 286 | 3. If requested by the Licensor, You must remove any of the 287 | information required by Section 3(a)(1)(A) to the extent 288 | reasonably practicable. 289 | 290 | 4. If You Share Adapted Material You produce, the Adapter's 291 | License You apply must not prevent recipients of the Adapted 292 | Material from complying with this Public License. 293 | 294 | Section 4 -- Sui Generis Database Rights. 295 | 296 | Where the Licensed Rights include Sui Generis Database Rights that 297 | apply to Your use of the Licensed Material: 298 | 299 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 300 | to extract, reuse, reproduce, and Share all or a substantial 301 | portion of the contents of the database for NonCommercial purposes 302 | only; 303 | 304 | b. if You include all or a substantial portion of the database 305 | contents in a database in which You have Sui Generis Database 306 | Rights, then the database in which You have Sui Generis Database 307 | Rights (but not its individual contents) is Adapted Material; and 308 | 309 | c. You must comply with the conditions in Section 3(a) if You Share 310 | all or a substantial portion of the contents of the database. 311 | 312 | For the avoidance of doubt, this Section 4 supplements and does not 313 | replace Your obligations under this Public License where the Licensed 314 | Rights include other Copyright and Similar Rights. 315 | 316 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 317 | 318 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 319 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 320 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 321 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 322 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 323 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 324 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 325 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 326 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 327 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 328 | 329 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 330 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 331 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 332 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 333 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 334 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 335 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 336 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 337 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 338 | 339 | c. The disclaimer of warranties and limitation of liability provided 340 | above shall be interpreted in a manner that, to the extent 341 | possible, most closely approximates an absolute disclaimer and 342 | waiver of all liability. 343 | 344 | Section 6 -- Term and Termination. 345 | 346 | a. This Public License applies for the term of the Copyright and 347 | Similar Rights licensed here. However, if You fail to comply with 348 | this Public License, then Your rights under this Public License 349 | terminate automatically. 350 | 351 | b. Where Your right to use the Licensed Material has terminated under 352 | Section 6(a), it reinstates: 353 | 354 | 1. automatically as of the date the violation is cured, provided 355 | it is cured within 30 days of Your discovery of the 356 | violation; or 357 | 358 | 2. upon express reinstatement by the Licensor. 359 | 360 | For the avoidance of doubt, this Section 6(b) does not affect any 361 | right the Licensor may have to seek remedies for Your violations 362 | of this Public License. 363 | 364 | c. For the avoidance of doubt, the Licensor may also offer the 365 | Licensed Material under separate terms or conditions or stop 366 | distributing the Licensed Material at any time; however, doing so 367 | will not terminate this Public License. 368 | 369 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 370 | License. 371 | 372 | Section 7 -- Other Terms and Conditions. 373 | 374 | a. The Licensor shall not be bound by any additional or different 375 | terms or conditions communicated by You unless expressly agreed. 376 | 377 | b. Any arrangements, understandings, or agreements regarding the 378 | Licensed Material not stated herein are separate from and 379 | independent of the terms and conditions of this Public License. 380 | 381 | Section 8 -- Interpretation. 382 | 383 | a. For the avoidance of doubt, this Public License does not, and 384 | shall not be interpreted to, reduce, limit, restrict, or impose 385 | conditions on any use of the Licensed Material that could lawfully 386 | be made without permission under this Public License. 387 | 388 | b. To the extent possible, if any provision of this Public License is 389 | deemed unenforceable, it shall be automatically reformed to the 390 | minimum extent necessary to make it enforceable. If the provision 391 | cannot be reformed, it shall be severed from this Public License 392 | without affecting the enforceability of the remaining terms and 393 | conditions. 394 | 395 | c. No term or condition of this Public License will be waived and no 396 | failure to comply consented to unless expressly agreed to by the 397 | Licensor. 398 | 399 | d. Nothing in this Public License constitutes or may be interpreted 400 | as a limitation upon, or waiver of, any privileges and immunities 401 | that apply to the Licensor or You, including from the legal 402 | processes of any jurisdiction or authority. 403 | 404 | ======================================================================= 405 | 406 | Creative Commons is not a party to its public 407 | licenses. Notwithstanding, Creative Commons may elect to apply one of 408 | its public licenses to material it publishes and in those instances 409 | will be considered the “Licensor.” The text of the Creative Commons 410 | public licenses is dedicated to the public domain under the CC0 Public 411 | Domain Dedication. Except for the limited purpose of indicating that 412 | material is shared under a Creative Commons public license or as 413 | otherwise permitted by the Creative Commons policies published at 414 | creativecommons.org/policies, Creative Commons does not authorize the 415 | use of the trademark "Creative Commons" or any other trademark or logo 416 | of Creative Commons without its prior written consent including, 417 | without limitation, in connection with any unauthorized modifications 418 | to any of its public licenses or any other arrangements, 419 | understandings, or agreements concerning use of licensed material. For 420 | the avoidance of doubt, this paragraph does not form part of the 421 | public licenses. 422 | 423 | Creative Commons may be contacted at creativecommons.org. 424 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## ConvNeXt V2
Official PyTorch Implementation 2 | 3 | This repo contains the PyTorch version of *8* model definitions (*Atto, Femto, Pico, Nano, Tiny, Base, Large, Huge*), pre-training/fine-tuning code and pre-trained weights (converted from JAX weights trained on TPU) for our ConvNeXt V2 paper. 4 | 5 | > [**ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders**](http://arxiv.org/abs/2301.00808)
6 | > [Sanghyun Woo](https://sites.google.com/view/sanghyunwoo/), [Shoubhik Debnath](https://www.linkedin.com/in/shoubhik-debnath-41268570/), [Ronghang Hu](https://ronghanghu.com/), [Xinlei Chen](https://xinleic.xyz/), [Zhuang Liu](https://liuzhuang13.github.io/), [In So Kweon](https://scholar.google.com/citations?user=XA8EOlEAAAAJ&hl=en) and [Saining Xie](https://sainingxie.com)\ 7 | >
KAIST, Meta AI and New York University
8 | 9 | We propose a fully convolutional masked autoencoder framework (FCMAE) and a new Global Response Normalization (GRN) layer that can be added to the ConvNeXt architecture to enhance inter-channel feature competition. This co-design of self-supervised learning techniques and architectural improvement results in a new model family called ConvNeXt V2, which significantly improves the performance of pure ConvNets on various recognition benchmarks. We also provide pre-trained ConvNeXt V2 models of various sizes. 10 | 11 |

12 | 13 |

14 | 15 |

16 | 17 |

18 | 19 | ## Results and Pre-trained Models 20 | ### ImageNet-1K FCMAE pre-trained weights (*self-supervised*) 21 | | name | resolution | #params | model | 22 | |:---:|:---:|:---:|:---:| 23 | | ConvNeXt V2-A | 224x224 | 3.7M | [model](https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_atto_1k_224_fcmae.pt) | 24 | | ConvNeXt V2-F | 224x224 | 5.2M | [model](https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_femto_1k_224_fcmae.pt) | 25 | | ConvNeXt V2-P | 224x224 | 9.1M | [model](https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_pico_1k_224_fcmae.pt) | 26 | | ConvNeXt V2-N | 224x224 | 15.6M| [model](https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_nano_1k_224_fcmae.pt) | 27 | | ConvNeXt V2-T | 224x224 | 28.6M| [model](https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_tiny_1k_224_fcmae.pt) | 28 | | ConvNeXt V2-B | 224x224 | 89M | [model](https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_base_1k_224_fcmae.pt) | 29 | | ConvNeXt V2-L | 224x224 | 198M | [model](https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_large_1k_224_fcmae.pt) | 30 | | ConvNeXt V2-H | 224x224 | 660M | [model](https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_huge_1k_224_fcmae.pt) | 31 | 32 | ### ImageNet-1K fine-tuned models 33 | 34 | | name | resolution |acc@1 | #params | FLOPs | model | 35 | |:---:|:---:|:---:|:---:| :---:|:---:| 36 | | ConvNeXt V2-A | 224x224 | 76.7 | 3.7M | 0.55G | [model](https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_atto_1k_224_ema.pt) | 37 | | ConvNeXt V2-F | 224x224 | 78.5 | 5.2M | 0.78G | [model](https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_femto_1k_224_ema.pt) | 38 | | ConvNeXt V2-P | 224x224 | 80.3 | 9.1M | 1.37G | [model](https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_pico_1k_224_ema.pt) | 39 | | ConvNeXt V2-N | 224x224 | 81.9 | 15.6M | 2.45G | [model](https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_nano_1k_224_ema.pt) | 40 | | ConvNeXt V2-T | 224x224 | 83.0 | 28.6M | 4.47G | [model](https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_tiny_1k_224_ema.pt) | 41 | | ConvNeXt V2-B | 224x224 | 84.9 | 89M | 15.4G | [model](https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_base_1k_224_ema.pt) | 42 | | ConvNeXt V2-L | 224x224 | 85.8 | 198M | 34.4G | [model](https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_large_1k_224_ema.pt) | 43 | | ConvNeXt V2-H | 224x224 | 86.3 | 660M | 115G | [model](https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_huge_1k_224_ema.pt) | 44 | 45 | ### ImageNet-22K fine-tuned models 46 | 47 | | name | resolution |acc@1 | #params | FLOPs | model | 48 | |:---:|:---:|:---:|:---:| :---:| :---:| 49 | | ConvNeXt V2-N | 224x224 | 82.1 | 15.6M | 2.45G | [model](https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_nano_22k_224_ema.pt)| 50 | | ConvNeXt V2-N | 384x384 | 83.4 | 15.6M | 7.21G | [model](https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_nano_22k_384_ema.pt)| 51 | | ConvNeXt V2-T | 224x224 | 83.9 | 28.6M | 4.47G | [model](https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_tiny_22k_224_ema.pt)| 52 | | ConvNeXt V2-T | 384x384 | 85.1 | 28.6M | 13.1G | [model](https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_tiny_22k_384_ema.pt)| 53 | | ConvNeXt V2-B | 224x224 | 86.8 | 89M | 15.4G | [model](https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_base_22k_224_ema.pt)| 54 | | ConvNeXt V2-B | 384x384 | 87.7 | 89M | 45.2G | [model](https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_base_22k_384_ema.pt)| 55 | | ConvNeXt V2-L | 224x224 | 87.3 | 198M | 34.4G | [model](https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_large_22k_224_ema.pt)| 56 | | ConvNeXt V2-L | 384x384 | 88.2 | 198M | 101.1G | [model](https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_large_22k_384_ema.pt)| 57 | | ConvNeXt V2-H | 384x384 | 88.7 | 660M | 337.9G | [model](https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_huge_22k_384_ema.pt)| 58 | | ConvNeXt V2-H | 512x512 | 88.9 | 660M | 600.8G | [model](https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_huge_22k_512_ema.pt)| 59 | 60 | ## Installation 61 | Please check [INSTALL.md](INSTALL.md) for installation instructions. 62 | 63 | ## Evaluation 64 | We provide example evaluation commands for ConvNeXt V2-Base: 65 | 66 | Single-GPU 67 | ``` 68 | python main_finetune.py \ 69 | --model convnextv2_base \ 70 | --eval true \ 71 | --resume /path/to/checkpoint \ 72 | --input_size 224 \ 73 | --data_path /path/to/imagenet-1k \ 74 | ``` 75 | Multi-GPU 76 | ``` 77 | python -m torch.distributed.launch --nproc_per_node=8 main_finetune.py \ 78 | --model convnextv2_base \ 79 | --eval true \ 80 | --resume /path/to/checkpoint \ 81 | --input_size 224 \ 82 | --data_path /path/to/imagenet-1k \ 83 | ``` 84 | 85 | - For evaluating other model variants, change `--model`, `--resume`, `--input_size` accordingly. URLs for the pre-trained models can be found from the result tables. 86 | - Setting model-specific `--drop_path` is not strictly required in evaluation, as the `DropPath` module in timm behaves the same during evaluation; but it is required in training. See [TRAINING.md](TRAINING.md) or our paper (appendix) for the values used for different models. 87 | 88 | ## Training 89 | See [TRAINING.md](TRAINING.md) for pre-training and fine-tuning instructions. 90 | 91 | ## Acknowledgement 92 | This repository borrows from [timm](https://github.com/rwightman/pytorch-image-models), [ConvNeXt](https://github.com/facebookresearch/ConvNeXt) and [MAE](https://github.com/facebookresearch/mae). 93 | 94 | We thank Ross Wightman for the initial design of the small-compute ConvNeXt model variants and the associated training recipe. We also appreciate the helpful discussions and feedback provided by Kaiming He. 95 | 96 | ## License 97 | This project is released under the MIT license except ImageNet pre-trained and fine-tuned models which are licensed under a CC-BY-NC. Please see the [LICENSE](LICENSE) file for more information. 98 | 99 | ## Citation 100 | If you find this repository helpful, please consider citing: 101 | ```bibtex 102 | @article{Woo2023ConvNeXtV2, 103 | title={ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders}, 104 | author={Sanghyun Woo, Shoubhik Debnath, Ronghang Hu, Xinlei Chen, Zhuang Liu, In So Kweon and Saining Xie}, 105 | year={2023}, 106 | journal={arXiv preprint arXiv:2301.00808}, 107 | } 108 | ``` 109 | -------------------------------------------------------------------------------- /TRAINING.md: -------------------------------------------------------------------------------- 1 | # Training 2 | 3 | We provide FCMAE ImageNet-1K pre-training and fine-tuning scripts here. 4 | Please check [INSTALL.md](INSTALL.md) for installation instructions first. 5 | 6 | ## Multi-node Training 7 | We use multi-node training on a SLURM cluster with [submitit](https://github.com/facebookincubator/submitit) for reproducing the results in the paper. Please install: 8 | ``` 9 | pip install submitit 10 | ``` 11 | We provide example commands for both multi-node and single-machine training below. 12 | 13 | 14 | ## ImageNet-1K FCMAE Pre-Training 15 | ConvNeXt V2-Base pre-training on ImageNet-1K with 8 8-GPU nodes: 16 | ``` 17 | python submitit_pretrain.py --nodes 8 --ngpus 8 \ 18 | --model convnextv2_base \ 19 | --batch_size 64 \ 20 | --blr 1.5e-4 \ 21 | --epochs 1600 \ 22 | --warmup_epochs 40 \ 23 | --data_path /path/to/imagenet-1k \ 24 | --job_dir /path/to/save_results 25 | ``` 26 | 27 | The following commands run the pre-training on a single machine: 28 | 29 | ``` 30 | python -m torch.distributed.launch --nproc_per_node=8 main_pretrain.py \ 31 | --model convnextv2_base \ 32 | --batch_size 64 --update_freq 8 \ 33 | --blr 1.5e-4 \ 34 | --epochs 1600 \ 35 | --warmup_epochs 40 \ 36 | --data_path /path/to/imagenet-1k \ 37 | --output_dir /path/to/save_results 38 | ``` 39 | 40 | 41 | ## ImageNet-1K Fine-Tuning 42 | 43 | ConvNeXt V2-Base fine-tuning on ImageNet-1K with 4 8-GPU nodes: 44 | ``` 45 | python submitit_finetune.py --nodes 4 --ngpus 8 \ 46 | --model convnextv2_base \ 47 | --batch_size 32 \ 48 | --blr 6.25e-4 \ 49 | --epochs 100 \ 50 | --warmup_epochs 20 \ 51 | --layer_decay_type 'group' \ 52 | --layer_decay 0.6 \ 53 | --weight_decay 0.05 \ 54 | --drop_path 0.1 \ 55 | --reprob 0.25 \ 56 | --mixup 0.8 \ 57 | --cutmix 1.0 \ 58 | --smoothing 0.1 \ 59 | --model_ema True --model_ema_eval True \ 60 | --use_amp True \ 61 | --finetune /path/to/checkpoint \ 62 | --data_path /path/to/imagenet-1k \ 63 | --job_dir /path/to/save_results 64 | ``` 65 | 66 | The following commands run the fine-tuning on a single machine: 67 | 68 | ``` 69 | python -m torch.distributed.launch --nproc_per_node=8 main_finetune.py \ 70 | --model convnextv2_base \ 71 | --batch_size 32 --update_freq 4 \ 72 | --blr 6.25e-4 \ 73 | --epochs 100 \ 74 | --warmup_epochs 20 \ 75 | --layer_decay_type 'group' \ 76 | --layer_decay 0.6 \ 77 | --weight_decay 0.05 \ 78 | --drop_path 0.1 \ 79 | --reprob 0.25 \ 80 | --mixup 0.8 \ 81 | --cutmix 1.0 \ 82 | --smoothing 0.1 \ 83 | --model_ema True --model_ema_eval True \ 84 | --use_amp True \ 85 | --finetune /path/to/checkpoint \ 86 | --data_path /path/to/imagenet-1k \ 87 | --output_dir /path/to/save_results 88 | ``` 89 | 90 |
91 | 92 | ConvNeXt-A 93 | 94 | 95 | ConvNeXt V2-Atto training on ImageNet-1K with 4 8-GPU nodes: 96 | ``` 97 | python submitit_finetune.py --nodes 4 --ngpus 8 \ 98 | --model convnextv2_atto \ 99 | --batch_size 32 \ 100 | --blr 2e-4 \ 101 | --epochs 600 \ 102 | --warmup_epochs 0 \ 103 | --layer_decay_type 'single' \ 104 | --layer_decay 0.9 \ 105 | --weight_decay 0.3 \ 106 | --drop_path 0.1 \ 107 | --reprob 0.25 \ 108 | --mixup 0. \ 109 | --cutmix 0. \ 110 | --smoothing 0.2 \ 111 | --model_ema True --model_ema_eval True \ 112 | --use_amp True \ 113 | --finetune /path/to/checkpoint \ 114 | --data_path /path/to/imagenet-1k \ 115 | --job_dir /path/to/save_results 116 | ``` 117 | 118 | The following commands run the fine-tuning on a single machine: 119 | ``` 120 | python -m torch.distributed.launch --nproc_per_node=8 main_finetune.py \ 121 | --model convnextv2_atto \ 122 | --batch_size 32 --update_freq 4 \ 123 | --blr 2e-4 \ 124 | --epochs 600 \ 125 | --warmup_epochs 0 \ 126 | --layer_decay_type 'single' \ 127 | --layer_decay 0.9 \ 128 | --weight_decay 0.3 \ 129 | --drop_path 0.1 \ 130 | --reprob 0.25 \ 131 | --mixup 0. \ 132 | --cutmix 0. \ 133 | --smoothing 0.2 \ 134 | --model_ema True --model_ema_eval True \ 135 | --use_amp True \ 136 | --finetune /path/to/checkpoint \ 137 | --data_path /path/to/imagenet-1k \ 138 | --output_dir /path/to/save_results 139 | ``` 140 |
141 | 142 |
143 | 144 | ConvNeXt-T 145 | 146 | 147 | ConvNeXt V2-Tiny training on ImageNet-1K with 4 8-GPU nodes: 148 | ``` 149 | python submitit_finetune.py --nodes 4 --ngpus 8 \ 150 | --model convnextv2_tiny \ 151 | --batch_size 32 \ 152 | --blr 8e-4 \ 153 | --epochs 300 \ 154 | --warmup_epochs 40 \ 155 | --layer_decay_type 'single' \ 156 | --layer_decay 0.9 \ 157 | --weight_decay 0.05 \ 158 | --drop_path 0.2 \ 159 | --reprob 0.25 \ 160 | --mixup 0.8 \ 161 | --cutmix 1.0 \ 162 | --smoothing 0.1 \ 163 | --model_ema True --model_ema_eval True \ 164 | --use_amp True \ 165 | --finetune /path/to/checkpoint \ 166 | --data_path /path/to/imagenet-1k \ 167 | --job_dir /path/to/save_results 168 | ``` 169 | 170 | The following commands run the fine-tuning on a single machine: 171 | ``` 172 | python -m torch.distributed.launch --nproc_per_node=8 main_finetune.py \ 173 | --model convnextv2_ \ 174 | --batch_size 32 --update_freq 4 \ 175 | --blr 8e-4 \ 176 | --epochs 300 \ 177 | --warmup_epochs 40 \ 178 | --layer_decay_type 'single' \ 179 | --layer_decay 0.9 \ 180 | --weight_decay 0.05 \ 181 | --drop_path 0.2 \ 182 | --reprob 0.25 \ 183 | --mixup 0.8 \ 184 | --cutmix 1.0 \ 185 | --smoothing 0.1 \ 186 | --model_ema True --model_ema_eval True \ 187 | --use_amp True \ 188 | --finetune /path/to/checkpoint \ 189 | --data_path /path/to/imagenet-1k \ 190 | --output_dir /path/to/save_results 191 | ``` 192 |
193 | 194 | ## Implementing FCMAE with Masked Convolution in JAX 195 | 196 | In our paper, we trained our main results using the JAX framework on TPU VM Pods. However, we do not have an efficient sparse convolution kernel implementation in this environment. Therefore, we have included our JAX model definition that uses a masked (dense) convolution for FCMAE pre-training. 197 | 198 | ```python 199 | 200 | from flax import linen as nn 201 | import jax.numpy as jnp 202 | 203 | class GRN(nn.Module): 204 | dim: int 205 | eps: float = 1e-6 206 | 207 | def init_fn(self, key, shape, fill_value): 208 | return jnp.full(shape, fill_value) 209 | 210 | @nn.compact 211 | def __call__(self, inputs, mask=None): 212 | gamma = self.param("gamma", self.init_fn, (self.dim,), 0.) 213 | beta = self.param("beta", self.init_fn, (self.dim,), 0.) 214 | 215 | x = inputs 216 | if mask is not None: 217 | x = x * (1. - mask) 218 | GX = jnp.power((jnp.sum(jnp.power(x, 2), axis=(1,2), keepdims=True) + self.eps), 0.5) 219 | Nx = Gx / (jnp.mean(Gx, axis=-1, keepdims=True) + self.eps) 220 | return gamma * (Nx * inputs) + beta + inputs 221 | 222 | class Block(nn.Module): 223 | dim: int 224 | drop_path: float 225 | 226 | @nn.compact 227 | def __call__(self, inputs, mask=None): 228 | if mask is not None: 229 | x = inputs * (1. - mask) 230 | x = DepthwiseConv2D((7, 7), name='dwconv')(x) 231 | if mask is not None: # The binary masking is numerically identical to sparse conv. 232 | x = x * (1.- mask) 233 | x = nn.LayerNorm(name='norm')(x) 234 | x = nn.Dense(4 * self.dim, name='pwconv1')(x) 235 | x = nn.gelu(x) 236 | x = GRN(4 * self.dim, name='grn')(x, mask) 237 | x = nn.Dense(self.dim, name='pwconv2')(x) 238 | x = nn.Dropout(rate=self.drop_path, broadcast_dims=(1,2,3), name='droppath')(x, deterministic=not self.training) 239 | return x + inputs 240 | ``` 241 | 242 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | 9 | import os 10 | from torchvision import datasets, transforms 11 | 12 | from timm.data.constants import \ 13 | IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD 14 | from timm.data import create_transform 15 | 16 | def build_dataset(is_train, args): 17 | transform = build_transform(is_train, args) 18 | 19 | print("Transform = ") 20 | if isinstance(transform, tuple): 21 | for trans in transform: 22 | print(" - - - - - - - - - - ") 23 | for t in trans.transforms: 24 | print(t) 25 | else: 26 | for t in transform.transforms: 27 | print(t) 28 | print("---------------------------") 29 | 30 | if args.data_set == 'CIFAR': 31 | dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform, download=True) 32 | nb_classes = 100 33 | elif args.data_set == 'IMNET': 34 | print("reading from datapath", args.data_path) 35 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 36 | dataset = datasets.ImageFolder(root, transform=transform) 37 | nb_classes = 1000 38 | elif args.data_set == "image_folder": 39 | root = args.data_path if is_train else args.eval_data_path 40 | dataset = datasets.ImageFolder(root, transform=transform) 41 | nb_classes = args.nb_classes 42 | assert len(dataset.class_to_idx) == nb_classes 43 | else: 44 | raise NotImplementedError() 45 | print("Number of the class = %d" % nb_classes) 46 | 47 | return dataset, nb_classes 48 | 49 | 50 | def build_transform(is_train, args): 51 | resize_im = args.input_size > 32 52 | imagenet_default_mean_and_std = args.imagenet_default_mean_and_std 53 | mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN 54 | std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD 55 | 56 | if is_train: 57 | # this should always dispatch to transforms_imagenet_train 58 | transform = create_transform( 59 | input_size=args.input_size, 60 | is_training=True, 61 | color_jitter=args.color_jitter, 62 | auto_augment=args.aa, 63 | interpolation=args.train_interpolation, 64 | re_prob=args.reprob, 65 | re_mode=args.remode, 66 | re_count=args.recount, 67 | mean=mean, 68 | std=std, 69 | ) 70 | if not resize_im: 71 | transform.transforms[0] = transforms.RandomCrop( 72 | args.input_size, padding=4) 73 | return transform 74 | 75 | t = [] 76 | if resize_im: 77 | # warping (no cropping) when evaluated at 384 or larger 78 | if args.input_size >= 384: 79 | t.append( 80 | transforms.Resize((args.input_size, args.input_size), 81 | interpolation=transforms.InterpolationMode.BICUBIC), 82 | ) 83 | print(f"Warping {args.input_size} size input images...") 84 | else: 85 | if args.crop_pct is None: 86 | args.crop_pct = 224 / 256 87 | size = int(args.input_size / args.crop_pct) 88 | t.append( 89 | # to maintain same ratio w.r.t. 224 images 90 | transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC), 91 | ) 92 | t.append(transforms.CenterCrop(args.input_size)) 93 | 94 | t.append(transforms.ToTensor()) 95 | t.append(transforms.Normalize(mean, std)) 96 | return transforms.Compose(t) 97 | -------------------------------------------------------------------------------- /engine_finetune.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | 9 | import math 10 | from typing import Iterable, Optional 11 | 12 | import torch 13 | 14 | from timm.data import Mixup 15 | from timm.utils import accuracy, ModelEma 16 | 17 | import utils 18 | from utils import adjust_learning_rate 19 | 20 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, 21 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 22 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 23 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, 24 | log_writer=None, args=None): 25 | model.train(True) 26 | metric_logger = utils.MetricLogger(delimiter=" ") 27 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 28 | header = 'Epoch: [{}]'.format(epoch) 29 | print_freq = 20 30 | 31 | update_freq = args.update_freq 32 | use_amp = args.use_amp 33 | optimizer.zero_grad() 34 | 35 | for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 36 | # we use a per iteration (instead of per epoch) lr scheduler 37 | if data_iter_step % update_freq == 0: 38 | adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) 39 | 40 | samples = samples.to(device, non_blocking=True) 41 | targets = targets.to(device, non_blocking=True) 42 | 43 | if mixup_fn is not None: 44 | samples, targets = mixup_fn(samples, targets) 45 | 46 | if use_amp: 47 | with torch.cuda.amp.autocast(): 48 | output = model(samples) 49 | loss = criterion(output, targets) 50 | else: # full precision 51 | output = model(samples) 52 | loss = criterion(output, targets) 53 | 54 | loss_value = loss.item() 55 | 56 | if not math.isfinite(loss_value): 57 | print("Loss is {}, stopping training".format(loss_value)) 58 | assert math.isfinite(loss_value) 59 | 60 | if use_amp: 61 | # this attribute is added by timm on one optimizer (adahessian) 62 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 63 | loss /= update_freq 64 | grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm, 65 | parameters=model.parameters(), create_graph=is_second_order, 66 | update_grad=(data_iter_step + 1) % update_freq == 0) 67 | if (data_iter_step + 1) % update_freq == 0: 68 | optimizer.zero_grad() 69 | if model_ema is not None: 70 | model_ema.update(model) 71 | else: # full precision 72 | loss /= update_freq 73 | loss.backward() 74 | if (data_iter_step + 1) % update_freq == 0: 75 | optimizer.step() 76 | optimizer.zero_grad() 77 | if model_ema is not None: 78 | model_ema.update(model) 79 | 80 | torch.cuda.synchronize() 81 | 82 | if mixup_fn is None: 83 | class_acc = (output.max(-1)[-1] == targets).float().mean() 84 | else: 85 | class_acc = None 86 | 87 | metric_logger.update(loss=loss_value) 88 | metric_logger.update(class_acc=class_acc) 89 | min_lr = 10. 90 | max_lr = 0. 91 | for group in optimizer.param_groups: 92 | min_lr = min(min_lr, group["lr"]) 93 | max_lr = max(max_lr, group["lr"]) 94 | 95 | metric_logger.update(lr=max_lr) 96 | metric_logger.update(min_lr=min_lr) 97 | weight_decay_value = None 98 | for group in optimizer.param_groups: 99 | if group["weight_decay"] > 0: 100 | weight_decay_value = group["weight_decay"] 101 | metric_logger.update(weight_decay=weight_decay_value) 102 | if use_amp: 103 | metric_logger.update(grad_norm=grad_norm) 104 | if log_writer is not None: 105 | log_writer.update(loss=loss_value, head="loss") 106 | log_writer.update(class_acc=class_acc, head="loss") 107 | log_writer.update(lr=max_lr, head="opt") 108 | log_writer.update(min_lr=min_lr, head="opt") 109 | log_writer.update(weight_decay=weight_decay_value, head="opt") 110 | if use_amp: 111 | log_writer.update(grad_norm=grad_norm, head="opt") 112 | log_writer.set_step() 113 | 114 | # gather the stats from all processes 115 | metric_logger.synchronize_between_processes() 116 | print("Averaged stats:", metric_logger) 117 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 118 | 119 | @torch.no_grad() 120 | def evaluate(data_loader, model, device, use_amp=False): 121 | criterion = torch.nn.CrossEntropyLoss() 122 | 123 | metric_logger = utils.MetricLogger(delimiter=" ") 124 | header = 'Test:' 125 | 126 | # switch to evaluation mode 127 | model.eval() 128 | 129 | for batch in metric_logger.log_every(data_loader, 10, header): 130 | images = batch[0] 131 | target = batch[-1] 132 | 133 | images = images.to(device, non_blocking=True) 134 | target = target.to(device, non_blocking=True) 135 | 136 | # compute output 137 | if use_amp: 138 | with torch.cuda.amp.autocast(): 139 | output = model(images) 140 | if isinstance(output, dict): 141 | output = output['logits'] 142 | loss = criterion(output, target) 143 | else: 144 | output = model(images) 145 | if isinstance(output, dict): 146 | output = output['logits'] 147 | loss = criterion(output, target) 148 | 149 | torch.cuda.synchronize() 150 | 151 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 152 | 153 | batch_size = images.shape[0] 154 | metric_logger.update(loss=loss.item()) 155 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 156 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 157 | # gather the stats from all processes 158 | metric_logger.synchronize_between_processes() 159 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 160 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 161 | 162 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} -------------------------------------------------------------------------------- /engine_pretrain.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | 9 | import math 10 | import sys 11 | from typing import Iterable 12 | 13 | import torch 14 | import utils 15 | 16 | def train_one_epoch(model: torch.nn.Module, 17 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 18 | device: torch.device, epoch: int, loss_scaler, 19 | log_writer=None, 20 | args=None): 21 | model.train(True) 22 | metric_logger = utils.MetricLogger(delimiter=" ") 23 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 24 | header = 'Epoch: [{}]'.format(epoch) 25 | print_freq = 20 26 | 27 | update_freq = args.update_freq 28 | 29 | optimizer.zero_grad() 30 | for data_iter_step, (samples, labels) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 31 | # we use a per iteration (instead of per epoch) lr scheduler 32 | if data_iter_step % update_freq == 0: 33 | utils.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) 34 | 35 | if not isinstance(samples, list): 36 | samples = samples.to(device, non_blocking=True) 37 | labels = labels.to(device, non_blocking=True) 38 | 39 | loss, _, _ = model(samples, labels, mask_ratio=args.mask_ratio) 40 | 41 | loss_value = loss.item() 42 | 43 | if not math.isfinite(loss_value): 44 | print("Loss is {}, stopping training".format(loss_value)) 45 | sys.exit(1) 46 | 47 | loss /= update_freq 48 | loss_scaler(loss, optimizer, parameters=model.parameters(), 49 | update_grad=(data_iter_step + 1) % update_freq == 0) 50 | if (data_iter_step + 1) % update_freq == 0: 51 | optimizer.zero_grad() 52 | torch.cuda.empty_cache() # clear the GPU cache at a regular interval for training ME network 53 | 54 | metric_logger.update(loss=loss_value) 55 | 56 | lr = optimizer.param_groups[0]["lr"] 57 | metric_logger.update(lr=lr) 58 | 59 | loss_value_reduce = utils.all_reduce_mean(loss_value) 60 | if log_writer is not None and (data_iter_step + 1) % update_freq == 0: 61 | """ We use epoch_1000x as the x-axis in tensorboard. 62 | This calibrates different curves when batch size changes. 63 | """ 64 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) 65 | log_writer.update(train_loss=loss_value_reduce, head="loss", step=epoch_1000x) 66 | log_writer.update(lr=lr, head="opt", step=epoch_1000x) 67 | 68 | metric_logger.synchronize_between_processes() 69 | print("Averaged stats:", metric_logger) 70 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} -------------------------------------------------------------------------------- /figures/fcmae_convnextv2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ConvNeXt-V2/2553895753323c6fe0b2bf390683f5ea358a42b9/figures/fcmae_convnextv2.png -------------------------------------------------------------------------------- /figures/model_scaling.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ConvNeXt-V2/2553895753323c6fe0b2bf390683f5ea358a42b9/figures/model_scaling.png -------------------------------------------------------------------------------- /main_finetune.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | 9 | import argparse 10 | import datetime 11 | import numpy as np 12 | import time 13 | import json 14 | import os 15 | from pathlib import Path 16 | 17 | import torch 18 | import torch.backends.cudnn as cudnn 19 | 20 | from timm.models.layers import trunc_normal_ 21 | from timm.data.mixup import Mixup 22 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 23 | from timm.utils import ModelEma 24 | from optim_factory import create_optimizer, LayerDecayValueAssigner 25 | 26 | from datasets import build_dataset 27 | from engine_finetune import train_one_epoch, evaluate 28 | 29 | import utils 30 | from utils import NativeScalerWithGradNormCount as NativeScaler 31 | from utils import str2bool, remap_checkpoint_keys 32 | import models.convnextv2 as convnextv2 33 | 34 | def get_args_parser(): 35 | parser = argparse.ArgumentParser('FCMAE fine-tuning', add_help=False) 36 | parser.add_argument('--batch_size', default=64, type=int, 37 | help='Per GPU batch size') 38 | parser.add_argument('--epochs', default=100, type=int) 39 | parser.add_argument('--update_freq', default=1, type=int, 40 | help='gradient accumulation steps') 41 | 42 | # Model parameters 43 | parser.add_argument('--model', default='convnextv2_base', type=str, metavar='MODEL', 44 | help='Name of model to train') 45 | parser.add_argument('--input_size', default=224, type=int, 46 | help='image input size') 47 | parser.add_argument('--drop_path', type=float, default=0., metavar='PCT', 48 | help='Drop path rate (default: 0.1)') 49 | parser.add_argument('--layer_decay_type', type=str, choices=['single', 'group'], default='single', 50 | help="""Layer decay strategies. The single strategy assigns a distinct decaying value for each layer, 51 | whereas the group strategy assigns the same decaying value for three consecutive layers""") 52 | 53 | # EMA related parameters 54 | parser.add_argument('--model_ema', type=str2bool, default=False) 55 | parser.add_argument('--model_ema_decay', type=float, default=0.9999, help='') 56 | parser.add_argument('--model_ema_force_cpu', type=str2bool, default=False, help='') 57 | parser.add_argument('--model_ema_eval', type=str2bool, default=False, help='Using ema to eval during training.') 58 | 59 | # Optimization parameters 60 | parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM', 61 | help='Clip gradient norm (default: None, no clipping)') 62 | parser.add_argument('--weight_decay', type=float, default=0.05, 63 | help='weight decay (default: 0.05)') 64 | parser.add_argument('--lr', type=float, default=None, metavar='LR', 65 | help='learning rate (absolute lr)') 66 | parser.add_argument('--blr', type=float, default=5e-4, metavar='LR', 67 | help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') 68 | parser.add_argument('--layer_decay', type=float, default=1.0) 69 | parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR', 70 | help='lower lr bound for cyclic schedulers that hit 0 (1e-6)') 71 | parser.add_argument('--warmup_epochs', type=int, default=20, metavar='N', 72 | help='epochs to warmup LR, if scheduler supports') 73 | 74 | parser.add_argument('--warmup_steps', type=int, default=-1, metavar='N', 75 | help='num of steps to warmup LR, will overload warmup_epochs if set > 0') 76 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', 77 | help='Optimizer (default: "adamw"') 78 | parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON', 79 | help='Optimizer Epsilon (default: 1e-8)') 80 | parser.add_argument('--opt_betas', default=None, type=float, nargs='+', metavar='BETA', 81 | help='Optimizer Betas (default: None, use opt default)') 82 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 83 | help='SGD momentum (default: 0.9)') 84 | parser.add_argument('--weight_decay_end', type=float, default=None, help="""Final value of the 85 | weight decay. We use a cosine schedule for WD and using a larger decay by 86 | the end of training improves performance for ViTs.""") 87 | 88 | # Augmentation parameters 89 | parser.add_argument('--color_jitter', type=float, default=None, metavar='PCT', 90 | help='Color jitter factor (enabled only when not using Auto/RandAug)') 91 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', 92 | help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)') 93 | parser.add_argument('--smoothing', type=float, default=0.1, 94 | help='Label smoothing (default: 0.1)') 95 | 96 | parser.add_argument('--train_interpolation', type=str, default='bicubic', 97 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")') 98 | 99 | # * Random Erase params 100 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', 101 | help='Random erase prob (default: 0.25)') 102 | parser.add_argument('--remode', type=str, default='pixel', 103 | help='Random erase mode (default: "pixel")') 104 | parser.add_argument('--recount', type=int, default=1, 105 | help='Random erase count (default: 1)') 106 | parser.add_argument('--resplit', type=str2bool, default=False, 107 | help='Do not random erase first (clean) augmentation split') 108 | 109 | # * Mixup params 110 | parser.add_argument('--mixup', type=float, default=0., 111 | help='mixup alpha, mixup enabled if > 0.') 112 | parser.add_argument('--cutmix', type=float, default=0., 113 | help='cutmix alpha, cutmix enabled if > 0.') 114 | parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None, 115 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') 116 | parser.add_argument('--mixup_prob', type=float, default=1.0, 117 | help='Probability of performing mixup or cutmix when either/both is enabled') 118 | parser.add_argument('--mixup_switch_prob', type=float, default=0.5, 119 | help='Probability of switching to cutmix when both mixup and cutmix enabled') 120 | parser.add_argument('--mixup_mode', type=str, default='batch', 121 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') 122 | 123 | # * Finetuning params 124 | parser.add_argument('--finetune', default='', 125 | help='finetune from checkpoint') 126 | parser.add_argument('--head_init_scale', default=0.001, type=float, 127 | help='classifier head initial scale, typically adjusted in fine-tuning') 128 | parser.add_argument('--model_key', default='model|module', type=str, 129 | help='which key to load from saved state dict, usually model or model_ema') 130 | parser.add_argument('--model_prefix', default='', type=str) 131 | 132 | # Dataset parameters 133 | parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str, 134 | help='dataset path') 135 | parser.add_argument('--nb_classes', default=1000, type=int, 136 | help='number of the classification types') 137 | parser.add_argument('--output_dir', default='', 138 | help='path where to save, empty for no saving') 139 | parser.add_argument('--log_dir', default=None, 140 | help='path where to tensorboard log') 141 | parser.add_argument('--device', default='cuda', 142 | help='device to use for training / testing') 143 | parser.add_argument('--seed', default=0, type=int) 144 | parser.add_argument('--resume', default='', 145 | help='resume from checkpoint') 146 | 147 | parser.add_argument('--eval_data_path', default=None, type=str, 148 | help='dataset path for evaluation') 149 | parser.add_argument('--imagenet_default_mean_and_std', type=str2bool, default=True) 150 | parser.add_argument('--data_set', default='IMNET', choices=['CIFAR', 'IMNET', 'image_folder'], 151 | type=str, help='ImageNet dataset path') 152 | parser.add_argument('--auto_resume', type=str2bool, default=True) 153 | parser.add_argument('--save_ckpt', type=str2bool, default=True) 154 | parser.add_argument('--save_ckpt_freq', default=1, type=int) 155 | parser.add_argument('--save_ckpt_num', default=3, type=int) 156 | 157 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 158 | help='start epoch') 159 | parser.add_argument('--eval', type=str2bool, default=False, 160 | help='Perform evaluation only') 161 | parser.add_argument('--dist_eval', type=str2bool, default=True, 162 | help='Enabling distributed evaluation') 163 | parser.add_argument('--disable_eval', type=str2bool, default=False, 164 | help='Disabling evaluation during training') 165 | parser.add_argument('--num_workers', default=10, type=int) 166 | parser.add_argument('--pin_mem', type=str2bool, default=True, 167 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 168 | 169 | # Evaluation parameters 170 | parser.add_argument('--crop_pct', type=float, default=None) 171 | 172 | # distributed training parameters 173 | parser.add_argument('--world_size', default=1, type=int, 174 | help='number of distributed processes') 175 | parser.add_argument('--local_rank', default=-1, type=int) 176 | parser.add_argument('--dist_on_itp', type=str2bool, default=False) 177 | parser.add_argument('--dist_url', default='env://', 178 | help='url used to set up distributed training') 179 | 180 | parser.add_argument('--use_amp', type=str2bool, default=False, 181 | help="Use apex AMP (Automatic Mixed Precision) or not") 182 | return parser 183 | 184 | def main(args): 185 | utils.init_distributed_mode(args) 186 | print(args) 187 | device = torch.device(args.device) 188 | 189 | # fix the seed for reproducibility 190 | seed = args.seed + utils.get_rank() 191 | torch.manual_seed(seed) 192 | np.random.seed(seed) 193 | cudnn.benchmark = True 194 | 195 | dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) 196 | if args.disable_eval: 197 | args.dist_eval = False 198 | dataset_val = None 199 | else: 200 | dataset_val, _ = build_dataset(is_train=False, args=args) 201 | 202 | num_tasks = utils.get_world_size() 203 | global_rank = utils.get_rank() 204 | sampler_train = torch.utils.data.DistributedSampler( 205 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True, seed=args.seed, 206 | ) 207 | print("Sampler_train = %s" % str(sampler_train)) 208 | if args.dist_eval: 209 | if len(dataset_val) % num_tasks != 0: 210 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 211 | 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 212 | 'equal num of samples per-process.') 213 | sampler_val = torch.utils.data.DistributedSampler( 214 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) 215 | else: 216 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 217 | 218 | if global_rank == 0 and args.log_dir is not None: 219 | os.makedirs(args.log_dir, exist_ok=True) 220 | log_writer = utils.TensorboardLogger(log_dir=args.log_dir) 221 | else: 222 | log_writer = None 223 | 224 | data_loader_train = torch.utils.data.DataLoader( 225 | dataset_train, sampler=sampler_train, 226 | batch_size=args.batch_size, 227 | num_workers=args.num_workers, 228 | pin_memory=args.pin_mem, 229 | drop_last=True, 230 | ) 231 | if dataset_val is not None: 232 | data_loader_val = torch.utils.data.DataLoader( 233 | dataset_val, sampler=sampler_val, 234 | batch_size=args.batch_size, 235 | num_workers=args.num_workers, 236 | pin_memory=args.pin_mem, 237 | drop_last=False 238 | ) 239 | else: 240 | data_loader_val = None 241 | 242 | mixup_fn = None 243 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None 244 | if mixup_active: 245 | print("Mixup is activated!") 246 | mixup_fn = Mixup( 247 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, 248 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, 249 | label_smoothing=args.smoothing, num_classes=args.nb_classes) 250 | 251 | model = convnextv2.__dict__[args.model]( 252 | num_classes=args.nb_classes, 253 | drop_path_rate=args.drop_path, 254 | head_init_scale=args.head_init_scale, 255 | ) 256 | 257 | if args.finetune: 258 | checkpoint = torch.load(args.finetune, map_location='cpu') 259 | 260 | print("Load pre-trained checkpoint from: %s" % args.finetune) 261 | checkpoint_model = checkpoint['model'] 262 | state_dict = model.state_dict() 263 | for k in ['head.weight', 'head.bias']: 264 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: 265 | print(f"Removing key {k} from pretrained checkpoint") 266 | del checkpoint_model[k] 267 | 268 | 269 | # remove decoder weights 270 | checkpoint_model_keys = list(checkpoint_model.keys()) 271 | for k in checkpoint_model_keys: 272 | if 'decoder' in k or 'mask_token'in k or \ 273 | 'proj' in k or 'pred' in k: 274 | print(f"Removing key {k} from pretrained checkpoint") 275 | del checkpoint_model[k] 276 | 277 | checkpoint_model = remap_checkpoint_keys(checkpoint_model) 278 | utils.load_state_dict(model, checkpoint_model, prefix=args.model_prefix) 279 | 280 | # manually initialize fc layer 281 | trunc_normal_(model.head.weight, std=2e-5) 282 | torch.nn.init.constant_(model.head.bias, 0.) 283 | 284 | model.to(device) 285 | 286 | model_ema = None 287 | if args.model_ema: 288 | # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper 289 | model_ema = ModelEma( 290 | model, 291 | decay=args.model_ema_decay, 292 | device='cpu' if args.model_ema_force_cpu else '', 293 | resume='') 294 | print("Using EMA with decay = %.8f" % args.model_ema_decay) 295 | 296 | model_without_ddp = model 297 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 298 | 299 | print("Model = %s" % str(model_without_ddp)) 300 | print('number of params:', n_parameters) 301 | 302 | eff_batch_size = args.batch_size * args.update_freq * utils.get_world_size() 303 | num_training_steps_per_epoch = len(dataset_train) // eff_batch_size 304 | 305 | if args.lr is None: 306 | args.lr = args.blr * eff_batch_size / 256 307 | 308 | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) 309 | print("actual lr: %.2e" % args.lr) 310 | 311 | print("accumulate grad iterations: %d" % args.update_freq) 312 | print("effective batch size: %d" % eff_batch_size) 313 | 314 | if args.layer_decay < 1.0 or args.layer_decay > 1.0: 315 | assert args.layer_decay_type in ['single', 'group'] 316 | if args.layer_decay_type == 'group': # applies for Base and Large models 317 | num_layers = 12 318 | else: 319 | num_layers = sum(model_without_ddp.depths) 320 | assigner = LayerDecayValueAssigner( 321 | list(args.layer_decay ** (num_layers + 1 - i) for i in range(num_layers + 2)), 322 | depths=model_without_ddp.depths, layer_decay_type=args.layer_decay_type) 323 | else: 324 | assigner = None 325 | 326 | if assigner is not None: 327 | print("Assigned values = %s" % str(assigner.values)) 328 | 329 | if args.distributed: 330 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False) 331 | model_without_ddp = model.module 332 | 333 | optimizer = create_optimizer( 334 | args, model_without_ddp, skip_list=None, 335 | get_num_layer=assigner.get_layer_id if assigner is not None else None, 336 | get_layer_scale=assigner.get_scale if assigner is not None else None) 337 | loss_scaler = NativeScaler() 338 | 339 | if mixup_fn is not None: 340 | # smoothing is handled with mixup label transform 341 | criterion = SoftTargetCrossEntropy() 342 | elif args.smoothing > 0.: 343 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) 344 | else: 345 | criterion = torch.nn.CrossEntropyLoss() 346 | 347 | print("criterion = %s" % str(criterion)) 348 | 349 | utils.auto_load_model( 350 | args=args, model=model, model_without_ddp=model_without_ddp, 351 | optimizer=optimizer, loss_scaler=loss_scaler, model_ema=model_ema) 352 | 353 | if args.eval: 354 | print(f"Eval only mode") 355 | test_stats = evaluate(data_loader_val, model, device) 356 | print(f"Accuracy of the network on {len(dataset_val)} test images: {test_stats['acc1']:.5f}%") 357 | return 358 | 359 | max_accuracy = 0.0 360 | if args.model_ema and args.model_ema_eval: 361 | max_accuracy_ema = 0.0 362 | 363 | print("Start training for %d epochs" % args.epochs) 364 | start_time = time.time() 365 | for epoch in range(args.start_epoch, args.epochs): 366 | if args.distributed: 367 | data_loader_train.sampler.set_epoch(epoch) 368 | if log_writer is not None: 369 | log_writer.set_step(epoch * num_training_steps_per_epoch * args.update_freq) 370 | train_stats = train_one_epoch( 371 | model, criterion, data_loader_train, 372 | optimizer, device, epoch, loss_scaler, 373 | args.clip_grad, model_ema, mixup_fn, 374 | log_writer=log_writer, 375 | args=args 376 | ) 377 | if args.output_dir and args.save_ckpt: 378 | if (epoch + 1) % args.save_ckpt_freq == 0 or epoch + 1 == args.epochs: 379 | utils.save_model( 380 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 381 | loss_scaler=loss_scaler, epoch=epoch, model_ema=model_ema) 382 | if data_loader_val is not None: 383 | test_stats = evaluate(data_loader_val, model, device, use_amp=args.use_amp) 384 | print(f"Accuracy of the model on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 385 | if max_accuracy < test_stats["acc1"]: 386 | max_accuracy = test_stats["acc1"] 387 | if args.output_dir and args.save_ckpt: 388 | utils.save_model( 389 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 390 | loss_scaler=loss_scaler, epoch="best", model_ema=model_ema) 391 | print(f'Max accuracy: {max_accuracy:.2f}%') 392 | 393 | if log_writer is not None: 394 | log_writer.update(test_acc1=test_stats['acc1'], head="perf", step=epoch) 395 | log_writer.update(test_acc5=test_stats['acc5'], head="perf", step=epoch) 396 | log_writer.update(test_loss=test_stats['loss'], head="perf", step=epoch) 397 | 398 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 399 | **{f'test_{k}': v for k, v in test_stats.items()}, 400 | 'epoch': epoch, 401 | 'n_parameters': n_parameters} 402 | 403 | # repeat testing routines for EMA, if ema eval is turned on 404 | if args.model_ema and args.model_ema_eval: 405 | test_stats_ema = evaluate(data_loader_val, model_ema.ema, device, use_amp=args.use_amp) 406 | print(f"Accuracy of the model EMA on {len(dataset_val)} test images: {test_stats_ema['acc1']:.1f}%") 407 | if max_accuracy_ema < test_stats_ema["acc1"]: 408 | max_accuracy_ema = test_stats_ema["acc1"] 409 | if args.output_dir and args.save_ckpt: 410 | utils.save_model( 411 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 412 | loss_scaler=loss_scaler, epoch="best-ema", model_ema=model_ema) 413 | print(f'Max EMA accuracy: {max_accuracy_ema:.2f}%') 414 | if log_writer is not None: 415 | log_writer.update(test_acc1_ema=test_stats_ema['acc1'], head="perf", step=epoch) 416 | log_stats.update({**{f'test_{k}_ema': v for k, v in test_stats_ema.items()}}) 417 | else: 418 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 419 | 'epoch': epoch, 420 | 'n_parameters': n_parameters} 421 | 422 | if args.output_dir and utils.is_main_process(): 423 | if log_writer is not None: 424 | log_writer.flush() 425 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: 426 | f.write(json.dumps(log_stats) + "\n") 427 | 428 | total_time = time.time() - start_time 429 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 430 | print('Training time {}'.format(total_time_str)) 431 | 432 | if __name__ == '__main__': 433 | parser = argparse.ArgumentParser('FCMAE fine-tuning', parents=[get_args_parser()]) 434 | args = parser.parse_args() 435 | if args.output_dir: 436 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 437 | main(args) -------------------------------------------------------------------------------- /main_pretrain.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | 9 | import argparse 10 | import datetime 11 | import numpy as np 12 | import time 13 | import json 14 | import os 15 | from pathlib import Path 16 | 17 | import torch 18 | import torch.backends.cudnn as cudnn 19 | import torchvision.transforms as transforms 20 | import torchvision.datasets as datasets 21 | 22 | import timm 23 | assert timm.__version__ == "0.3.2" # version check 24 | import timm.optim.optim_factory as optim_factory 25 | 26 | from engine_pretrain import train_one_epoch 27 | import models.fcmae as fcmae 28 | 29 | import utils 30 | from utils import NativeScalerWithGradNormCount as NativeScaler 31 | from utils import str2bool 32 | 33 | def get_args_parser(): 34 | parser = argparse.ArgumentParser('FCMAE pre-training', add_help=False) 35 | parser.add_argument('--batch_size', default=64, type=int, 36 | help='Per GPU batch size') 37 | parser.add_argument('--epochs', default=800, type=int) 38 | parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N', 39 | help='epochs to warmup LR') 40 | parser.add_argument('--update_freq', default=1, type=int, 41 | help='gradient accumulation step') 42 | 43 | # Model parameters 44 | parser.add_argument('--model', default='convnextv2_base', type=str, metavar='MODEL', 45 | help='Name of model to train') 46 | parser.add_argument('--input_size', default=224, type=int, 47 | help='image input size') 48 | parser.add_argument('--mask_ratio', default=0.6, type=float, 49 | help='Masking ratio (percentage of removed patches).') 50 | parser.add_argument('--norm_pix_loss', action='store_true', 51 | help='Use (per-patch) normalized pixels as targets for computing loss') 52 | parser.set_defaults(norm_pix_loss=True) 53 | parser.add_argument('--decoder_depth', type=int, default=1) 54 | parser.add_argument('--decoder_embed_dim', type=int, default=512) 55 | 56 | # Optimizer parameters 57 | parser.add_argument('--weight_decay', type=float, default=0.05, 58 | help='weight decay (default: 0.05)') 59 | parser.add_argument('--lr', type=float, default=None, metavar='LR', 60 | help='learning rate (absolute lr)') 61 | parser.add_argument('--blr', type=float, default=1.5e-4, metavar='LR', 62 | help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') 63 | parser.add_argument('--min_lr', type=float, default=0., metavar='LR', 64 | help='lower lr bound for cyclic schedulers that hit 0') 65 | 66 | # Dataset parameters 67 | parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str, 68 | help='dataset path') 69 | parser.add_argument('--output_dir', default='', 70 | help='path where to save, empty for no saving') 71 | parser.add_argument('--log_dir', default=None, 72 | help='path where to tensorboard log') 73 | parser.add_argument('--device', default='cuda', 74 | help='device to use for training / testing') 75 | parser.add_argument('--seed', default=0, type=int) 76 | parser.add_argument('--resume', default='', 77 | help='resume from checkpoint') 78 | 79 | parser.add_argument('--auto_resume', type=str2bool, default=True) 80 | parser.add_argument('--save_ckpt', type=str2bool, default=True) 81 | parser.add_argument('--save_ckpt_freq', default=1, type=int) 82 | parser.add_argument('--save_ckpt_num', default=3, type=int) 83 | 84 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 85 | help='start epoch') 86 | parser.add_argument('--num_workers', default=10, type=int) 87 | parser.add_argument('--pin_mem', type=str2bool, default=True, 88 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 89 | 90 | # Evaluation parameters 91 | parser.add_argument('--crop_pct', type=float, default=None) 92 | 93 | # distributed training parameters 94 | parser.add_argument('--world_size', default=1, type=int, 95 | help='number of distributed processes') 96 | parser.add_argument('--local_rank', default=-1, type=int) 97 | parser.add_argument('--dist_on_itp', type=str2bool, default=False) 98 | parser.add_argument('--dist_url', default='env://', 99 | help='url used to set up distributed training') 100 | return parser 101 | 102 | def main(args): 103 | utils.init_distributed_mode(args) 104 | 105 | print(args) 106 | device = torch.device(args.device) 107 | 108 | # fix the seed for reproducibility 109 | seed = args.seed + utils.get_rank() 110 | torch.manual_seed(seed) 111 | np.random.seed(seed) 112 | 113 | cudnn.benchmark = True 114 | 115 | # simple augmentation 116 | transform_train = transforms.Compose([ 117 | transforms.RandomResizedCrop(args.input_size, scale=(0.2, 1.0), interpolation=3), # 3 is bicubic 118 | transforms.RandomHorizontalFlip(), 119 | transforms.ToTensor(), 120 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 121 | dataset_train = datasets.ImageFolder(os.path.join(args.data_path, 'train'), transform=transform_train) 122 | print(dataset_train) 123 | 124 | num_tasks = utils.get_world_size() 125 | global_rank = utils.get_rank() 126 | 127 | sampler_train = torch.utils.data.DistributedSampler( 128 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True, seed=args.seed, 129 | ) 130 | print("Sampler_train = %s" % str(sampler_train)) 131 | 132 | if global_rank == 0 and args.log_dir is not None: 133 | os.makedirs(args.log_dir, exist_ok=True) 134 | # log_writer = SummaryWriter(log_dir=args.log_dir) 135 | log_writer = utils.TensorboardLogger(log_dir=args.log_dir) 136 | else: 137 | log_writer = None 138 | 139 | data_loader_train = torch.utils.data.DataLoader( 140 | dataset_train, sampler=sampler_train, 141 | batch_size=args.batch_size, 142 | num_workers=args.num_workers, 143 | pin_memory=args.pin_mem, 144 | drop_last=True, 145 | ) 146 | 147 | # define the model 148 | model = fcmae.__dict__[args.model]( 149 | mask_ratio=args.mask_ratio, 150 | decoder_depth=args.decoder_depth, 151 | decoder_embed_dim=args.decoder_embed_dim, 152 | norm_pix_loss=args.norm_pix_loss 153 | ) 154 | model.to(device) 155 | 156 | model_without_ddp = model 157 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 158 | 159 | print("Model = %s" % str(model_without_ddp)) 160 | print('number of params:', n_parameters) 161 | 162 | eff_batch_size = args.batch_size * args.update_freq * utils.get_world_size() 163 | num_training_steps_per_epoch = len(dataset_train) // eff_batch_size 164 | 165 | if args.lr is None: 166 | args.lr = args.blr * eff_batch_size / 256 167 | 168 | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) 169 | print("actual lr: %.2e" % args.lr) 170 | 171 | print("accumulate grad iterations: %d" % args.update_freq) 172 | print("effective batch size: %d" % eff_batch_size) 173 | 174 | if args.distributed: 175 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True) 176 | model_without_ddp = model.module 177 | 178 | param_groups = optim_factory.add_weight_decay(model_without_ddp, args.weight_decay) 179 | optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95)) 180 | print(optimizer) 181 | loss_scaler = NativeScaler() 182 | 183 | utils.auto_load_model( 184 | args=args, model=model, model_without_ddp=model_without_ddp, 185 | optimizer=optimizer, loss_scaler=loss_scaler) 186 | 187 | print(f"Start training for {args.epochs} epochs") 188 | start_time = time.time() 189 | for epoch in range(args.start_epoch, args.epochs): 190 | if args.distributed: 191 | data_loader_train.sampler.set_epoch(epoch) 192 | if log_writer is not None: 193 | log_writer.set_step(epoch * num_training_steps_per_epoch * args.update_freq) 194 | train_stats = train_one_epoch( 195 | model, data_loader_train, 196 | optimizer, device, epoch, loss_scaler, 197 | log_writer=log_writer, 198 | args=args 199 | ) 200 | if args.output_dir and args.save_ckpt: 201 | if (epoch + 1) % args.save_ckpt_freq == 0 or epoch + 1 == args.epochs: 202 | utils.save_model( 203 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 204 | loss_scaler=loss_scaler, epoch=epoch) 205 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 206 | 'epoch': epoch, 207 | 'n_parameters': n_parameters} 208 | if args.output_dir and utils.is_main_process(): 209 | if log_writer is not None: 210 | log_writer.flush() 211 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: 212 | f.write(json.dumps(log_stats) + "\n") 213 | 214 | total_time = time.time() - start_time 215 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 216 | print('Training time {}'.format(total_time_str)) 217 | 218 | if __name__ == '__main__': 219 | args = get_args_parser() 220 | args = args.parse_args() 221 | if args.output_dir: 222 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 223 | main(args) -------------------------------------------------------------------------------- /models/convnextv2.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from timm.models.layers import trunc_normal_, DropPath 12 | from .utils import LayerNorm, GRN 13 | 14 | class Block(nn.Module): 15 | """ ConvNeXtV2 Block. 16 | 17 | Args: 18 | dim (int): Number of input channels. 19 | drop_path (float): Stochastic depth rate. Default: 0.0 20 | """ 21 | def __init__(self, dim, drop_path=0.): 22 | super().__init__() 23 | self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv 24 | self.norm = LayerNorm(dim, eps=1e-6) 25 | self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers 26 | self.act = nn.GELU() 27 | self.grn = GRN(4 * dim) 28 | self.pwconv2 = nn.Linear(4 * dim, dim) 29 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 30 | 31 | def forward(self, x): 32 | input = x 33 | x = self.dwconv(x) 34 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) 35 | x = self.norm(x) 36 | x = self.pwconv1(x) 37 | x = self.act(x) 38 | x = self.grn(x) 39 | x = self.pwconv2(x) 40 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) 41 | 42 | x = input + self.drop_path(x) 43 | return x 44 | 45 | class ConvNeXtV2(nn.Module): 46 | """ ConvNeXt V2 47 | 48 | Args: 49 | in_chans (int): Number of input image channels. Default: 3 50 | num_classes (int): Number of classes for classification head. Default: 1000 51 | depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] 52 | dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] 53 | drop_path_rate (float): Stochastic depth rate. Default: 0. 54 | head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. 55 | """ 56 | def __init__(self, in_chans=3, num_classes=1000, 57 | depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], 58 | drop_path_rate=0., head_init_scale=1. 59 | ): 60 | super().__init__() 61 | self.depths = depths 62 | self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers 63 | stem = nn.Sequential( 64 | nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), 65 | LayerNorm(dims[0], eps=1e-6, data_format="channels_first") 66 | ) 67 | self.downsample_layers.append(stem) 68 | for i in range(3): 69 | downsample_layer = nn.Sequential( 70 | LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), 71 | nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2), 72 | ) 73 | self.downsample_layers.append(downsample_layer) 74 | 75 | self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks 76 | dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 77 | cur = 0 78 | for i in range(4): 79 | stage = nn.Sequential( 80 | *[Block(dim=dims[i], drop_path=dp_rates[cur + j]) for j in range(depths[i])] 81 | ) 82 | self.stages.append(stage) 83 | cur += depths[i] 84 | 85 | self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer 86 | self.head = nn.Linear(dims[-1], num_classes) 87 | 88 | self.apply(self._init_weights) 89 | self.head.weight.data.mul_(head_init_scale) 90 | self.head.bias.data.mul_(head_init_scale) 91 | 92 | def _init_weights(self, m): 93 | if isinstance(m, (nn.Conv2d, nn.Linear)): 94 | trunc_normal_(m.weight, std=.02) 95 | nn.init.constant_(m.bias, 0) 96 | 97 | def forward_features(self, x): 98 | for i in range(4): 99 | x = self.downsample_layers[i](x) 100 | x = self.stages[i](x) 101 | return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C) 102 | 103 | def forward(self, x): 104 | x = self.forward_features(x) 105 | x = self.head(x) 106 | return x 107 | 108 | def convnextv2_atto(**kwargs): 109 | model = ConvNeXtV2(depths=[2, 2, 6, 2], dims=[40, 80, 160, 320], **kwargs) 110 | return model 111 | 112 | def convnextv2_femto(**kwargs): 113 | model = ConvNeXtV2(depths=[2, 2, 6, 2], dims=[48, 96, 192, 384], **kwargs) 114 | return model 115 | 116 | def convnext_pico(**kwargs): 117 | model = ConvNeXtV2(depths=[2, 2, 6, 2], dims=[64, 128, 256, 512], **kwargs) 118 | return model 119 | 120 | def convnextv2_nano(**kwargs): 121 | model = ConvNeXtV2(depths=[2, 2, 8, 2], dims=[80, 160, 320, 640], **kwargs) 122 | return model 123 | 124 | def convnextv2_tiny(**kwargs): 125 | model = ConvNeXtV2(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs) 126 | return model 127 | 128 | def convnextv2_base(**kwargs): 129 | model = ConvNeXtV2(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) 130 | return model 131 | 132 | def convnextv2_large(**kwargs): 133 | model = ConvNeXtV2(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) 134 | return model 135 | 136 | def convnextv2_huge(**kwargs): 137 | model = ConvNeXtV2(depths=[3, 3, 27, 3], dims=[352, 704, 1408, 2816], **kwargs) 138 | return model -------------------------------------------------------------------------------- /models/convnextv2_sparse.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | 9 | import torch 10 | import torch.nn as nn 11 | from timm.models.layers import trunc_normal_ 12 | 13 | from .utils import ( 14 | LayerNorm, 15 | MinkowskiLayerNorm, 16 | MinkowskiGRN, 17 | MinkowskiDropPath 18 | ) 19 | from MinkowskiEngine import ( 20 | MinkowskiConvolution, 21 | MinkowskiDepthwiseConvolution, 22 | MinkowskiLinear, 23 | MinkowskiGELU 24 | ) 25 | from MinkowskiOps import ( 26 | to_sparse, 27 | ) 28 | 29 | class Block(nn.Module): 30 | """ Sparse ConvNeXtV2 Block. 31 | 32 | Args: 33 | dim (int): Number of input channels. 34 | drop_path (float): Stochastic depth rate. Default: 0.0 35 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. 36 | """ 37 | def __init__(self, dim, drop_path=0., D=3): 38 | super().__init__() 39 | self.dwconv = MinkowskiDepthwiseConvolution(dim, kernel_size=7, bias=True, dimension=D) 40 | self.norm = MinkowskiLayerNorm(dim, 1e-6) 41 | self.pwconv1 = MinkowskiLinear(dim, 4 * dim) 42 | self.act = MinkowskiGELU() 43 | self.pwconv2 = MinkowskiLinear(4 * dim, dim) 44 | self.grn = MinkowskiGRN(4 * dim) 45 | self.drop_path = MinkowskiDropPath(drop_path) 46 | 47 | def forward(self, x): 48 | input = x 49 | x = self.dwconv(x) 50 | x = self.norm(x) 51 | x = self.pwconv1(x) 52 | x = self.act(x) 53 | x = self.grn(x) 54 | x = self.pwconv2(x) 55 | x = input + self.drop_path(x) 56 | return x 57 | 58 | class SparseConvNeXtV2(nn.Module): 59 | """ Sparse ConvNeXtV2. 60 | 61 | Args: 62 | in_chans (int): Number of input image channels. Default: 3 63 | num_classes (int): Number of classes for classification head. Default: 1000 64 | depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] 65 | dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] 66 | drop_path_rate (float): Stochastic depth rate. Default: 0. 67 | head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. 68 | """ 69 | def __init__(self, 70 | in_chans=3, 71 | num_classes=1000, 72 | depths=[3, 3, 9, 3], 73 | dims=[96, 192, 384, 768], 74 | drop_path_rate=0., 75 | D=3): 76 | super().__init__() 77 | self.depths = depths 78 | self.num_classes = num_classes 79 | self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers 80 | stem = nn.Sequential( 81 | nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), 82 | LayerNorm(dims[0], eps=1e-6, data_format="channels_first") 83 | ) 84 | self.downsample_layers.append(stem) 85 | for i in range(3): 86 | downsample_layer = nn.Sequential( 87 | MinkowskiLayerNorm(dims[i], eps=1e-6), 88 | MinkowskiConvolution(dims[i], dims[i+1], kernel_size=2, stride=2, bias=True, dimension=D) 89 | ) 90 | self.downsample_layers.append(downsample_layer) 91 | 92 | self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks 93 | dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 94 | cur = 0 95 | for i in range(4): 96 | stage = nn.Sequential( 97 | *[Block(dim=dims[i], drop_path=dp_rates[cur + j], D=D) for j in range(depths[i])] 98 | ) 99 | self.stages.append(stage) 100 | cur += depths[i] 101 | 102 | self.apply(self._init_weights) 103 | 104 | def _init_weights(self, m): 105 | if isinstance(m, MinkowskiConvolution): 106 | trunc_normal_(m.kernel, std=.02) 107 | nn.init.constant_(m.bias, 0) 108 | if isinstance(m, MinkowskiDepthwiseConvolution): 109 | trunc_normal_(m.kernel, std=.02) 110 | nn.init.constant_(m.bias, 0) 111 | if isinstance(m, MinkowskiLinear): 112 | trunc_normal_(m.linear.weight, std=.02) 113 | nn.init.constant_(m.linear.bias, 0) 114 | 115 | def upsample_mask(self, mask, scale): 116 | assert len(mask.shape) == 2 117 | p = int(mask.shape[1] ** .5) 118 | return mask.reshape(-1, p, p).\ 119 | repeat_interleave(scale, axis=1).\ 120 | repeat_interleave(scale, axis=2) 121 | 122 | def forward(self, x, mask): 123 | num_stages = len(self.stages) 124 | mask = self.upsample_mask(mask, 2**(num_stages-1)) 125 | mask = mask.unsqueeze(1).type_as(x) 126 | 127 | # patch embedding 128 | x = self.downsample_layers[0](x) 129 | x *= (1.-mask) 130 | 131 | # sparse encoding 132 | x = to_sparse(x) 133 | for i in range(4): 134 | x = self.downsample_layers[i](x) if i > 0 else x 135 | x = self.stages[i](x) 136 | 137 | # densify 138 | x = x.dense()[0] 139 | return x -------------------------------------------------------------------------------- /models/fcmae.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | from MinkowskiEngine import ( 13 | MinkowskiConvolution, 14 | MinkowskiDepthwiseConvolution, 15 | MinkowskiLinear, 16 | ) 17 | 18 | from timm.models.layers import trunc_normal_ 19 | from .convnextv2_sparse import SparseConvNeXtV2 20 | from .convnextv2 import Block 21 | 22 | class FCMAE(nn.Module): 23 | """ Fully Convolutional Masked Autoencoder with ConvNeXtV2 backbone 24 | """ 25 | def __init__( 26 | self, 27 | img_size=224, 28 | in_chans=3, 29 | depths=[3, 3, 9, 3], 30 | dims=[96, 192, 384, 768], 31 | decoder_depth=1, 32 | decoder_embed_dim=512, 33 | patch_size=32, 34 | mask_ratio=0.6, 35 | norm_pix_loss=False): 36 | super().__init__() 37 | 38 | # configs 39 | self.img_size = img_size 40 | self.depths = depths 41 | self.imds = dims 42 | self.patch_size = patch_size 43 | self.mask_ratio = mask_ratio 44 | self.num_patches = (img_size // patch_size) ** 2 45 | self.decoder_embed_dim = decoder_embed_dim 46 | self.decoder_depth = decoder_depth 47 | self.norm_pix_loss = norm_pix_loss 48 | 49 | # encoder 50 | self.encoder = SparseConvNeXtV2( 51 | in_chans=in_chans, depths=depths, dims=dims, D=2) 52 | # decoder 53 | self.proj = nn.Conv2d( 54 | in_channels=dims[-1], 55 | out_channels=decoder_embed_dim, 56 | kernel_size=1) 57 | # mask tokens 58 | self.mask_token = nn.Parameter(torch.zeros(1, decoder_embed_dim, 1, 1)) 59 | decoder = [Block( 60 | dim=decoder_embed_dim, 61 | drop_path=0.) for i in range(decoder_depth)] 62 | self.decoder = nn.Sequential(*decoder) 63 | # pred 64 | self.pred = nn.Conv2d( 65 | in_channels=decoder_embed_dim, 66 | out_channels=patch_size ** 2 * in_chans, 67 | kernel_size=1) 68 | 69 | self.apply(self._init_weights) 70 | 71 | def _init_weights(self, m): 72 | if isinstance(m, MinkowskiConvolution): 73 | trunc_normal_(m.kernel, std=.02) 74 | nn.init.constant_(m.bias, 0) 75 | if isinstance(m, MinkowskiDepthwiseConvolution): 76 | trunc_normal_(m.kernel) 77 | nn.init.constant_(m.bias, 0) 78 | if isinstance(m, MinkowskiLinear): 79 | trunc_normal_(m.linear.weight) 80 | nn.init.constant_(m.linear.bias, 0) 81 | if isinstance(m, nn.Conv2d): 82 | w = m.weight.data 83 | trunc_normal_(w.view([w.shape[0], -1])) 84 | nn.init.constant_(m.bias, 0) 85 | if isinstance(m, nn.LayerNorm): 86 | nn.init.constant_(m.bias, 0) 87 | nn.init.constant_(m.weight, 1.0) 88 | if hasattr(self, 'mask_token'): 89 | torch.nn.init.normal_(self.mask_token, std=.02) 90 | 91 | def patchify(self, imgs): 92 | """ 93 | imgs: (N, 3, H, W) 94 | x: (N, L, patch_size**2 *3) 95 | """ 96 | p = self.patch_size 97 | assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 98 | 99 | h = w = imgs.shape[2] // p 100 | x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) 101 | x = torch.einsum('nchpwq->nhwpqc', x) 102 | x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) 103 | return x 104 | 105 | def unpatchify(self, x): 106 | """ 107 | x: (N, L, patch_size**2 *3) 108 | imgs: (N, 3, H, W) 109 | """ 110 | p = self.patch_size 111 | h = w = int(x.shape[1]**.5) 112 | assert h * w == x.shape[1] 113 | 114 | x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) 115 | x = torch.einsum('nhwpqc->nchpwq', x) 116 | imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) 117 | return imgs 118 | 119 | def gen_random_mask(self, x, mask_ratio): 120 | N = x.shape[0] 121 | L = (x.shape[2] // self.patch_size) ** 2 122 | len_keep = int(L * (1 - mask_ratio)) 123 | 124 | noise = torch.randn(N, L, device=x.device) 125 | 126 | # sort noise for each sample 127 | ids_shuffle = torch.argsort(noise, dim=1) 128 | ids_restore = torch.argsort(ids_shuffle, dim=1) 129 | 130 | # generate the binary mask: 0 is keep 1 is remove 131 | mask = torch.ones([N, L], device=x.device) 132 | mask[:, :len_keep] = 0 133 | # unshuffle to get the binary mask 134 | mask = torch.gather(mask, dim=1, index=ids_restore) 135 | return mask 136 | 137 | def upsample_mask(self, mask, scale): 138 | assert len(mask.shape) == 2 139 | p = int(mask.shape[1] ** .5) 140 | return mask.reshape(-1, p, p).\ 141 | repeat_interleave(scale, axis=1).\ 142 | repeat_interleave(scale, axis=2) 143 | 144 | def forward_encoder(self, imgs, mask_ratio): 145 | # generate random masks 146 | mask = self.gen_random_mask(imgs, mask_ratio) 147 | # encoding 148 | x = self.encoder(imgs, mask) 149 | return x, mask 150 | 151 | def forward_decoder(self, x, mask): 152 | x = self.proj(x) 153 | # append mask token 154 | n, c, h, w = x.shape 155 | mask = mask.reshape(-1, h, w).unsqueeze(1).type_as(x) 156 | mask_token = self.mask_token.repeat(x.shape[0], 1, x.shape[2], x.shape[3]) 157 | x = x * (1. - mask) + mask_token * mask 158 | # decoding 159 | x = self.decoder(x) 160 | # pred 161 | pred = self.pred(x) 162 | return pred 163 | 164 | def forward_loss(self, imgs, pred, mask): 165 | """ 166 | imgs: [N, 3, H, W] 167 | pred: [N, L, p*p*3] 168 | mask: [N, L], 0 is keep, 1 is remove 169 | """ 170 | if len(pred.shape) == 4: 171 | n, c, _, _ = pred.shape 172 | pred = pred.reshape(n, c, -1) 173 | pred = torch.einsum('ncl->nlc', pred) 174 | 175 | target = self.patchify(imgs) 176 | if self.norm_pix_loss: 177 | mean = target.mean(dim=-1, keepdim=True) 178 | var = target.var(dim=-1, keepdim=True) 179 | target = (target - mean) / (var + 1.e-6)**.5 180 | loss = (pred - target) ** 2 181 | loss = loss.mean(dim=-1) # [N, L], mean loss per patch 182 | 183 | loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches 184 | return loss 185 | 186 | def forward(self, imgs, labels=None, mask_ratio=0.6): 187 | x, mask = self.forward_encoder(imgs, mask_ratio) 188 | pred = self.forward_decoder(x, mask) 189 | loss = self.forward_loss(imgs, pred, mask) 190 | return loss, pred, mask 191 | 192 | def convnextv2_atto(**kwargs): 193 | model = FCMAE( 194 | depths=[2, 2, 6, 2], dims=[40, 80, 160, 320], **kwargs) 195 | return model 196 | 197 | def convnextv2_femto(**kwargs): 198 | model = FCMAE( 199 | depths=[2, 2, 6, 2], dims=[48, 96, 192, 384], **kwargs) 200 | return model 201 | 202 | def convnextv2_pico(**kwargs): 203 | model = FCMAE( 204 | depths=[2, 2, 6, 2], dims=[64, 128, 256, 512], **kwargs) 205 | return model 206 | 207 | def convnextv2_nano(**kwargs): 208 | model = FCMAE( 209 | depths=[2, 2, 8, 2], dims=[80, 160, 320, 640], **kwargs) 210 | return model 211 | 212 | def convnextv2_tiny(**kwargs): 213 | model = FCMAE( 214 | depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs) 215 | return model 216 | 217 | def convnextv2_base(**kwargs): 218 | model = FCMAE( 219 | depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) 220 | return model 221 | 222 | def convnextv2_large(**kwargs): 223 | model = FCMAE( 224 | depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) 225 | return model 226 | 227 | def convnextv2_huge(**kwargs): 228 | model = FCMAE( 229 | depths=[3, 3, 27, 3], dims=[352, 704, 1408, 2816], **kwargs) 230 | return model -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | 9 | import numpy.random as random 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from MinkowskiEngine import SparseTensor 15 | 16 | class MinkowskiGRN(nn.Module): 17 | """ GRN layer for sparse tensors. 18 | """ 19 | def __init__(self, dim): 20 | super().__init__() 21 | self.gamma = nn.Parameter(torch.zeros(1, dim)) 22 | self.beta = nn.Parameter(torch.zeros(1, dim)) 23 | 24 | def forward(self, x): 25 | cm = x.coordinate_manager 26 | in_key = x.coordinate_map_key 27 | 28 | Gx = torch.norm(x.F, p=2, dim=0, keepdim=True) 29 | Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) 30 | return SparseTensor( 31 | self.gamma * (x.F * Nx) + self.beta + x.F, 32 | coordinate_map_key=in_key, 33 | coordinate_manager=cm) 34 | 35 | class MinkowskiDropPath(nn.Module): 36 | """ Drop Path for sparse tensors. 37 | """ 38 | 39 | def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True): 40 | super(MinkowskiDropPath, self).__init__() 41 | self.drop_prob = drop_prob 42 | self.scale_by_keep = scale_by_keep 43 | 44 | def forward(self, x): 45 | if self.drop_prob == 0. or not self.training: 46 | return x 47 | cm = x.coordinate_manager 48 | in_key = x.coordinate_map_key 49 | keep_prob = 1 - self.drop_prob 50 | mask = torch.cat([ 51 | torch.ones(len(_)) if random.uniform(0, 1) > self.drop_prob 52 | else torch.zeros(len(_)) for _ in x.decomposed_coordinates 53 | ]).view(-1, 1).to(x.device) 54 | if keep_prob > 0.0 and self.scale_by_keep: 55 | mask.div_(keep_prob) 56 | return SparseTensor( 57 | x.F * mask, 58 | coordinate_map_key=in_key, 59 | coordinate_manager=cm) 60 | 61 | class MinkowskiLayerNorm(nn.Module): 62 | """ Channel-wise layer normalization for sparse tensors. 63 | """ 64 | 65 | def __init__( 66 | self, 67 | normalized_shape, 68 | eps=1e-6, 69 | ): 70 | super(MinkowskiLayerNorm, self).__init__() 71 | self.ln = nn.LayerNorm(normalized_shape, eps=eps) 72 | def forward(self, input): 73 | output = self.ln(input.F) 74 | return SparseTensor( 75 | output, 76 | coordinate_map_key=input.coordinate_map_key, 77 | coordinate_manager=input.coordinate_manager) 78 | 79 | class LayerNorm(nn.Module): 80 | """ LayerNorm that supports two data formats: channels_last (default) or channels_first. 81 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with 82 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs 83 | with shape (batch_size, channels, height, width). 84 | """ 85 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): 86 | super().__init__() 87 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 88 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 89 | self.eps = eps 90 | self.data_format = data_format 91 | if self.data_format not in ["channels_last", "channels_first"]: 92 | raise NotImplementedError 93 | self.normalized_shape = (normalized_shape, ) 94 | 95 | def forward(self, x): 96 | if self.data_format == "channels_last": 97 | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 98 | elif self.data_format == "channels_first": 99 | u = x.mean(1, keepdim=True) 100 | s = (x - u).pow(2).mean(1, keepdim=True) 101 | x = (x - u) / torch.sqrt(s + self.eps) 102 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 103 | return x 104 | 105 | class GRN(nn.Module): 106 | """ GRN (Global Response Normalization) layer 107 | """ 108 | def __init__(self, dim): 109 | super().__init__() 110 | self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) 111 | self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) 112 | 113 | def forward(self, x): 114 | Gx = torch.norm(x, p=2, dim=(1,2), keepdim=True) 115 | Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) 116 | return self.gamma * (x * Nx) + self.beta + x -------------------------------------------------------------------------------- /optim_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | 9 | import torch 10 | from torch import optim as optim 11 | 12 | from timm.optim.adafactor import Adafactor 13 | from timm.optim.adahessian import Adahessian 14 | from timm.optim.adamp import AdamP 15 | from timm.optim.lookahead import Lookahead 16 | from timm.optim.nadam import Nadam 17 | from timm.optim.novograd import NovoGrad 18 | from timm.optim.nvnovograd import NvNovoGrad 19 | from timm.optim.radam import RAdam 20 | from timm.optim.rmsprop_tf import RMSpropTF 21 | from timm.optim.sgdp import SGDP 22 | 23 | import json 24 | 25 | try: 26 | from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD 27 | has_apex = True 28 | except ImportError: 29 | has_apex = False 30 | 31 | 32 | def get_num_layer_for_convnext_single(var_name, depths): 33 | """ 34 | Each layer is assigned distinctive layer ids 35 | """ 36 | if var_name.startswith("downsample_layers"): 37 | stage_id = int(var_name.split('.')[1]) 38 | layer_id = sum(depths[:stage_id]) + 1 39 | return layer_id 40 | 41 | elif var_name.startswith("stages"): 42 | stage_id = int(var_name.split('.')[1]) 43 | block_id = int(var_name.split('.')[2]) 44 | layer_id = sum(depths[:stage_id]) + block_id + 1 45 | return layer_id 46 | 47 | else: 48 | return sum(depths) + 1 49 | 50 | 51 | def get_num_layer_for_convnext(var_name): 52 | """ 53 | Divide [3, 3, 27, 3] layers into 12 groups; each group is three 54 | consecutive blocks, including possible neighboring downsample layers; 55 | adapted from https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py 56 | """ 57 | num_max_layer = 12 58 | if var_name.startswith("downsample_layers"): 59 | stage_id = int(var_name.split('.')[1]) 60 | if stage_id == 0: 61 | layer_id = 0 62 | elif stage_id == 1 or stage_id == 2: 63 | layer_id = stage_id + 1 64 | elif stage_id == 3: 65 | layer_id = 12 66 | return layer_id 67 | 68 | elif var_name.startswith("stages"): 69 | stage_id = int(var_name.split('.')[1]) 70 | block_id = int(var_name.split('.')[2]) 71 | if stage_id == 0 or stage_id == 1: 72 | layer_id = stage_id + 1 73 | elif stage_id == 2: 74 | layer_id = 3 + block_id // 3 75 | elif stage_id == 3: 76 | layer_id = 12 77 | return layer_id 78 | else: 79 | return num_max_layer + 1 80 | 81 | class LayerDecayValueAssigner(object): 82 | def __init__(self, values, depths=[3,3,27,3], layer_decay_type='single'): 83 | self.values = values 84 | self.depths = depths 85 | self.layer_decay_type = layer_decay_type 86 | 87 | def get_scale(self, layer_id): 88 | return self.values[layer_id] 89 | 90 | def get_layer_id(self, var_name): 91 | if self.layer_decay_type == 'single': 92 | return get_num_layer_for_convnext_single(var_name, self.depths) 93 | else: 94 | return get_num_layer_for_convnext(var_name) 95 | 96 | 97 | def get_parameter_groups(model, weight_decay=1e-5, skip_list=(), get_num_layer=None, get_layer_scale=None): 98 | parameter_group_names = {} 99 | parameter_group_vars = {} 100 | 101 | for name, param in model.named_parameters(): 102 | if not param.requires_grad: 103 | continue # frozen weights 104 | if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list or \ 105 | name.endswith(".gamma") or name.endswith(".beta"): 106 | group_name = "no_decay" 107 | this_weight_decay = 0. 108 | else: 109 | group_name = "decay" 110 | this_weight_decay = weight_decay 111 | if get_num_layer is not None: 112 | layer_id = get_num_layer(name) 113 | group_name = "layer_%d_%s" % (layer_id, group_name) 114 | else: 115 | layer_id = None 116 | 117 | if group_name not in parameter_group_names: 118 | if get_layer_scale is not None: 119 | scale = get_layer_scale(layer_id) 120 | else: 121 | scale = 1. 122 | 123 | parameter_group_names[group_name] = { 124 | "weight_decay": this_weight_decay, 125 | "params": [], 126 | "lr_scale": scale 127 | } 128 | parameter_group_vars[group_name] = { 129 | "weight_decay": this_weight_decay, 130 | "params": [], 131 | "lr_scale": scale 132 | } 133 | 134 | parameter_group_vars[group_name]["params"].append(param) 135 | parameter_group_names[group_name]["params"].append(name) 136 | print("Param groups = %s" % json.dumps(parameter_group_names, indent=2)) 137 | return list(parameter_group_vars.values()) 138 | 139 | 140 | def create_optimizer(args, model, get_num_layer=None, get_layer_scale=None, filter_bias_and_bn=True, skip_list=None): 141 | opt_lower = args.opt.lower() 142 | weight_decay = args.weight_decay 143 | # if weight_decay and filter_bias_and_bn: 144 | if filter_bias_and_bn: 145 | skip = {} 146 | if skip_list is not None: 147 | skip = skip_list 148 | elif hasattr(model, 'no_weight_decay'): 149 | skip = model.no_weight_decay() 150 | parameters = get_parameter_groups(model, weight_decay, skip, get_num_layer, get_layer_scale) 151 | weight_decay = 0. 152 | else: 153 | parameters = model.parameters() 154 | 155 | if 'fused' in opt_lower: 156 | assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers' 157 | 158 | opt_args = dict(lr=args.lr, weight_decay=weight_decay) 159 | if hasattr(args, 'opt_eps') and args.opt_eps is not None: 160 | opt_args['eps'] = args.opt_eps 161 | if hasattr(args, 'opt_betas') and args.opt_betas is not None: 162 | opt_args['betas'] = args.opt_betas 163 | 164 | opt_split = opt_lower.split('_') 165 | opt_lower = opt_split[-1] 166 | if opt_lower == 'sgd' or opt_lower == 'nesterov': 167 | opt_args.pop('eps', None) 168 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) 169 | elif opt_lower == 'momentum': 170 | opt_args.pop('eps', None) 171 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) 172 | elif opt_lower == 'adam': 173 | optimizer = optim.Adam(parameters, **opt_args) 174 | elif opt_lower == 'adamw': 175 | optimizer = optim.AdamW(parameters, **opt_args) 176 | elif opt_lower == 'nadam': 177 | optimizer = Nadam(parameters, **opt_args) 178 | elif opt_lower == 'radam': 179 | optimizer = RAdam(parameters, **opt_args) 180 | elif opt_lower == 'adamp': 181 | optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args) 182 | elif opt_lower == 'sgdp': 183 | optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args) 184 | elif opt_lower == 'adadelta': 185 | optimizer = optim.Adadelta(parameters, **opt_args) 186 | elif opt_lower == 'adafactor': 187 | if not args.lr: 188 | opt_args['lr'] = None 189 | optimizer = Adafactor(parameters, **opt_args) 190 | elif opt_lower == 'adahessian': 191 | optimizer = Adahessian(parameters, **opt_args) 192 | elif opt_lower == 'rmsprop': 193 | optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args) 194 | elif opt_lower == 'rmsproptf': 195 | optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args) 196 | elif opt_lower == 'novograd': 197 | optimizer = NovoGrad(parameters, **opt_args) 198 | elif opt_lower == 'nvnovograd': 199 | optimizer = NvNovoGrad(parameters, **opt_args) 200 | elif opt_lower == 'fusedsgd': 201 | opt_args.pop('eps', None) 202 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) 203 | elif opt_lower == 'fusedmomentum': 204 | opt_args.pop('eps', None) 205 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) 206 | elif opt_lower == 'fusedadam': 207 | optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args) 208 | elif opt_lower == 'fusedadamw': 209 | optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args) 210 | elif opt_lower == 'fusedlamb': 211 | optimizer = FusedLAMB(parameters, **opt_args) 212 | elif opt_lower == 'fusednovograd': 213 | opt_args.setdefault('betas', (0.95, 0.98)) 214 | optimizer = FusedNovoGrad(parameters, **opt_args) 215 | else: 216 | assert False and "Invalid optimizer" 217 | 218 | if len(opt_split) > 1: 219 | if opt_split[0] == 'lookahead': 220 | optimizer = Lookahead(optimizer) 221 | 222 | return optimizer 223 | -------------------------------------------------------------------------------- /submitit_finetune.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | 9 | import argparse 10 | import os 11 | import uuid 12 | from pathlib import Path 13 | 14 | import main_finetune as trainer 15 | import submitit 16 | 17 | def parse_args(): 18 | trainer_parser = trainer.get_args_parser() 19 | parser = argparse.ArgumentParser("Submitit for finetune", parents=[trainer_parser]) 20 | parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node") 21 | parser.add_argument("--nodes", default=2, type=int, help="Number of nodes to request") 22 | parser.add_argument("--timeout", default=4320, type=int, help="Duration of the job") 23 | parser.add_argument("--job_dir", default="", type=str, help="Job dir. Leave empty for automatic.") 24 | 25 | parser.add_argument("--partition", default="learnlab", type=str, help="Partition where to submit") 26 | parser.add_argument("--use_volta32", action='store_true', help="Request 32G V100 GPUs") 27 | parser.add_argument('--comment', default="", type=str, help="Comment to pass to scheduler") 28 | return parser.parse_args() 29 | 30 | def get_shared_folder() -> Path: 31 | user = os.getenv("USER") 32 | if Path("/checkpoint/").is_dir(): 33 | p = Path(f"/checkpoint/{user}/experiments") 34 | p.mkdir(exist_ok=True) 35 | return p 36 | raise RuntimeError("No shared folder available") 37 | 38 | def get_init_file(): 39 | # Init file must not exist, but it's parent dir must exist. 40 | os.makedirs(str(get_shared_folder()), exist_ok=True) 41 | init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init" 42 | if init_file.exists(): 43 | os.remove(str(init_file)) 44 | return init_file 45 | 46 | class Trainer(object): 47 | def __init__(self, args): 48 | self.args = args 49 | 50 | def __call__(self): 51 | import main_finetune as trainer 52 | 53 | self._setup_gpu_args() 54 | trainer.main(self.args) 55 | 56 | def checkpoint(self): 57 | import os 58 | import submitit 59 | 60 | self.args.dist_url = get_init_file().as_uri() 61 | checkpoint_file = os.path.join(self.args.output_dir, "checkpoint.pth") 62 | if os.path.exists(checkpoint_file): 63 | self.args.resume = checkpoint_file 64 | print("Requeuing ", self.args) 65 | empty_trainer = type(self)(self.args) 66 | return submitit.helpers.DelayedSubmission(empty_trainer) 67 | 68 | def _setup_gpu_args(self): 69 | import submitit 70 | from pathlib import Path 71 | 72 | job_env = submitit.JobEnvironment() 73 | self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id))) 74 | self.args.log_dir = self.args.output_dir 75 | self.args.gpu = job_env.local_rank 76 | self.args.rank = job_env.global_rank 77 | self.args.world_size = job_env.num_tasks 78 | print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") 79 | 80 | def main(): 81 | args = parse_args() 82 | if args.job_dir == "": 83 | args.job_dir = get_shared_folder() / "%j" 84 | 85 | # Note that the folder will depend on the job_id, to easily track experiments 86 | executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30) 87 | 88 | num_gpus_per_node = args.ngpus 89 | nodes = args.nodes 90 | timeout_min = args.timeout 91 | 92 | partition = args.partition 93 | kwargs = {} 94 | if args.use_volta32: 95 | kwargs['slurm_constraint'] = 'volta32gb' 96 | if args.comment: 97 | kwargs['slurm_comment'] = args.comment 98 | 99 | executor.update_parameters( 100 | mem_gb=40 * num_gpus_per_node, 101 | gpus_per_node=num_gpus_per_node, 102 | tasks_per_node=num_gpus_per_node, # one task per GPU 103 | cpus_per_task=10, 104 | nodes=nodes, 105 | timeout_min=timeout_min, # max is 60 * 72 106 | # Below are cluster dependent parameters 107 | slurm_partition=partition, 108 | slurm_signal_delay_s=120, 109 | **kwargs 110 | ) 111 | 112 | executor.update_parameters(name="finetune") 113 | 114 | args.dist_url = get_init_file().as_uri() 115 | args.output_dir = args.job_dir 116 | 117 | trainer = Trainer(args) 118 | job = executor.submit(trainer) 119 | 120 | # print("Submitted job_id:", job.job_id) 121 | print(job.job_id) 122 | 123 | if __name__ == "__main__": 124 | main() 125 | 126 | -------------------------------------------------------------------------------- /submitit_pretrain.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | 9 | import argparse 10 | import os 11 | import uuid 12 | from pathlib import Path 13 | 14 | import main_pretrain as trainer 15 | import submitit 16 | 17 | def parse_args(): 18 | trainer_parser = trainer.get_args_parser() 19 | parser = argparse.ArgumentParser("Submitit for pretrain", parents=[trainer_parser]) 20 | parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node") 21 | parser.add_argument("--nodes", default=2, type=int, help="Number of nodes to request") 22 | parser.add_argument("--timeout", default=4320, type=int, help="Duration of the job") 23 | parser.add_argument("--job_dir", default="", type=str, help="Job dir. Leave empty for automatic.") 24 | 25 | parser.add_argument("--partition", default="learnlab", type=str, help="Partition where to submit") 26 | parser.add_argument("--use_volta32", action='store_true', help="Request 32G V100 GPUs") 27 | parser.add_argument('--comment', default="", type=str, help="Comment to pass to scheduler") 28 | return parser.parse_args() 29 | 30 | def get_shared_folder() -> Path: 31 | user = os.getenv("USER") 32 | if Path("/checkpoint/").is_dir(): 33 | p = Path(f"/checkpoint/{user}/experiments") 34 | p.mkdir(exist_ok=True) 35 | return p 36 | raise RuntimeError("No shared folder available") 37 | 38 | def get_init_file(): 39 | # Init file must not exist, but it's parent dir must exist. 40 | os.makedirs(str(get_shared_folder()), exist_ok=True) 41 | init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init" 42 | if init_file.exists(): 43 | os.remove(str(init_file)) 44 | return init_file 45 | 46 | class Trainer(object): 47 | def __init__(self, args): 48 | self.args = args 49 | 50 | def __call__(self): 51 | import main_pretrain as trainer 52 | 53 | self._setup_gpu_args() 54 | trainer.main(self.args) 55 | 56 | def checkpoint(self): 57 | import os 58 | import submitit 59 | 60 | self.args.dist_url = get_init_file().as_uri() 61 | checkpoint_file = os.path.join(self.args.output_dir, "checkpoint.pth") 62 | if os.path.exists(checkpoint_file): 63 | self.args.resume = checkpoint_file 64 | print("Requeuing ", self.args) 65 | empty_trainer = type(self)(self.args) 66 | return submitit.helpers.DelayedSubmission(empty_trainer) 67 | 68 | def _setup_gpu_args(self): 69 | import submitit 70 | from pathlib import Path 71 | 72 | job_env = submitit.JobEnvironment() 73 | self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id))) 74 | self.args.log_dir = self.args.output_dir 75 | self.args.gpu = job_env.local_rank 76 | self.args.rank = job_env.global_rank 77 | self.args.world_size = job_env.num_tasks 78 | print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") 79 | 80 | def main(): 81 | args = parse_args() 82 | if args.job_dir == "": 83 | args.job_dir = get_shared_folder() / "%j" 84 | 85 | # Note that the folder will depend on the job_id, to easily track experiments 86 | executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30) 87 | 88 | num_gpus_per_node = args.ngpus 89 | nodes = args.nodes 90 | timeout_min = args.timeout 91 | 92 | partition = args.partition 93 | kwargs = {} 94 | if args.use_volta32: 95 | kwargs['slurm_constraint'] = 'volta32gb' 96 | if args.comment: 97 | kwargs['slurm_comment'] = args.comment 98 | 99 | executor.update_parameters( 100 | mem_gb=40 * num_gpus_per_node, 101 | gpus_per_node=num_gpus_per_node, 102 | tasks_per_node=num_gpus_per_node, # one task per GPU 103 | cpus_per_task=10, 104 | nodes=nodes, 105 | timeout_min=timeout_min, # max is 60 * 72 106 | # Below are cluster dependent parameters 107 | slurm_partition=partition, 108 | slurm_signal_delay_s=120, 109 | **kwargs 110 | ) 111 | 112 | executor.update_parameters(name="pretrain") 113 | 114 | args.dist_url = get_init_file().as_uri() 115 | args.output_dir = args.job_dir 116 | 117 | trainer = Trainer(args) 118 | job = executor.submit(trainer) 119 | 120 | # print("Submitted job_id:", job.job_id) 121 | print(job.job_id) 122 | 123 | if __name__ == "__main__": 124 | main() 125 | 126 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | 9 | import os 10 | import math 11 | import time 12 | from collections import defaultdict, deque 13 | import datetime 14 | import numpy as np 15 | from timm.utils import get_state_dict 16 | 17 | from pathlib import Path 18 | 19 | import torch 20 | import torch.distributed as dist 21 | from torch._six import inf 22 | 23 | from tensorboardX import SummaryWriter 24 | from collections import OrderedDict 25 | 26 | def str2bool(v): 27 | """ 28 | Converts string to bool type; enables command line 29 | arguments in the format of '--arg1 true --arg2 false' 30 | """ 31 | if isinstance(v, bool): 32 | return v 33 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 34 | return True 35 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 36 | return False 37 | else: 38 | raise argparse.ArgumentTypeError('Boolean value expected.') 39 | 40 | class SmoothedValue(object): 41 | """Track a series of values and provide access to smoothed values over a 42 | window or the global series average. 43 | """ 44 | 45 | def __init__(self, window_size=20, fmt=None): 46 | if fmt is None: 47 | fmt = "{median:.4f} ({global_avg:.4f})" 48 | self.deque = deque(maxlen=window_size) 49 | self.total = 0.0 50 | self.count = 0 51 | self.fmt = fmt 52 | 53 | def update(self, value, n=1): 54 | self.deque.append(value) 55 | self.count += n 56 | self.total += value * n 57 | 58 | def synchronize_between_processes(self): 59 | """ 60 | Warning: does not synchronize the deque! 61 | """ 62 | if not is_dist_avail_and_initialized(): 63 | return 64 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 65 | dist.barrier() 66 | dist.all_reduce(t) 67 | t = t.tolist() 68 | self.count = int(t[0]) 69 | self.total = t[1] 70 | 71 | @property 72 | def median(self): 73 | d = torch.tensor(list(self.deque)) 74 | return d.median().item() 75 | 76 | @property 77 | def avg(self): 78 | d = torch.tensor(list(self.deque), dtype=torch.float32) 79 | return d.mean().item() 80 | 81 | @property 82 | def global_avg(self): 83 | return self.total / self.count 84 | 85 | @property 86 | def max(self): 87 | return max(self.deque) 88 | 89 | @property 90 | def value(self): 91 | return self.deque[-1] 92 | 93 | def __str__(self): 94 | return self.fmt.format( 95 | median=self.median, 96 | avg=self.avg, 97 | global_avg=self.global_avg, 98 | max=self.max, 99 | value=self.value) 100 | 101 | 102 | class MetricLogger(object): 103 | def __init__(self, delimiter="\t"): 104 | self.meters = defaultdict(SmoothedValue) 105 | self.delimiter = delimiter 106 | 107 | def update(self, **kwargs): 108 | for k, v in kwargs.items(): 109 | if v is None: 110 | continue 111 | if isinstance(v, torch.Tensor): 112 | v = v.item() 113 | assert isinstance(v, (float, int)) 114 | self.meters[k].update(v) 115 | 116 | def __getattr__(self, attr): 117 | if attr in self.meters: 118 | return self.meters[attr] 119 | if attr in self.__dict__: 120 | return self.__dict__[attr] 121 | raise AttributeError("'{}' object has no attribute '{}'".format( 122 | type(self).__name__, attr)) 123 | 124 | def __str__(self): 125 | loss_str = [] 126 | for name, meter in self.meters.items(): 127 | loss_str.append( 128 | "{}: {}".format(name, str(meter)) 129 | ) 130 | return self.delimiter.join(loss_str) 131 | 132 | def synchronize_between_processes(self): 133 | for meter in self.meters.values(): 134 | meter.synchronize_between_processes() 135 | 136 | def add_meter(self, name, meter): 137 | self.meters[name] = meter 138 | 139 | def log_every(self, iterable, print_freq, header=None): 140 | i = 0 141 | if not header: 142 | header = '' 143 | start_time = time.time() 144 | end = time.time() 145 | iter_time = SmoothedValue(fmt='{avg:.4f}') 146 | data_time = SmoothedValue(fmt='{avg:.4f}') 147 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 148 | log_msg = [ 149 | header, 150 | '[{0' + space_fmt + '}/{1}]', 151 | 'eta: {eta}', 152 | '{meters}', 153 | 'time: {time}', 154 | 'data: {data}' 155 | ] 156 | if torch.cuda.is_available(): 157 | log_msg.append('max mem: {memory:.0f}') 158 | log_msg = self.delimiter.join(log_msg) 159 | MB = 1024.0 * 1024.0 160 | for obj in iterable: 161 | data_time.update(time.time() - end) 162 | yield obj 163 | iter_time.update(time.time() - end) 164 | if i % print_freq == 0 or i == len(iterable) - 1: 165 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 166 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 167 | if torch.cuda.is_available(): 168 | print(log_msg.format( 169 | i, len(iterable), eta=eta_string, 170 | meters=str(self), 171 | time=str(iter_time), data=str(data_time), 172 | memory=torch.cuda.max_memory_allocated() / MB)) 173 | else: 174 | print(log_msg.format( 175 | i, len(iterable), eta=eta_string, 176 | meters=str(self), 177 | time=str(iter_time), data=str(data_time))) 178 | i += 1 179 | end = time.time() 180 | total_time = time.time() - start_time 181 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 182 | print('{} Total time: {} ({:.4f} s / it)'.format( 183 | header, total_time_str, total_time / len(iterable))) 184 | 185 | 186 | class TensorboardLogger(object): 187 | def __init__(self, log_dir): 188 | self.writer = SummaryWriter(logdir=log_dir) 189 | self.step = 0 190 | 191 | def set_step(self, step=None): 192 | if step is not None: 193 | self.step = step 194 | else: 195 | self.step += 1 196 | 197 | def update(self, head='scalar', step=None, **kwargs): 198 | for k, v in kwargs.items(): 199 | if v is None: 200 | continue 201 | if isinstance(v, torch.Tensor): 202 | v = v.item() 203 | assert isinstance(v, (float, int)) 204 | self.writer.add_scalar(head + "/" + k, v, self.step if step is None else step) 205 | 206 | def flush(self): 207 | self.writer.flush() 208 | 209 | 210 | class WandbLogger(object): 211 | def __init__(self, args): 212 | self.args = args 213 | 214 | try: 215 | import wandb 216 | self._wandb = wandb 217 | except ImportError: 218 | raise ImportError( 219 | "To use the Weights and Biases Logger please install wandb." 220 | "Run `pip install wandb` to install it." 221 | ) 222 | 223 | # Initialize a W&B run 224 | if self._wandb.run is None: 225 | self._wandb.init( 226 | project=args.project, 227 | config=args 228 | ) 229 | 230 | def log_epoch_metrics(self, metrics, commit=True): 231 | """ 232 | Log train/test metrics onto W&B. 233 | """ 234 | # Log number of model parameters as W&B summary 235 | self._wandb.summary['n_parameters'] = metrics.get('n_parameters', None) 236 | metrics.pop('n_parameters', None) 237 | 238 | # Log current epoch 239 | self._wandb.log({'epoch': metrics.get('epoch')}, commit=False) 240 | metrics.pop('epoch') 241 | 242 | for k, v in metrics.items(): 243 | if 'train' in k: 244 | self._wandb.log({f'Global Train/{k}': v}, commit=False) 245 | elif 'test' in k: 246 | self._wandb.log({f'Global Test/{k}': v}, commit=False) 247 | 248 | self._wandb.log({}) 249 | 250 | def log_checkpoints(self): 251 | output_dir = self.args.output_dir 252 | model_artifact = self._wandb.Artifact( 253 | self._wandb.run.id + "_model", type="model" 254 | ) 255 | 256 | model_artifact.add_dir(output_dir) 257 | self._wandb.log_artifact(model_artifact, aliases=["latest", "best"]) 258 | 259 | def set_steps(self): 260 | # Set global training step 261 | self._wandb.define_metric('Rank-0 Batch Wise/*', step_metric='Rank-0 Batch Wise/global_train_step') 262 | # Set epoch-wise step 263 | self._wandb.define_metric('Global Train/*', step_metric='epoch') 264 | self._wandb.define_metric('Global Test/*', step_metric='epoch') 265 | 266 | 267 | def setup_for_distributed(is_master): 268 | """ 269 | This function disables printing when not in master process 270 | """ 271 | import builtins as __builtin__ 272 | builtin_print = __builtin__.print 273 | 274 | def print(*args, **kwargs): 275 | force = kwargs.pop('force', False) 276 | if is_master or force: 277 | builtin_print(*args, **kwargs) 278 | 279 | __builtin__.print = print 280 | 281 | 282 | def is_dist_avail_and_initialized(): 283 | if not dist.is_available(): 284 | return False 285 | if not dist.is_initialized(): 286 | return False 287 | return True 288 | 289 | 290 | def get_world_size(): 291 | if not is_dist_avail_and_initialized(): 292 | return 1 293 | return dist.get_world_size() 294 | 295 | 296 | def get_rank(): 297 | if not is_dist_avail_and_initialized(): 298 | return 0 299 | return dist.get_rank() 300 | 301 | 302 | def is_main_process(): 303 | return get_rank() == 0 304 | 305 | 306 | def save_on_master(*args, **kwargs): 307 | if is_main_process(): 308 | torch.save(*args, **kwargs) 309 | 310 | def init_distributed_mode(args): 311 | 312 | if args.dist_on_itp: 313 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 314 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) 315 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 316 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) 317 | os.environ['LOCAL_RANK'] = str(args.gpu) 318 | os.environ['RANK'] = str(args.rank) 319 | os.environ['WORLD_SIZE'] = str(args.world_size) 320 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] 321 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 322 | args.rank = int(os.environ["RANK"]) 323 | args.world_size = int(os.environ['WORLD_SIZE']) 324 | args.gpu = int(os.environ['LOCAL_RANK']) 325 | elif 'SLURM_PROCID' in os.environ: 326 | args.rank = int(os.environ['SLURM_PROCID']) 327 | args.gpu = args.rank % torch.cuda.device_count() 328 | 329 | os.environ['RANK'] = str(args.rank) 330 | os.environ['LOCAL_RANK'] = str(args.gpu) 331 | os.environ['WORLD_SIZE'] = str(args.world_size) 332 | else: 333 | print('Not using distributed mode') 334 | args.distributed = False 335 | return 336 | 337 | args.distributed = True 338 | 339 | torch.cuda.set_device(args.gpu) 340 | args.dist_backend = 'nccl' 341 | print('| distributed init (rank {}): {}, gpu {}'.format( 342 | args.rank, args.dist_url, args.gpu), flush=True) 343 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 344 | world_size=args.world_size, rank=args.rank) 345 | torch.distributed.barrier() 346 | setup_for_distributed(args.rank == 0) 347 | 348 | def all_reduce_mean(x): 349 | world_size = get_world_size() 350 | if world_size > 1: 351 | x_reduce = torch.tensor(x).cuda() 352 | dist.all_reduce(x_reduce) 353 | x_reduce /= world_size 354 | return x_reduce.item() 355 | else: 356 | return x 357 | 358 | def load_state_dict(model, state_dict, prefix='', ignore_missing="relative_position_index"): 359 | missing_keys = [] 360 | unexpected_keys = [] 361 | error_msgs = [] 362 | # copy state_dict so _load_from_state_dict can modify it 363 | metadata = getattr(state_dict, '_metadata', None) 364 | state_dict = state_dict.copy() 365 | if metadata is not None: 366 | state_dict._metadata = metadata 367 | 368 | def load(module, prefix=''): 369 | local_metadata = {} if metadata is None else metadata.get( 370 | prefix[:-1], {}) 371 | module._load_from_state_dict( 372 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) 373 | for name, child in module._modules.items(): 374 | if child is not None: 375 | load(child, prefix + name + '.') 376 | 377 | load(model, prefix=prefix) 378 | 379 | warn_missing_keys = [] 380 | ignore_missing_keys = [] 381 | for key in missing_keys: 382 | keep_flag = True 383 | for ignore_key in ignore_missing.split('|'): 384 | if ignore_key in key: 385 | keep_flag = False 386 | break 387 | if keep_flag: 388 | warn_missing_keys.append(key) 389 | else: 390 | ignore_missing_keys.append(key) 391 | 392 | missing_keys = warn_missing_keys 393 | 394 | if len(missing_keys) > 0: 395 | print("Weights of {} not initialized from pretrained model: {}".format( 396 | model.__class__.__name__, missing_keys)) 397 | if len(unexpected_keys) > 0: 398 | print("Weights from pretrained model not used in {}: {}".format( 399 | model.__class__.__name__, unexpected_keys)) 400 | if len(ignore_missing_keys) > 0: 401 | print("Ignored weights of {} not initialized from pretrained model: {}".format( 402 | model.__class__.__name__, ignore_missing_keys)) 403 | if len(error_msgs) > 0: 404 | print('\n'.join(error_msgs)) 405 | 406 | 407 | class NativeScalerWithGradNormCount: 408 | state_dict_key = "amp_scaler" 409 | 410 | def __init__(self): 411 | self._scaler = torch.cuda.amp.GradScaler() 412 | 413 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 414 | self._scaler.scale(loss).backward(create_graph=create_graph) 415 | if update_grad: 416 | if clip_grad is not None: 417 | assert parameters is not None 418 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 419 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 420 | else: 421 | self._scaler.unscale_(optimizer) 422 | norm = get_grad_norm_(parameters) 423 | self._scaler.step(optimizer) 424 | self._scaler.update() 425 | else: 426 | norm = None 427 | return norm 428 | 429 | def state_dict(self): 430 | return self._scaler.state_dict() 431 | 432 | def load_state_dict(self, state_dict): 433 | self._scaler.load_state_dict(state_dict) 434 | 435 | 436 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 437 | if isinstance(parameters, torch.Tensor): 438 | parameters = [parameters] 439 | parameters = [p for p in parameters if p.grad is not None] 440 | norm_type = float(norm_type) 441 | if len(parameters) == 0: 442 | return torch.tensor(0.) 443 | device = parameters[0].grad.device 444 | if norm_type == inf: 445 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 446 | else: 447 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) 448 | return total_norm 449 | 450 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, model_ema=None): 451 | output_dir = Path(args.output_dir) 452 | epoch_name = str(epoch) 453 | checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)] 454 | for checkpoint_path in checkpoint_paths: 455 | to_save = { 456 | 'model': model_without_ddp.state_dict(), 457 | 'optimizer': optimizer.state_dict(), 458 | 'epoch': epoch, 459 | 'scaler': loss_scaler.state_dict(), 460 | 'args': args, 461 | } 462 | 463 | if model_ema is not None: 464 | to_save['model_ema'] = get_state_dict(model_ema) 465 | 466 | save_on_master(to_save, checkpoint_path) 467 | 468 | if is_main_process() and isinstance(epoch, int): 469 | to_del = epoch - args.save_ckpt_num * args.save_ckpt_freq 470 | old_ckpt = output_dir / ('checkpoint-%s.pth' % to_del) 471 | if os.path.exists(old_ckpt): 472 | os.remove(old_ckpt) 473 | 474 | def auto_load_model(args, model, model_without_ddp, optimizer, loss_scaler, model_ema=None): 475 | output_dir = Path(args.output_dir) 476 | if args.auto_resume and len(args.resume) == 0: 477 | import glob 478 | all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*.pth')) 479 | latest_ckpt = -1 480 | for ckpt in all_checkpoints: 481 | t = ckpt.split('-')[-1].split('.')[0] 482 | if t.isdigit(): 483 | latest_ckpt = max(int(t), latest_ckpt) 484 | if latest_ckpt >= 0: 485 | args.resume = os.path.join(output_dir, 'checkpoint-%d.pth' % latest_ckpt) 486 | print("Auto resume checkpoint: %s" % args.resume) 487 | 488 | if args.resume: 489 | if args.resume.startswith('https'): 490 | checkpoint = torch.hub.load_state_dict_from_url( 491 | args.resume, map_location='cpu', check_hash=True) 492 | else: 493 | checkpoint = torch.load(args.resume, map_location='cpu') 494 | 495 | model_without_ddp.load_state_dict(checkpoint['model']) 496 | print("Resume checkpoint %s" % args.resume) 497 | if 'optimizer' in checkpoint and 'epoch' in checkpoint: 498 | optimizer.load_state_dict(checkpoint['optimizer']) 499 | if not isinstance(checkpoint['epoch'], str): # does not support resuming with 'best', 'best-ema' 500 | args.start_epoch = checkpoint['epoch'] + 1 501 | else: 502 | assert args.eval, 'Does not support resuming with checkpoint-best' 503 | if hasattr(args, 'model_ema') and args.model_ema: 504 | if 'model_ema' in checkpoint.keys(): 505 | model_ema.ema.load_state_dict(checkpoint['model_ema']) 506 | else: 507 | model_ema.ema.load_state_dict(checkpoint['model']) 508 | if 'scaler' in checkpoint: 509 | loss_scaler.load_state_dict(checkpoint['scaler']) 510 | print("With optim & sched!") 511 | 512 | def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, 513 | start_warmup_value=0, warmup_steps=-1): 514 | warmup_schedule = np.array([]) 515 | warmup_iters = warmup_epochs * niter_per_ep 516 | if warmup_steps > 0: 517 | warmup_iters = warmup_steps 518 | print("Set warmup steps = %d" % warmup_iters) 519 | if warmup_epochs > 0: 520 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) 521 | 522 | iters = np.arange(epochs * niter_per_ep - warmup_iters) 523 | schedule = np.array( 524 | [final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / (len(iters)))) for i in iters]) 525 | 526 | schedule = np.concatenate((warmup_schedule, schedule)) 527 | 528 | assert len(schedule) == epochs * niter_per_ep 529 | return schedule 530 | 531 | def adjust_learning_rate(optimizer, epoch, args): 532 | """Decay the learning rate with half-cycle cosine after warmup""" 533 | if epoch < args.warmup_epochs: 534 | lr = args.lr * epoch / args.warmup_epochs 535 | else: 536 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \ 537 | (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))) 538 | for param_group in optimizer.param_groups: 539 | if "lr_scale" in param_group: 540 | param_group["lr"] = lr * param_group["lr_scale"] 541 | else: 542 | param_group["lr"] = lr 543 | return lr 544 | 545 | def remap_checkpoint_keys(ckpt): 546 | new_ckpt = OrderedDict() 547 | for k, v in ckpt.items(): 548 | if k.startswith('encoder'): 549 | k = '.'.join(k.split('.')[1:]) # remove encoder in the name 550 | if k.endswith('kernel'): 551 | k = '.'.join(k.split('.')[:-1]) # remove kernel in the name 552 | new_k = k + '.weight' 553 | if len(v.shape) == 3: # resahpe standard convolution 554 | kv, in_dim, out_dim = v.shape 555 | ks = int(math.sqrt(kv)) 556 | new_ckpt[new_k] = v.permute(2, 1, 0).\ 557 | reshape(out_dim, in_dim, ks, ks).transpose(3, 2) 558 | elif len(v.shape) == 2: # reshape depthwise convolution 559 | kv, dim = v.shape 560 | ks = int(math.sqrt(kv)) 561 | new_ckpt[new_k] = v.permute(1, 0).\ 562 | reshape(dim, 1, ks, ks).transpose(3, 2) 563 | continue 564 | elif 'ln' in k or 'linear' in k: 565 | k = k.split('.') 566 | k.pop(-2) # remove ln and linear in the name 567 | new_k = '.'.join(k) 568 | else: 569 | new_k = k 570 | new_ckpt[new_k] = v 571 | 572 | # reshape grn affine parameters and biases 573 | for k, v in new_ckpt.items(): 574 | if k.endswith('bias') and len(v.shape) != 1: 575 | new_ckpt[k] = v.reshape(-1) 576 | elif 'grn' in k: 577 | new_ckpt[k] = v.unsqueeze(0).unsqueeze(1) 578 | return new_ckpt --------------------------------------------------------------------------------