├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── FINETUNE.md ├── LICENSE ├── PRETRAIN.md ├── README.md ├── demo └── mae_visualize.ipynb ├── engine_finetune.py ├── engine_pretrain.py ├── main_finetune.py ├── main_linprobe.py ├── main_pretrain.py ├── models_mae.py ├── models_vit.py ├── submitit_finetune.py ├── submitit_linprobe.py ├── submitit_pretrain.py └── util ├── crop.py ├── datasets.py ├── lars.py ├── lr_decay.py ├── lr_sched.py ├── misc.py └── pos_embed.py /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to mae 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to mae, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /FINETUNE.md: -------------------------------------------------------------------------------- 1 | ## Fine-tuning Pre-trained MAE for Classification 2 | 3 | ### Evaluation 4 | 5 | As a sanity check, run evaluation using our ImageNet **fine-tuned** models: 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 |
ViT-BaseViT-LargeViT-Huge
fine-tuned checkpointdownloaddownloaddownload
md51b25e951f5502541f2
reference ImageNet accuracy83.66485.95286.928
31 | 32 | Evaluate ViT-Base in a single GPU (`${IMAGENET_DIR}` is a directory containing `{train, val}` sets of ImageNet): 33 | ``` 34 | python main_finetune.py --eval --resume mae_finetuned_vit_base.pth --model vit_base_patch16 --batch_size 16 --data_path ${IMAGENET_DIR} 35 | ``` 36 | This should give: 37 | ``` 38 | * Acc@1 83.664 Acc@5 96.530 loss 0.731 39 | ``` 40 | 41 | Evaluate ViT-Large: 42 | ``` 43 | python main_finetune.py --eval --resume mae_finetuned_vit_large.pth --model vit_large_patch16 --batch_size 16 --data_path ${IMAGENET_DIR} 44 | ``` 45 | This should give: 46 | ``` 47 | * Acc@1 85.952 Acc@5 97.570 loss 0.646 48 | ``` 49 | 50 | Evaluate ViT-Huge: 51 | ``` 52 | python main_finetune.py --eval --resume mae_finetuned_vit_huge.pth --model vit_huge_patch14 --batch_size 16 --data_path ${IMAGENET_DIR} 53 | ``` 54 | This should give: 55 | ``` 56 | * Acc@1 86.928 Acc@5 98.088 loss 0.584 57 | ``` 58 | 59 | ### Fine-tuning 60 | 61 | Get our pre-trained checkpoints from [here](https://github.com/fairinternal/mae/#pre-trained-checkpoints). 62 | 63 | To fine-tune with **multi-node distributed training**, run the following on 4 nodes with 8 GPUs each: 64 | ``` 65 | python submitit_finetune.py \ 66 | --job_dir ${JOB_DIR} \ 67 | --nodes 4 \ 68 | --batch_size 32 \ 69 | --model vit_base_patch16 \ 70 | --finetune ${PRETRAIN_CHKPT} \ 71 | --epochs 100 \ 72 | --blr 5e-4 --layer_decay 0.65 \ 73 | --weight_decay 0.05 --drop_path 0.1 --reprob 0.25 --mixup 0.8 --cutmix 1.0 \ 74 | --dist_eval --data_path ${IMAGENET_DIR} 75 | ``` 76 | - Install submitit (`pip install submitit`) first. 77 | - Here the effective batch size is 32 (`batch_size` per gpu) * 4 (`nodes`) * 8 (gpus per node) = 1024. 78 | - `blr` is the base learning rate. The actual `lr` is computed by the [linear scaling rule](https://arxiv.org/abs/1706.02677): `lr` = `blr` * effective batch size / 256. 79 | - We have run 4 trials with different random seeds. The resutls are 83.63, 83.66, 83.52, 83.46 (mean 83.57 and std 0.08). 80 | - Training time is ~7h11m in 32 V100 GPUs. 81 | 82 | Script for ViT-Large: 83 | ``` 84 | python submitit_finetune.py \ 85 | --job_dir ${JOB_DIR} \ 86 | --nodes 4 --use_volta32 \ 87 | --batch_size 32 \ 88 | --model vit_large_patch16 \ 89 | --finetune ${PRETRAIN_CHKPT} \ 90 | --epochs 50 \ 91 | --blr 1e-3 --layer_decay 0.75 \ 92 | --weight_decay 0.05 --drop_path 0.2 --reprob 0.25 --mixup 0.8 --cutmix 1.0 \ 93 | --dist_eval --data_path ${IMAGENET_DIR} 94 | ``` 95 | - We have run 4 trials with different random seeds. The resutls are 85.95, 85.87, 85.76, 85.88 (mean 85.87 and std 0.07). 96 | - Training time is ~8h52m in 32 V100 GPUs. 97 | 98 | Script for ViT-Huge: 99 | ``` 100 | python submitit_finetune.py \ 101 | --job_dir ${JOB_DIR} \ 102 | --nodes 8 --use_volta32 \ 103 | --batch_size 16 \ 104 | --model vit_huge_patch14 \ 105 | --finetune ${PRETRAIN_CHKPT} \ 106 | --epochs 50 \ 107 | --blr 1e-3 --layer_decay 0.75 \ 108 | --weight_decay 0.05 --drop_path 0.3 --reprob 0.25 --mixup 0.8 --cutmix 1.0 \ 109 | --dist_eval --data_path ${IMAGENET_DIR} 110 | ``` 111 | - Training time is ~13h9m in 64 V100 GPUs. 112 | 113 | To fine-tune our pre-trained ViT-Base with **single-node training**, run the following on 1 node with 8 GPUs: 114 | ``` 115 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 main_finetune.py \ 116 | --accum_iter 4 \ 117 | --batch_size 32 \ 118 | --model vit_base_patch16 \ 119 | --finetune ${PRETRAIN_CHKPT} \ 120 | --epochs 100 \ 121 | --blr 5e-4 --layer_decay 0.65 \ 122 | --weight_decay 0.05 --drop_path 0.1 --mixup 0.8 --cutmix 1.0 --reprob 0.25 \ 123 | --dist_eval --data_path ${IMAGENET_DIR} 124 | ``` 125 | - Here the effective batch size is 32 (`batch_size` per gpu) * 4 (`accum_iter`) * 8 (gpus) = 1024. `--accum_iter 4` simulates 4 nodes. 126 | 127 | #### Notes 128 | 129 | - The [pre-trained models we provide](https://github.com/fairinternal/mae/#pre-trained-checkpoints) are trained with *normalized* pixels `--norm_pix_loss` (1600 epochs, Table 3 in paper). The fine-tuning hyper-parameters are slightly different from the default baseline using *unnormalized* pixels. 130 | 131 | - The original MAE implementation was in TensorFlow+TPU with no explicit mixed precision. This re-implementation is in PyTorch+GPU with automatic mixed precision (`torch.cuda.amp`). We have observed different numerical behavior between the two platforms. In this repo, we use `--global_pool` for fine-tuning; using `--cls_token` performs similarly, but there is a chance of producing NaN when fine-tuning ViT-Huge in GPUs. We did not observe this issue in TPUs. Turning off amp could solve this issue, but is slower. 132 | 133 | - Here we use RandErase following DeiT: `--reprob 0.25`. Its effect is smaller than random variance. 134 | 135 | ### Linear Probing 136 | 137 | Run the following on 4 nodes with 8 GPUs each: 138 | ``` 139 | python submitit_linprobe.py \ 140 | --job_dir ${JOB_DIR} \ 141 | --nodes 4 \ 142 | --batch_size 512 \ 143 | --model vit_base_patch16 --cls_token \ 144 | --finetune ${PRETRAIN_CHKPT} \ 145 | --epochs 90 \ 146 | --blr 0.1 \ 147 | --weight_decay 0.0 \ 148 | --dist_eval --data_path ${IMAGENET_DIR} 149 | ``` 150 | - Here the effective batch size is 512 (`batch_size` per gpu) * 4 (`nodes`) * 8 (gpus per node) = 16384. 151 | - `blr` is the base learning rate. The actual `lr` is computed by the [linear scaling rule](https://arxiv.org/abs/1706.02677): `lr` = `blr` * effective batch size / 256. 152 | - Training time is ~2h20m for 90 epochs in 32 V100 GPUs. 153 | - To run single-node training, follow the instruction in fine-tuning. 154 | 155 | To train ViT-Large or ViT-Huge, set `--model vit_large_patch16` or `--model vit_huge_patch14`. It is sufficient to train 50 epochs `--epochs 50`. 156 | 157 | This PT/GPU code produces *better* results for ViT-L/H (see the table below). This is likely caused by the system difference between TF and PT. 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 |
ViT-BaseViT-LargeViT-Huge
paper (TF/TPU)68.075.876.6
this repo (PT/GPU)67.876.077.2
178 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Attribution-NonCommercial 4.0 International 3 | 4 | ======================================================================= 5 | 6 | Creative Commons Corporation ("Creative Commons") is not a law firm and 7 | does not provide legal services or legal advice. Distribution of 8 | Creative Commons public licenses does not create a lawyer-client or 9 | other relationship. Creative Commons makes its licenses and related 10 | information available on an "as-is" basis. Creative Commons gives no 11 | warranties regarding its licenses, any material licensed under their 12 | terms and conditions, or any related information. Creative Commons 13 | disclaims all liability for damages resulting from their use to the 14 | fullest extent possible. 15 | 16 | Using Creative Commons Public Licenses 17 | 18 | Creative Commons public licenses provide a standard set of terms and 19 | conditions that creators and other rights holders may use to share 20 | original works of authorship and other material subject to copyright 21 | and certain other rights specified in the public license below. The 22 | following considerations are for informational purposes only, are not 23 | exhaustive, and do not form part of our licenses. 24 | 25 | Considerations for licensors: Our public licenses are 26 | intended for use by those authorized to give the public 27 | permission to use material in ways otherwise restricted by 28 | copyright and certain other rights. Our licenses are 29 | irrevocable. Licensors should read and understand the terms 30 | and conditions of the license they choose before applying it. 31 | Licensors should also secure all rights necessary before 32 | applying our licenses so that the public can reuse the 33 | material as expected. Licensors should clearly mark any 34 | material not subject to the license. This includes other CC- 35 | licensed material, or material used under an exception or 36 | limitation to copyright. More considerations for licensors: 37 | wiki.creativecommons.org/Considerations_for_licensors 38 | 39 | Considerations for the public: By using one of our public 40 | licenses, a licensor grants the public permission to use the 41 | licensed material under specified terms and conditions. If 42 | the licensor's permission is not necessary for any reason--for 43 | example, because of any applicable exception or limitation to 44 | copyright--then that use is not regulated by the license. Our 45 | licenses grant only permissions under copyright and certain 46 | other rights that a licensor has authority to grant. Use of 47 | the licensed material may still be restricted for other 48 | reasons, including because others have copyright or other 49 | rights in the material. A licensor may make special requests, 50 | such as asking that all changes be marked or described. 51 | Although not required by our licenses, you are encouraged to 52 | respect those requests where reasonable. More_considerations 53 | for the public: 54 | wiki.creativecommons.org/Considerations_for_licensees 55 | 56 | ======================================================================= 57 | 58 | Creative Commons Attribution-NonCommercial 4.0 International Public 59 | License 60 | 61 | By exercising the Licensed Rights (defined below), You accept and agree 62 | to be bound by the terms and conditions of this Creative Commons 63 | Attribution-NonCommercial 4.0 International Public License ("Public 64 | License"). To the extent this Public License may be interpreted as a 65 | contract, You are granted the Licensed Rights in consideration of Your 66 | acceptance of these terms and conditions, and the Licensor grants You 67 | such rights in consideration of benefits the Licensor receives from 68 | making the Licensed Material available under these terms and 69 | conditions. 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. Copyright and Similar Rights means copyright and/or similar rights 88 | closely related to copyright including, without limitation, 89 | performance, broadcast, sound recording, and Sui Generis Database 90 | Rights, without regard to how the rights are labeled or 91 | categorized. For purposes of this Public License, the rights 92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 93 | Rights. 94 | d. Effective Technological Measures means those measures that, in the 95 | absence of proper authority, may not be circumvented under laws 96 | fulfilling obligations under Article 11 of the WIPO Copyright 97 | Treaty adopted on December 20, 1996, and/or similar international 98 | agreements. 99 | 100 | e. Exceptions and Limitations means fair use, fair dealing, and/or 101 | any other exception or limitation to Copyright and Similar Rights 102 | that applies to Your use of the Licensed Material. 103 | 104 | f. Licensed Material means the artistic or literary work, database, 105 | or other material to which the Licensor applied this Public 106 | License. 107 | 108 | g. Licensed Rights means the rights granted to You subject to the 109 | terms and conditions of this Public License, which are limited to 110 | all Copyright and Similar Rights that apply to Your use of the 111 | Licensed Material and that the Licensor has authority to license. 112 | 113 | h. Licensor means the individual(s) or entity(ies) granting rights 114 | under this Public License. 115 | 116 | i. NonCommercial means not primarily intended for or directed towards 117 | commercial advantage or monetary compensation. For purposes of 118 | this Public License, the exchange of the Licensed Material for 119 | other material subject to Copyright and Similar Rights by digital 120 | file-sharing or similar means is NonCommercial provided there is 121 | no payment of monetary compensation in connection with the 122 | exchange. 123 | 124 | j. Share means to provide material to the public by any means or 125 | process that requires permission under the Licensed Rights, such 126 | as reproduction, public display, public performance, distribution, 127 | dissemination, communication, or importation, and to make material 128 | available to the public including in ways that members of the 129 | public may access the material from a place and at a time 130 | individually chosen by them. 131 | 132 | k. Sui Generis Database Rights means rights other than copyright 133 | resulting from Directive 96/9/EC of the European Parliament and of 134 | the Council of 11 March 1996 on the legal protection of databases, 135 | as amended and/or succeeded, as well as other essentially 136 | equivalent rights anywhere in the world. 137 | 138 | l. You means the individual or entity exercising the Licensed Rights 139 | under this Public License. Your has a corresponding meaning. 140 | 141 | Section 2 -- Scope. 142 | 143 | a. License grant. 144 | 145 | 1. Subject to the terms and conditions of this Public License, 146 | the Licensor hereby grants You a worldwide, royalty-free, 147 | non-sublicensable, non-exclusive, irrevocable license to 148 | exercise the Licensed Rights in the Licensed Material to: 149 | 150 | a. reproduce and Share the Licensed Material, in whole or 151 | in part, for NonCommercial purposes only; and 152 | 153 | b. produce, reproduce, and Share Adapted Material for 154 | NonCommercial purposes only. 155 | 156 | 2. Exceptions and Limitations. For the avoidance of doubt, where 157 | Exceptions and Limitations apply to Your use, this Public 158 | License does not apply, and You do not need to comply with 159 | its terms and conditions. 160 | 161 | 3. Term. The term of this Public License is specified in Section 162 | 6(a). 163 | 164 | 4. Media and formats; technical modifications allowed. The 165 | Licensor authorizes You to exercise the Licensed Rights in 166 | all media and formats whether now known or hereafter created, 167 | and to make technical modifications necessary to do so. The 168 | Licensor waives and/or agrees not to assert any right or 169 | authority to forbid You from making technical modifications 170 | necessary to exercise the Licensed Rights, including 171 | technical modifications necessary to circumvent Effective 172 | Technological Measures. For purposes of this Public License, 173 | simply making modifications authorized by this Section 2(a) 174 | (4) never produces Adapted Material. 175 | 176 | 5. Downstream recipients. 177 | 178 | a. Offer from the Licensor -- Licensed Material. Every 179 | recipient of the Licensed Material automatically 180 | receives an offer from the Licensor to exercise the 181 | Licensed Rights under the terms and conditions of this 182 | Public License. 183 | 184 | b. No downstream restrictions. You may not offer or impose 185 | any additional or different terms or conditions on, or 186 | apply any Effective Technological Measures to, the 187 | Licensed Material if doing so restricts exercise of the 188 | Licensed Rights by any recipient of the Licensed 189 | Material. 190 | 191 | 6. No endorsement. Nothing in this Public License constitutes or 192 | may be construed as permission to assert or imply that You 193 | are, or that Your use of the Licensed Material is, connected 194 | with, or sponsored, endorsed, or granted official status by, 195 | the Licensor or others designated to receive attribution as 196 | provided in Section 3(a)(1)(A)(i). 197 | 198 | b. Other rights. 199 | 200 | 1. Moral rights, such as the right of integrity, are not 201 | licensed under this Public License, nor are publicity, 202 | privacy, and/or other similar personality rights; however, to 203 | the extent possible, the Licensor waives and/or agrees not to 204 | assert any such rights held by the Licensor to the limited 205 | extent necessary to allow You to exercise the Licensed 206 | Rights, but not otherwise. 207 | 208 | 2. Patent and trademark rights are not licensed under this 209 | Public License. 210 | 211 | 3. To the extent possible, the Licensor waives any right to 212 | collect royalties from You for the exercise of the Licensed 213 | Rights, whether directly or through a collecting society 214 | under any voluntary or waivable statutory or compulsory 215 | licensing scheme. In all other cases the Licensor expressly 216 | reserves any right to collect such royalties, including when 217 | the Licensed Material is used other than for NonCommercial 218 | purposes. 219 | 220 | Section 3 -- License Conditions. 221 | 222 | Your exercise of the Licensed Rights is expressly made subject to the 223 | following conditions. 224 | 225 | a. Attribution. 226 | 227 | 1. If You Share the Licensed Material (including in modified 228 | form), You must: 229 | 230 | a. retain the following if it is supplied by the Licensor 231 | with the Licensed Material: 232 | 233 | i. identification of the creator(s) of the Licensed 234 | Material and any others designated to receive 235 | attribution, in any reasonable manner requested by 236 | the Licensor (including by pseudonym if 237 | designated); 238 | 239 | ii. a copyright notice; 240 | 241 | iii. a notice that refers to this Public License; 242 | 243 | iv. a notice that refers to the disclaimer of 244 | warranties; 245 | 246 | v. a URI or hyperlink to the Licensed Material to the 247 | extent reasonably practicable; 248 | 249 | b. indicate if You modified the Licensed Material and 250 | retain an indication of any previous modifications; and 251 | 252 | c. indicate the Licensed Material is licensed under this 253 | Public License, and include the text of, or the URI or 254 | hyperlink to, this Public License. 255 | 256 | 2. You may satisfy the conditions in Section 3(a)(1) in any 257 | reasonable manner based on the medium, means, and context in 258 | which You Share the Licensed Material. For example, it may be 259 | reasonable to satisfy the conditions by providing a URI or 260 | hyperlink to a resource that includes the required 261 | information. 262 | 263 | 3. If requested by the Licensor, You must remove any of the 264 | information required by Section 3(a)(1)(A) to the extent 265 | reasonably practicable. 266 | 267 | 4. If You Share Adapted Material You produce, the Adapter's 268 | License You apply must not prevent recipients of the Adapted 269 | Material from complying with this Public License. 270 | 271 | Section 4 -- Sui Generis Database Rights. 272 | 273 | Where the Licensed Rights include Sui Generis Database Rights that 274 | apply to Your use of the Licensed Material: 275 | 276 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 277 | to extract, reuse, reproduce, and Share all or a substantial 278 | portion of the contents of the database for NonCommercial purposes 279 | only; 280 | 281 | b. if You include all or a substantial portion of the database 282 | contents in a database in which You have Sui Generis Database 283 | Rights, then the database in which You have Sui Generis Database 284 | Rights (but not its individual contents) is Adapted Material; and 285 | 286 | c. You must comply with the conditions in Section 3(a) if You Share 287 | all or a substantial portion of the contents of the database. 288 | 289 | For the avoidance of doubt, this Section 4 supplements and does not 290 | replace Your obligations under this Public License where the Licensed 291 | Rights include other Copyright and Similar Rights. 292 | 293 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 294 | 295 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 296 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 297 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 298 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 299 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 300 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 301 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 302 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 303 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 304 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 305 | 306 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 307 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 308 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 309 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 310 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 311 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 312 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 313 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 314 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 315 | 316 | c. The disclaimer of warranties and limitation of liability provided 317 | above shall be interpreted in a manner that, to the extent 318 | possible, most closely approximates an absolute disclaimer and 319 | waiver of all liability. 320 | 321 | Section 6 -- Term and Termination. 322 | 323 | a. This Public License applies for the term of the Copyright and 324 | Similar Rights licensed here. However, if You fail to comply with 325 | this Public License, then Your rights under this Public License 326 | terminate automatically. 327 | 328 | b. Where Your right to use the Licensed Material has terminated under 329 | Section 6(a), it reinstates: 330 | 331 | 1. automatically as of the date the violation is cured, provided 332 | it is cured within 30 days of Your discovery of the 333 | violation; or 334 | 335 | 2. upon express reinstatement by the Licensor. 336 | 337 | For the avoidance of doubt, this Section 6(b) does not affect any 338 | right the Licensor may have to seek remedies for Your violations 339 | of this Public License. 340 | 341 | c. For the avoidance of doubt, the Licensor may also offer the 342 | Licensed Material under separate terms or conditions or stop 343 | distributing the Licensed Material at any time; however, doing so 344 | will not terminate this Public License. 345 | 346 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 347 | License. 348 | 349 | Section 7 -- Other Terms and Conditions. 350 | 351 | a. The Licensor shall not be bound by any additional or different 352 | terms or conditions communicated by You unless expressly agreed. 353 | 354 | b. Any arrangements, understandings, or agreements regarding the 355 | Licensed Material not stated herein are separate from and 356 | independent of the terms and conditions of this Public License. 357 | 358 | Section 8 -- Interpretation. 359 | 360 | a. For the avoidance of doubt, this Public License does not, and 361 | shall not be interpreted to, reduce, limit, restrict, or impose 362 | conditions on any use of the Licensed Material that could lawfully 363 | be made without permission under this Public License. 364 | 365 | b. To the extent possible, if any provision of this Public License is 366 | deemed unenforceable, it shall be automatically reformed to the 367 | minimum extent necessary to make it enforceable. If the provision 368 | cannot be reformed, it shall be severed from this Public License 369 | without affecting the enforceability of the remaining terms and 370 | conditions. 371 | 372 | c. No term or condition of this Public License will be waived and no 373 | failure to comply consented to unless expressly agreed to by the 374 | Licensor. 375 | 376 | d. Nothing in this Public License constitutes or may be interpreted 377 | as a limitation upon, or waiver of, any privileges and immunities 378 | that apply to the Licensor or You, including from the legal 379 | processes of any jurisdiction or authority. 380 | 381 | ======================================================================= 382 | 383 | Creative Commons is not a party to its public 384 | licenses. Notwithstanding, Creative Commons may elect to apply one of 385 | its public licenses to material it publishes and in those instances 386 | will be considered the “Licensor.” The text of the Creative Commons 387 | public licenses is dedicated to the public domain under the CC0 Public 388 | Domain Dedication. Except for the limited purpose of indicating that 389 | material is shared under a Creative Commons public license or as 390 | otherwise permitted by the Creative Commons policies published at 391 | creativecommons.org/policies, Creative Commons does not authorize the 392 | use of the trademark "Creative Commons" or any other trademark or logo 393 | of Creative Commons without its prior written consent including, 394 | without limitation, in connection with any unauthorized modifications 395 | to any of its public licenses or any other arrangements, 396 | understandings, or agreements concerning use of licensed material. For 397 | the avoidance of doubt, this paragraph does not form part of the 398 | public licenses. 399 | 400 | Creative Commons may be contacted at creativecommons.org. -------------------------------------------------------------------------------- /PRETRAIN.md: -------------------------------------------------------------------------------- 1 | ## Pre-training MAE 2 | 3 | To pre-train ViT-Large (recommended default) with **multi-node distributed training**, run the following on 8 nodes with 8 GPUs each: 4 | ``` 5 | python submitit_pretrain.py \ 6 | --job_dir ${JOB_DIR} \ 7 | --nodes 8 \ 8 | --use_volta32 \ 9 | --batch_size 64 \ 10 | --model mae_vit_large_patch16 \ 11 | --norm_pix_loss \ 12 | --mask_ratio 0.75 \ 13 | --epochs 800 \ 14 | --warmup_epochs 40 \ 15 | --blr 1.5e-4 --weight_decay 0.05 \ 16 | --data_path ${IMAGENET_DIR} 17 | ``` 18 | - Here the effective batch size is 64 (`batch_size` per gpu) * 8 (`nodes`) * 8 (gpus per node) = 4096. If memory or # gpus is limited, use `--accum_iter` to maintain the effective batch size, which is `batch_size` (per gpu) * `nodes` * 8 (gpus per node) * `accum_iter`. 19 | - `blr` is the base learning rate. The actual `lr` is computed by the [linear scaling rule](https://arxiv.org/abs/1706.02677): `lr` = `blr` * effective batch size / 256. 20 | - Here we use `--norm_pix_loss` as the target for better representation learning. To train a baseline model (e.g., for visualization), use pixel-based construction and turn off `--norm_pix_loss`. 21 | - The exact same hyper-parameters and configs (initialization, augmentation, etc.) are used as our TF/TPU implementation. In our sanity checks, this PT/GPU re-implementation can reproduce the TF/TPU results within reasonable random variation. We get 85.5% [fine-tuning](FINETUNE.md) accuracy by pre-training ViT-Large for 800 epochs (85.4% in paper Table 1d with TF/TPU). 22 | - Training time is ~42h in 64 V100 GPUs (800 epochs). 23 | 24 | To train ViT-Base or ViT-Huge, set `--model mae_vit_base_patch16` or `--model mae_vit_huge_patch14`. 25 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Masked Autoencoders: A PyTorch Implementation 2 | 3 |

4 | 5 |

6 | 7 | 8 | This is a PyTorch/GPU re-implementation of the paper [Masked Autoencoders Are Scalable Vision Learners](https://arxiv.org/abs/2111.06377): 9 | ``` 10 | @Article{MaskedAutoencoders2021, 11 | author = {Kaiming He and Xinlei Chen and Saining Xie and Yanghao Li and Piotr Doll{\'a}r and Ross Girshick}, 12 | journal = {arXiv:2111.06377}, 13 | title = {Masked Autoencoders Are Scalable Vision Learners}, 14 | year = {2021}, 15 | } 16 | ``` 17 | 18 | * The original implementation was in TensorFlow+TPU. This re-implementation is in PyTorch+GPU. 19 | 20 | * This repo is a modification on the [DeiT repo](https://github.com/facebookresearch/deit). Installation and preparation follow that repo. 21 | 22 | * This repo is based on [`timm==0.3.2`](https://github.com/rwightman/pytorch-image-models), for which a [fix](https://github.com/rwightman/pytorch-image-models/issues/420#issuecomment-776459842) is needed to work with PyTorch 1.8.1+. 23 | 24 | ### Catalog 25 | 26 | - [x] Visualization demo 27 | - [x] Pre-trained checkpoints + fine-tuning code 28 | - [x] Pre-training code 29 | 30 | ### Visualization demo 31 | 32 | Run our interactive visualization demo using [Colab notebook](https://colab.research.google.com/github/facebookresearch/mae/blob/main/demo/mae_visualize.ipynb) (no GPU needed): 33 |

34 | 35 |

36 | 37 | ### Fine-tuning with pre-trained checkpoints 38 | 39 | The following table provides the pre-trained checkpoints used in the paper, converted from TF/TPU to PT/GPU: 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 |
ViT-BaseViT-LargeViT-Huge
pre-trained checkpointdownloaddownloaddownload
md58cad7cb8b06e9bdbb0
59 | 60 | The fine-tuning instruction is in [FINETUNE.md](FINETUNE.md). 61 | 62 | By fine-tuning these pre-trained models, we rank #1 in these classification tasks (detailed in the paper): 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 |
ViT-BViT-LViT-HViT-H448prev best
ImageNet-1K (no external data)83.685.986.987.887.1
following are evaluation of the same model weights (fine-tuned in original ImageNet-1K):
ImageNet-Corruption (error rate) 51.741.833.836.842.5
ImageNet-Adversarial35.957.168.276.735.8
ImageNet-Rendition48.359.964.466.548.7
ImageNet-Sketch34.545.349.650.936.0
following are transfer learning by fine-tuning the pre-trained MAE on the target dataset:
iNaturalists 201770.575.779.383.475.4
iNaturalists 201875.480.183.086.881.2
iNaturalists 201980.583.485.788.384.1
Places20563.965.865.966.866.0
Places36557.959.459.860.358.0
149 | 150 | ### Pre-training 151 | 152 | The pre-training instruction is in [PRETRAIN.md](PRETRAIN.md). 153 | 154 | ### License 155 | 156 | This project is under the CC-BY-NC 4.0 license. See [LICENSE](LICENSE) for details. 157 | -------------------------------------------------------------------------------- /engine_finetune.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import math 13 | import sys 14 | from typing import Iterable, Optional 15 | 16 | import torch 17 | 18 | from timm.data import Mixup 19 | from timm.utils import accuracy 20 | 21 | import util.misc as misc 22 | import util.lr_sched as lr_sched 23 | 24 | 25 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, 26 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 27 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 28 | mixup_fn: Optional[Mixup] = None, log_writer=None, 29 | args=None): 30 | model.train(True) 31 | metric_logger = misc.MetricLogger(delimiter=" ") 32 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) 33 | header = 'Epoch: [{}]'.format(epoch) 34 | print_freq = 20 35 | 36 | accum_iter = args.accum_iter 37 | 38 | optimizer.zero_grad() 39 | 40 | if log_writer is not None: 41 | print('log_dir: {}'.format(log_writer.log_dir)) 42 | 43 | for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 44 | 45 | # we use a per iteration (instead of per epoch) lr scheduler 46 | if data_iter_step % accum_iter == 0: 47 | lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) 48 | 49 | samples = samples.to(device, non_blocking=True) 50 | targets = targets.to(device, non_blocking=True) 51 | 52 | if mixup_fn is not None: 53 | samples, targets = mixup_fn(samples, targets) 54 | 55 | with torch.cuda.amp.autocast(): 56 | outputs = model(samples) 57 | loss = criterion(outputs, targets) 58 | 59 | loss_value = loss.item() 60 | 61 | if not math.isfinite(loss_value): 62 | print("Loss is {}, stopping training".format(loss_value)) 63 | sys.exit(1) 64 | 65 | loss /= accum_iter 66 | loss_scaler(loss, optimizer, clip_grad=max_norm, 67 | parameters=model.parameters(), create_graph=False, 68 | update_grad=(data_iter_step + 1) % accum_iter == 0) 69 | if (data_iter_step + 1) % accum_iter == 0: 70 | optimizer.zero_grad() 71 | 72 | torch.cuda.synchronize() 73 | 74 | metric_logger.update(loss=loss_value) 75 | min_lr = 10. 76 | max_lr = 0. 77 | for group in optimizer.param_groups: 78 | min_lr = min(min_lr, group["lr"]) 79 | max_lr = max(max_lr, group["lr"]) 80 | 81 | metric_logger.update(lr=max_lr) 82 | 83 | loss_value_reduce = misc.all_reduce_mean(loss_value) 84 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0: 85 | """ We use epoch_1000x as the x-axis in tensorboard. 86 | This calibrates different curves when batch size changes. 87 | """ 88 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) 89 | log_writer.add_scalar('loss', loss_value_reduce, epoch_1000x) 90 | log_writer.add_scalar('lr', max_lr, epoch_1000x) 91 | 92 | # gather the stats from all processes 93 | metric_logger.synchronize_between_processes() 94 | print("Averaged stats:", metric_logger) 95 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 96 | 97 | 98 | @torch.no_grad() 99 | def evaluate(data_loader, model, device): 100 | criterion = torch.nn.CrossEntropyLoss() 101 | 102 | metric_logger = misc.MetricLogger(delimiter=" ") 103 | header = 'Test:' 104 | 105 | # switch to evaluation mode 106 | model.eval() 107 | 108 | for batch in metric_logger.log_every(data_loader, 10, header): 109 | images = batch[0] 110 | target = batch[-1] 111 | images = images.to(device, non_blocking=True) 112 | target = target.to(device, non_blocking=True) 113 | 114 | # compute output 115 | with torch.cuda.amp.autocast(): 116 | output = model(images) 117 | loss = criterion(output, target) 118 | 119 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 120 | 121 | batch_size = images.shape[0] 122 | metric_logger.update(loss=loss.item()) 123 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 124 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 125 | # gather the stats from all processes 126 | metric_logger.synchronize_between_processes() 127 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 128 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 129 | 130 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} -------------------------------------------------------------------------------- /engine_pretrain.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | import math 12 | import sys 13 | from typing import Iterable 14 | 15 | import torch 16 | 17 | import util.misc as misc 18 | import util.lr_sched as lr_sched 19 | 20 | 21 | def train_one_epoch(model: torch.nn.Module, 22 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 23 | device: torch.device, epoch: int, loss_scaler, 24 | log_writer=None, 25 | args=None): 26 | model.train(True) 27 | metric_logger = misc.MetricLogger(delimiter=" ") 28 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) 29 | header = 'Epoch: [{}]'.format(epoch) 30 | print_freq = 20 31 | 32 | accum_iter = args.accum_iter 33 | 34 | optimizer.zero_grad() 35 | 36 | if log_writer is not None: 37 | print('log_dir: {}'.format(log_writer.log_dir)) 38 | 39 | for data_iter_step, (samples, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 40 | 41 | # we use a per iteration (instead of per epoch) lr scheduler 42 | if data_iter_step % accum_iter == 0: 43 | lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) 44 | 45 | samples = samples.to(device, non_blocking=True) 46 | 47 | with torch.cuda.amp.autocast(): 48 | loss, _, _ = model(samples, mask_ratio=args.mask_ratio) 49 | 50 | loss_value = loss.item() 51 | 52 | if not math.isfinite(loss_value): 53 | print("Loss is {}, stopping training".format(loss_value)) 54 | sys.exit(1) 55 | 56 | loss /= accum_iter 57 | loss_scaler(loss, optimizer, parameters=model.parameters(), 58 | update_grad=(data_iter_step + 1) % accum_iter == 0) 59 | if (data_iter_step + 1) % accum_iter == 0: 60 | optimizer.zero_grad() 61 | 62 | torch.cuda.synchronize() 63 | 64 | metric_logger.update(loss=loss_value) 65 | 66 | lr = optimizer.param_groups[0]["lr"] 67 | metric_logger.update(lr=lr) 68 | 69 | loss_value_reduce = misc.all_reduce_mean(loss_value) 70 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0: 71 | """ We use epoch_1000x as the x-axis in tensorboard. 72 | This calibrates different curves when batch size changes. 73 | """ 74 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) 75 | log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x) 76 | log_writer.add_scalar('lr', lr, epoch_1000x) 77 | 78 | 79 | # gather the stats from all processes 80 | metric_logger.synchronize_between_processes() 81 | print("Averaged stats:", metric_logger) 82 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} -------------------------------------------------------------------------------- /main_finetune.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import argparse 13 | import datetime 14 | import json 15 | import numpy as np 16 | import os 17 | import time 18 | from pathlib import Path 19 | 20 | import torch 21 | import torch.backends.cudnn as cudnn 22 | from torch.utils.tensorboard import SummaryWriter 23 | 24 | import timm 25 | 26 | assert timm.__version__ == "0.3.2" # version check 27 | from timm.models.layers import trunc_normal_ 28 | from timm.data.mixup import Mixup 29 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 30 | 31 | import util.lr_decay as lrd 32 | import util.misc as misc 33 | from util.datasets import build_dataset 34 | from util.pos_embed import interpolate_pos_embed 35 | from util.misc import NativeScalerWithGradNormCount as NativeScaler 36 | 37 | import models_vit 38 | 39 | from engine_finetune import train_one_epoch, evaluate 40 | 41 | 42 | def get_args_parser(): 43 | parser = argparse.ArgumentParser('MAE fine-tuning for image classification', add_help=False) 44 | parser.add_argument('--batch_size', default=64, type=int, 45 | help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') 46 | parser.add_argument('--epochs', default=50, type=int) 47 | parser.add_argument('--accum_iter', default=1, type=int, 48 | help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') 49 | 50 | # Model parameters 51 | parser.add_argument('--model', default='vit_large_patch16', type=str, metavar='MODEL', 52 | help='Name of model to train') 53 | 54 | parser.add_argument('--input_size', default=224, type=int, 55 | help='images input size') 56 | 57 | parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT', 58 | help='Drop path rate (default: 0.1)') 59 | 60 | # Optimizer parameters 61 | parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM', 62 | help='Clip gradient norm (default: None, no clipping)') 63 | parser.add_argument('--weight_decay', type=float, default=0.05, 64 | help='weight decay (default: 0.05)') 65 | 66 | parser.add_argument('--lr', type=float, default=None, metavar='LR', 67 | help='learning rate (absolute lr)') 68 | parser.add_argument('--blr', type=float, default=1e-3, metavar='LR', 69 | help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') 70 | parser.add_argument('--layer_decay', type=float, default=0.75, 71 | help='layer-wise lr decay from ELECTRA/BEiT') 72 | 73 | parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR', 74 | help='lower lr bound for cyclic schedulers that hit 0') 75 | 76 | parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N', 77 | help='epochs to warmup LR') 78 | 79 | # Augmentation parameters 80 | parser.add_argument('--color_jitter', type=float, default=None, metavar='PCT', 81 | help='Color jitter factor (enabled only when not using Auto/RandAug)') 82 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', 83 | help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'), 84 | parser.add_argument('--smoothing', type=float, default=0.1, 85 | help='Label smoothing (default: 0.1)') 86 | 87 | # * Random Erase params 88 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', 89 | help='Random erase prob (default: 0.25)') 90 | parser.add_argument('--remode', type=str, default='pixel', 91 | help='Random erase mode (default: "pixel")') 92 | parser.add_argument('--recount', type=int, default=1, 93 | help='Random erase count (default: 1)') 94 | parser.add_argument('--resplit', action='store_true', default=False, 95 | help='Do not random erase first (clean) augmentation split') 96 | 97 | # * Mixup params 98 | parser.add_argument('--mixup', type=float, default=0, 99 | help='mixup alpha, mixup enabled if > 0.') 100 | parser.add_argument('--cutmix', type=float, default=0, 101 | help='cutmix alpha, cutmix enabled if > 0.') 102 | parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None, 103 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') 104 | parser.add_argument('--mixup_prob', type=float, default=1.0, 105 | help='Probability of performing mixup or cutmix when either/both is enabled') 106 | parser.add_argument('--mixup_switch_prob', type=float, default=0.5, 107 | help='Probability of switching to cutmix when both mixup and cutmix enabled') 108 | parser.add_argument('--mixup_mode', type=str, default='batch', 109 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') 110 | 111 | # * Finetuning params 112 | parser.add_argument('--finetune', default='', 113 | help='finetune from checkpoint') 114 | parser.add_argument('--global_pool', action='store_true') 115 | parser.set_defaults(global_pool=True) 116 | parser.add_argument('--cls_token', action='store_false', dest='global_pool', 117 | help='Use class token instead of global pool for classification') 118 | 119 | # Dataset parameters 120 | parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str, 121 | help='dataset path') 122 | parser.add_argument('--nb_classes', default=1000, type=int, 123 | help='number of the classification types') 124 | 125 | parser.add_argument('--output_dir', default='./output_dir', 126 | help='path where to save, empty for no saving') 127 | parser.add_argument('--log_dir', default='./output_dir', 128 | help='path where to tensorboard log') 129 | parser.add_argument('--device', default='cuda', 130 | help='device to use for training / testing') 131 | parser.add_argument('--seed', default=0, type=int) 132 | parser.add_argument('--resume', default='', 133 | help='resume from checkpoint') 134 | 135 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 136 | help='start epoch') 137 | parser.add_argument('--eval', action='store_true', 138 | help='Perform evaluation only') 139 | parser.add_argument('--dist_eval', action='store_true', default=False, 140 | help='Enabling distributed evaluation (recommended during training for faster monitor') 141 | parser.add_argument('--num_workers', default=10, type=int) 142 | parser.add_argument('--pin_mem', action='store_true', 143 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 144 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') 145 | parser.set_defaults(pin_mem=True) 146 | 147 | # distributed training parameters 148 | parser.add_argument('--world_size', default=1, type=int, 149 | help='number of distributed processes') 150 | parser.add_argument('--local_rank', default=-1, type=int) 151 | parser.add_argument('--dist_on_itp', action='store_true') 152 | parser.add_argument('--dist_url', default='env://', 153 | help='url used to set up distributed training') 154 | 155 | return parser 156 | 157 | 158 | def main(args): 159 | misc.init_distributed_mode(args) 160 | 161 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) 162 | print("{}".format(args).replace(', ', ',\n')) 163 | 164 | device = torch.device(args.device) 165 | 166 | # fix the seed for reproducibility 167 | seed = args.seed + misc.get_rank() 168 | torch.manual_seed(seed) 169 | np.random.seed(seed) 170 | 171 | cudnn.benchmark = True 172 | 173 | dataset_train = build_dataset(is_train=True, args=args) 174 | dataset_val = build_dataset(is_train=False, args=args) 175 | 176 | if True: # args.distributed: 177 | num_tasks = misc.get_world_size() 178 | global_rank = misc.get_rank() 179 | sampler_train = torch.utils.data.DistributedSampler( 180 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 181 | ) 182 | print("Sampler_train = %s" % str(sampler_train)) 183 | if args.dist_eval: 184 | if len(dataset_val) % num_tasks != 0: 185 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 186 | 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 187 | 'equal num of samples per-process.') 188 | sampler_val = torch.utils.data.DistributedSampler( 189 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True) # shuffle=True to reduce monitor bias 190 | else: 191 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 192 | else: 193 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 194 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 195 | 196 | if global_rank == 0 and args.log_dir is not None and not args.eval: 197 | os.makedirs(args.log_dir, exist_ok=True) 198 | log_writer = SummaryWriter(log_dir=args.log_dir) 199 | else: 200 | log_writer = None 201 | 202 | data_loader_train = torch.utils.data.DataLoader( 203 | dataset_train, sampler=sampler_train, 204 | batch_size=args.batch_size, 205 | num_workers=args.num_workers, 206 | pin_memory=args.pin_mem, 207 | drop_last=True, 208 | ) 209 | 210 | data_loader_val = torch.utils.data.DataLoader( 211 | dataset_val, sampler=sampler_val, 212 | batch_size=args.batch_size, 213 | num_workers=args.num_workers, 214 | pin_memory=args.pin_mem, 215 | drop_last=False 216 | ) 217 | 218 | mixup_fn = None 219 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None 220 | if mixup_active: 221 | print("Mixup is activated!") 222 | mixup_fn = Mixup( 223 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, 224 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, 225 | label_smoothing=args.smoothing, num_classes=args.nb_classes) 226 | 227 | model = models_vit.__dict__[args.model]( 228 | num_classes=args.nb_classes, 229 | drop_path_rate=args.drop_path, 230 | global_pool=args.global_pool, 231 | ) 232 | 233 | if args.finetune and not args.eval: 234 | checkpoint = torch.load(args.finetune, map_location='cpu') 235 | 236 | print("Load pre-trained checkpoint from: %s" % args.finetune) 237 | checkpoint_model = checkpoint['model'] 238 | state_dict = model.state_dict() 239 | for k in ['head.weight', 'head.bias']: 240 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: 241 | print(f"Removing key {k} from pretrained checkpoint") 242 | del checkpoint_model[k] 243 | 244 | # interpolate position embedding 245 | interpolate_pos_embed(model, checkpoint_model) 246 | 247 | # load pre-trained model 248 | msg = model.load_state_dict(checkpoint_model, strict=False) 249 | print(msg) 250 | 251 | if args.global_pool: 252 | assert set(msg.missing_keys) == {'head.weight', 'head.bias', 'fc_norm.weight', 'fc_norm.bias'} 253 | else: 254 | assert set(msg.missing_keys) == {'head.weight', 'head.bias'} 255 | 256 | # manually initialize fc layer 257 | trunc_normal_(model.head.weight, std=2e-5) 258 | 259 | model.to(device) 260 | 261 | model_without_ddp = model 262 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 263 | 264 | print("Model = %s" % str(model_without_ddp)) 265 | print('number of params (M): %.2f' % (n_parameters / 1.e6)) 266 | 267 | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() 268 | 269 | if args.lr is None: # only base_lr is specified 270 | args.lr = args.blr * eff_batch_size / 256 271 | 272 | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) 273 | print("actual lr: %.2e" % args.lr) 274 | 275 | print("accumulate grad iterations: %d" % args.accum_iter) 276 | print("effective batch size: %d" % eff_batch_size) 277 | 278 | if args.distributed: 279 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 280 | model_without_ddp = model.module 281 | 282 | # build optimizer with layer-wise lr decay (lrd) 283 | param_groups = lrd.param_groups_lrd(model_without_ddp, args.weight_decay, 284 | no_weight_decay_list=model_without_ddp.no_weight_decay(), 285 | layer_decay=args.layer_decay 286 | ) 287 | optimizer = torch.optim.AdamW(param_groups, lr=args.lr) 288 | loss_scaler = NativeScaler() 289 | 290 | if mixup_fn is not None: 291 | # smoothing is handled with mixup label transform 292 | criterion = SoftTargetCrossEntropy() 293 | elif args.smoothing > 0.: 294 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) 295 | else: 296 | criterion = torch.nn.CrossEntropyLoss() 297 | 298 | print("criterion = %s" % str(criterion)) 299 | 300 | misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) 301 | 302 | if args.eval: 303 | test_stats = evaluate(data_loader_val, model, device) 304 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 305 | exit(0) 306 | 307 | print(f"Start training for {args.epochs} epochs") 308 | start_time = time.time() 309 | max_accuracy = 0.0 310 | for epoch in range(args.start_epoch, args.epochs): 311 | if args.distributed: 312 | data_loader_train.sampler.set_epoch(epoch) 313 | train_stats = train_one_epoch( 314 | model, criterion, data_loader_train, 315 | optimizer, device, epoch, loss_scaler, 316 | args.clip_grad, mixup_fn, 317 | log_writer=log_writer, 318 | args=args 319 | ) 320 | if args.output_dir: 321 | misc.save_model( 322 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 323 | loss_scaler=loss_scaler, epoch=epoch) 324 | 325 | test_stats = evaluate(data_loader_val, model, device) 326 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 327 | max_accuracy = max(max_accuracy, test_stats["acc1"]) 328 | print(f'Max accuracy: {max_accuracy:.2f}%') 329 | 330 | if log_writer is not None: 331 | log_writer.add_scalar('perf/test_acc1', test_stats['acc1'], epoch) 332 | log_writer.add_scalar('perf/test_acc5', test_stats['acc5'], epoch) 333 | log_writer.add_scalar('perf/test_loss', test_stats['loss'], epoch) 334 | 335 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 336 | **{f'test_{k}': v for k, v in test_stats.items()}, 337 | 'epoch': epoch, 338 | 'n_parameters': n_parameters} 339 | 340 | if args.output_dir and misc.is_main_process(): 341 | if log_writer is not None: 342 | log_writer.flush() 343 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: 344 | f.write(json.dumps(log_stats) + "\n") 345 | 346 | total_time = time.time() - start_time 347 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 348 | print('Training time {}'.format(total_time_str)) 349 | 350 | 351 | if __name__ == '__main__': 352 | args = get_args_parser() 353 | args = args.parse_args() 354 | if args.output_dir: 355 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 356 | main(args) 357 | -------------------------------------------------------------------------------- /main_linprobe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # MoCo v3: https://github.com/facebookresearch/moco-v3 10 | # -------------------------------------------------------- 11 | 12 | import argparse 13 | import datetime 14 | import json 15 | import numpy as np 16 | import os 17 | import time 18 | from pathlib import Path 19 | 20 | import torch 21 | import torch.backends.cudnn as cudnn 22 | from torch.utils.tensorboard import SummaryWriter 23 | import torchvision.transforms as transforms 24 | import torchvision.datasets as datasets 25 | 26 | import timm 27 | 28 | assert timm.__version__ == "0.3.2" # version check 29 | from timm.models.layers import trunc_normal_ 30 | 31 | import util.misc as misc 32 | from util.pos_embed import interpolate_pos_embed 33 | from util.misc import NativeScalerWithGradNormCount as NativeScaler 34 | from util.lars import LARS 35 | from util.crop import RandomResizedCrop 36 | 37 | import models_vit 38 | 39 | from engine_finetune import train_one_epoch, evaluate 40 | 41 | 42 | def get_args_parser(): 43 | parser = argparse.ArgumentParser('MAE linear probing for image classification', add_help=False) 44 | parser.add_argument('--batch_size', default=512, type=int, 45 | help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') 46 | parser.add_argument('--epochs', default=90, type=int) 47 | parser.add_argument('--accum_iter', default=1, type=int, 48 | help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') 49 | 50 | # Model parameters 51 | parser.add_argument('--model', default='vit_large_patch16', type=str, metavar='MODEL', 52 | help='Name of model to train') 53 | 54 | # Optimizer parameters 55 | parser.add_argument('--weight_decay', type=float, default=0, 56 | help='weight decay (default: 0 for linear probe following MoCo v1)') 57 | 58 | parser.add_argument('--lr', type=float, default=None, metavar='LR', 59 | help='learning rate (absolute lr)') 60 | parser.add_argument('--blr', type=float, default=0.1, metavar='LR', 61 | help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') 62 | 63 | parser.add_argument('--min_lr', type=float, default=0., metavar='LR', 64 | help='lower lr bound for cyclic schedulers that hit 0') 65 | 66 | parser.add_argument('--warmup_epochs', type=int, default=10, metavar='N', 67 | help='epochs to warmup LR') 68 | 69 | # * Finetuning params 70 | parser.add_argument('--finetune', default='', 71 | help='finetune from checkpoint') 72 | parser.add_argument('--global_pool', action='store_true') 73 | parser.set_defaults(global_pool=False) 74 | parser.add_argument('--cls_token', action='store_false', dest='global_pool', 75 | help='Use class token instead of global pool for classification') 76 | 77 | # Dataset parameters 78 | parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str, 79 | help='dataset path') 80 | parser.add_argument('--nb_classes', default=1000, type=int, 81 | help='number of the classification types') 82 | 83 | parser.add_argument('--output_dir', default='./output_dir', 84 | help='path where to save, empty for no saving') 85 | parser.add_argument('--log_dir', default='./output_dir', 86 | help='path where to tensorboard log') 87 | parser.add_argument('--device', default='cuda', 88 | help='device to use for training / testing') 89 | parser.add_argument('--seed', default=0, type=int) 90 | parser.add_argument('--resume', default='', 91 | help='resume from checkpoint') 92 | 93 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 94 | help='start epoch') 95 | parser.add_argument('--eval', action='store_true', 96 | help='Perform evaluation only') 97 | parser.add_argument('--dist_eval', action='store_true', default=False, 98 | help='Enabling distributed evaluation (recommended during training for faster monitor') 99 | parser.add_argument('--num_workers', default=10, type=int) 100 | parser.add_argument('--pin_mem', action='store_true', 101 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 102 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') 103 | parser.set_defaults(pin_mem=True) 104 | 105 | # distributed training parameters 106 | parser.add_argument('--world_size', default=1, type=int, 107 | help='number of distributed processes') 108 | parser.add_argument('--local_rank', default=-1, type=int) 109 | parser.add_argument('--dist_on_itp', action='store_true') 110 | parser.add_argument('--dist_url', default='env://', 111 | help='url used to set up distributed training') 112 | 113 | return parser 114 | 115 | 116 | def main(args): 117 | misc.init_distributed_mode(args) 118 | 119 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) 120 | print("{}".format(args).replace(', ', ',\n')) 121 | 122 | device = torch.device(args.device) 123 | 124 | # fix the seed for reproducibility 125 | seed = args.seed + misc.get_rank() 126 | torch.manual_seed(seed) 127 | np.random.seed(seed) 128 | 129 | cudnn.benchmark = True 130 | 131 | # linear probe: weak augmentation 132 | transform_train = transforms.Compose([ 133 | RandomResizedCrop(224, interpolation=3), 134 | transforms.RandomHorizontalFlip(), 135 | transforms.ToTensor(), 136 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 137 | transform_val = transforms.Compose([ 138 | transforms.Resize(256, interpolation=3), 139 | transforms.CenterCrop(224), 140 | transforms.ToTensor(), 141 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 142 | dataset_train = datasets.ImageFolder(os.path.join(args.data_path, 'train'), transform=transform_train) 143 | dataset_val = datasets.ImageFolder(os.path.join(args.data_path, 'val'), transform=transform_val) 144 | print(dataset_train) 145 | print(dataset_val) 146 | 147 | if True: # args.distributed: 148 | num_tasks = misc.get_world_size() 149 | global_rank = misc.get_rank() 150 | sampler_train = torch.utils.data.DistributedSampler( 151 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 152 | ) 153 | print("Sampler_train = %s" % str(sampler_train)) 154 | if args.dist_eval: 155 | if len(dataset_val) % num_tasks != 0: 156 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 157 | 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 158 | 'equal num of samples per-process.') 159 | sampler_val = torch.utils.data.DistributedSampler( 160 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True) # shuffle=True to reduce monitor bias 161 | else: 162 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 163 | else: 164 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 165 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 166 | 167 | if global_rank == 0 and args.log_dir is not None and not args.eval: 168 | os.makedirs(args.log_dir, exist_ok=True) 169 | log_writer = SummaryWriter(log_dir=args.log_dir) 170 | else: 171 | log_writer = None 172 | 173 | data_loader_train = torch.utils.data.DataLoader( 174 | dataset_train, sampler=sampler_train, 175 | batch_size=args.batch_size, 176 | num_workers=args.num_workers, 177 | pin_memory=args.pin_mem, 178 | drop_last=True, 179 | ) 180 | 181 | data_loader_val = torch.utils.data.DataLoader( 182 | dataset_val, sampler=sampler_val, 183 | batch_size=args.batch_size, 184 | num_workers=args.num_workers, 185 | pin_memory=args.pin_mem, 186 | drop_last=False 187 | ) 188 | 189 | model = models_vit.__dict__[args.model]( 190 | num_classes=args.nb_classes, 191 | global_pool=args.global_pool, 192 | ) 193 | 194 | if args.finetune and not args.eval: 195 | checkpoint = torch.load(args.finetune, map_location='cpu') 196 | 197 | print("Load pre-trained checkpoint from: %s" % args.finetune) 198 | checkpoint_model = checkpoint['model'] 199 | state_dict = model.state_dict() 200 | for k in ['head.weight', 'head.bias']: 201 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: 202 | print(f"Removing key {k} from pretrained checkpoint") 203 | del checkpoint_model[k] 204 | 205 | # interpolate position embedding 206 | interpolate_pos_embed(model, checkpoint_model) 207 | 208 | # load pre-trained model 209 | msg = model.load_state_dict(checkpoint_model, strict=False) 210 | print(msg) 211 | 212 | if args.global_pool: 213 | assert set(msg.missing_keys) == {'head.weight', 'head.bias', 'fc_norm.weight', 'fc_norm.bias'} 214 | else: 215 | assert set(msg.missing_keys) == {'head.weight', 'head.bias'} 216 | 217 | # manually initialize fc layer: following MoCo v3 218 | trunc_normal_(model.head.weight, std=0.01) 219 | 220 | # for linear prob only 221 | # hack: revise model's head with BN 222 | model.head = torch.nn.Sequential(torch.nn.BatchNorm1d(model.head.in_features, affine=False, eps=1e-6), model.head) 223 | # freeze all but the head 224 | for _, p in model.named_parameters(): 225 | p.requires_grad = False 226 | for _, p in model.head.named_parameters(): 227 | p.requires_grad = True 228 | 229 | model.to(device) 230 | 231 | model_without_ddp = model 232 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 233 | 234 | print("Model = %s" % str(model_without_ddp)) 235 | print('number of params (M): %.2f' % (n_parameters / 1.e6)) 236 | 237 | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() 238 | 239 | if args.lr is None: # only base_lr is specified 240 | args.lr = args.blr * eff_batch_size / 256 241 | 242 | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) 243 | print("actual lr: %.2e" % args.lr) 244 | 245 | print("accumulate grad iterations: %d" % args.accum_iter) 246 | print("effective batch size: %d" % eff_batch_size) 247 | 248 | if args.distributed: 249 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 250 | model_without_ddp = model.module 251 | 252 | optimizer = LARS(model_without_ddp.head.parameters(), lr=args.lr, weight_decay=args.weight_decay) 253 | print(optimizer) 254 | loss_scaler = NativeScaler() 255 | 256 | criterion = torch.nn.CrossEntropyLoss() 257 | 258 | print("criterion = %s" % str(criterion)) 259 | 260 | misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) 261 | 262 | if args.eval: 263 | test_stats = evaluate(data_loader_val, model, device) 264 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 265 | exit(0) 266 | 267 | print(f"Start training for {args.epochs} epochs") 268 | start_time = time.time() 269 | max_accuracy = 0.0 270 | for epoch in range(args.start_epoch, args.epochs): 271 | if args.distributed: 272 | data_loader_train.sampler.set_epoch(epoch) 273 | train_stats = train_one_epoch( 274 | model, criterion, data_loader_train, 275 | optimizer, device, epoch, loss_scaler, 276 | max_norm=None, 277 | log_writer=log_writer, 278 | args=args 279 | ) 280 | if args.output_dir: 281 | misc.save_model( 282 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 283 | loss_scaler=loss_scaler, epoch=epoch) 284 | 285 | test_stats = evaluate(data_loader_val, model, device) 286 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 287 | max_accuracy = max(max_accuracy, test_stats["acc1"]) 288 | print(f'Max accuracy: {max_accuracy:.2f}%') 289 | 290 | if log_writer is not None: 291 | log_writer.add_scalar('perf/test_acc1', test_stats['acc1'], epoch) 292 | log_writer.add_scalar('perf/test_acc5', test_stats['acc5'], epoch) 293 | log_writer.add_scalar('perf/test_loss', test_stats['loss'], epoch) 294 | 295 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 296 | **{f'test_{k}': v for k, v in test_stats.items()}, 297 | 'epoch': epoch, 298 | 'n_parameters': n_parameters} 299 | 300 | if args.output_dir and misc.is_main_process(): 301 | if log_writer is not None: 302 | log_writer.flush() 303 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: 304 | f.write(json.dumps(log_stats) + "\n") 305 | 306 | total_time = time.time() - start_time 307 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 308 | print('Training time {}'.format(total_time_str)) 309 | 310 | 311 | if __name__ == '__main__': 312 | args = get_args_parser() 313 | args = args.parse_args() 314 | if args.output_dir: 315 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 316 | main(args) 317 | -------------------------------------------------------------------------------- /main_pretrain.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | import argparse 12 | import datetime 13 | import json 14 | import numpy as np 15 | import os 16 | import time 17 | from pathlib import Path 18 | 19 | import torch 20 | import torch.backends.cudnn as cudnn 21 | from torch.utils.tensorboard import SummaryWriter 22 | import torchvision.transforms as transforms 23 | import torchvision.datasets as datasets 24 | 25 | import timm 26 | 27 | assert timm.__version__ == "0.3.2" # version check 28 | import timm.optim.optim_factory as optim_factory 29 | 30 | import util.misc as misc 31 | from util.misc import NativeScalerWithGradNormCount as NativeScaler 32 | 33 | import models_mae 34 | 35 | from engine_pretrain import train_one_epoch 36 | 37 | 38 | def get_args_parser(): 39 | parser = argparse.ArgumentParser('MAE pre-training', add_help=False) 40 | parser.add_argument('--batch_size', default=64, type=int, 41 | help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') 42 | parser.add_argument('--epochs', default=400, type=int) 43 | parser.add_argument('--accum_iter', default=1, type=int, 44 | help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') 45 | 46 | # Model parameters 47 | parser.add_argument('--model', default='mae_vit_large_patch16', type=str, metavar='MODEL', 48 | help='Name of model to train') 49 | 50 | parser.add_argument('--input_size', default=224, type=int, 51 | help='images input size') 52 | 53 | parser.add_argument('--mask_ratio', default=0.75, type=float, 54 | help='Masking ratio (percentage of removed patches).') 55 | 56 | parser.add_argument('--norm_pix_loss', action='store_true', 57 | help='Use (per-patch) normalized pixels as targets for computing loss') 58 | parser.set_defaults(norm_pix_loss=False) 59 | 60 | # Optimizer parameters 61 | parser.add_argument('--weight_decay', type=float, default=0.05, 62 | help='weight decay (default: 0.05)') 63 | 64 | parser.add_argument('--lr', type=float, default=None, metavar='LR', 65 | help='learning rate (absolute lr)') 66 | parser.add_argument('--blr', type=float, default=1e-3, metavar='LR', 67 | help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') 68 | parser.add_argument('--min_lr', type=float, default=0., metavar='LR', 69 | help='lower lr bound for cyclic schedulers that hit 0') 70 | 71 | parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N', 72 | help='epochs to warmup LR') 73 | 74 | # Dataset parameters 75 | parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str, 76 | help='dataset path') 77 | 78 | parser.add_argument('--output_dir', default='./output_dir', 79 | help='path where to save, empty for no saving') 80 | parser.add_argument('--log_dir', default='./output_dir', 81 | help='path where to tensorboard log') 82 | parser.add_argument('--device', default='cuda', 83 | help='device to use for training / testing') 84 | parser.add_argument('--seed', default=0, type=int) 85 | parser.add_argument('--resume', default='', 86 | help='resume from checkpoint') 87 | 88 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 89 | help='start epoch') 90 | parser.add_argument('--num_workers', default=10, type=int) 91 | parser.add_argument('--pin_mem', action='store_true', 92 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 93 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') 94 | parser.set_defaults(pin_mem=True) 95 | 96 | # distributed training parameters 97 | parser.add_argument('--world_size', default=1, type=int, 98 | help='number of distributed processes') 99 | parser.add_argument('--local_rank', default=-1, type=int) 100 | parser.add_argument('--dist_on_itp', action='store_true') 101 | parser.add_argument('--dist_url', default='env://', 102 | help='url used to set up distributed training') 103 | 104 | return parser 105 | 106 | 107 | def main(args): 108 | misc.init_distributed_mode(args) 109 | 110 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) 111 | print("{}".format(args).replace(', ', ',\n')) 112 | 113 | device = torch.device(args.device) 114 | 115 | # fix the seed for reproducibility 116 | seed = args.seed + misc.get_rank() 117 | torch.manual_seed(seed) 118 | np.random.seed(seed) 119 | 120 | cudnn.benchmark = True 121 | 122 | # simple augmentation 123 | transform_train = transforms.Compose([ 124 | transforms.RandomResizedCrop(args.input_size, scale=(0.2, 1.0), interpolation=3), # 3 is bicubic 125 | transforms.RandomHorizontalFlip(), 126 | transforms.ToTensor(), 127 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 128 | dataset_train = datasets.ImageFolder(os.path.join(args.data_path, 'train'), transform=transform_train) 129 | print(dataset_train) 130 | 131 | if True: # args.distributed: 132 | num_tasks = misc.get_world_size() 133 | global_rank = misc.get_rank() 134 | sampler_train = torch.utils.data.DistributedSampler( 135 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 136 | ) 137 | print("Sampler_train = %s" % str(sampler_train)) 138 | else: 139 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 140 | 141 | if global_rank == 0 and args.log_dir is not None: 142 | os.makedirs(args.log_dir, exist_ok=True) 143 | log_writer = SummaryWriter(log_dir=args.log_dir) 144 | else: 145 | log_writer = None 146 | 147 | data_loader_train = torch.utils.data.DataLoader( 148 | dataset_train, sampler=sampler_train, 149 | batch_size=args.batch_size, 150 | num_workers=args.num_workers, 151 | pin_memory=args.pin_mem, 152 | drop_last=True, 153 | ) 154 | 155 | # define the model 156 | model = models_mae.__dict__[args.model](norm_pix_loss=args.norm_pix_loss) 157 | 158 | model.to(device) 159 | 160 | model_without_ddp = model 161 | print("Model = %s" % str(model_without_ddp)) 162 | 163 | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() 164 | 165 | if args.lr is None: # only base_lr is specified 166 | args.lr = args.blr * eff_batch_size / 256 167 | 168 | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) 169 | print("actual lr: %.2e" % args.lr) 170 | 171 | print("accumulate grad iterations: %d" % args.accum_iter) 172 | print("effective batch size: %d" % eff_batch_size) 173 | 174 | if args.distributed: 175 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True) 176 | model_without_ddp = model.module 177 | 178 | # following timm: set wd as 0 for bias and norm layers 179 | param_groups = optim_factory.add_weight_decay(model_without_ddp, args.weight_decay) 180 | optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95)) 181 | print(optimizer) 182 | loss_scaler = NativeScaler() 183 | 184 | misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) 185 | 186 | print(f"Start training for {args.epochs} epochs") 187 | start_time = time.time() 188 | for epoch in range(args.start_epoch, args.epochs): 189 | if args.distributed: 190 | data_loader_train.sampler.set_epoch(epoch) 191 | train_stats = train_one_epoch( 192 | model, data_loader_train, 193 | optimizer, device, epoch, loss_scaler, 194 | log_writer=log_writer, 195 | args=args 196 | ) 197 | if args.output_dir and (epoch % 20 == 0 or epoch + 1 == args.epochs): 198 | misc.save_model( 199 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 200 | loss_scaler=loss_scaler, epoch=epoch) 201 | 202 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 203 | 'epoch': epoch,} 204 | 205 | if args.output_dir and misc.is_main_process(): 206 | if log_writer is not None: 207 | log_writer.flush() 208 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: 209 | f.write(json.dumps(log_stats) + "\n") 210 | 211 | total_time = time.time() - start_time 212 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 213 | print('Training time {}'.format(total_time_str)) 214 | 215 | 216 | if __name__ == '__main__': 217 | args = get_args_parser() 218 | args = args.parse_args() 219 | if args.output_dir: 220 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 221 | main(args) 222 | -------------------------------------------------------------------------------- /models_mae.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | # DeiT: https://github.com/facebookresearch/deit 10 | # -------------------------------------------------------- 11 | 12 | from functools import partial 13 | 14 | import torch 15 | import torch.nn as nn 16 | 17 | from timm.models.vision_transformer import PatchEmbed, Block 18 | 19 | from util.pos_embed import get_2d_sincos_pos_embed 20 | 21 | 22 | class MaskedAutoencoderViT(nn.Module): 23 | """ Masked Autoencoder with VisionTransformer backbone 24 | """ 25 | def __init__(self, img_size=224, patch_size=16, in_chans=3, 26 | embed_dim=1024, depth=24, num_heads=16, 27 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 28 | mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False): 29 | super().__init__() 30 | 31 | # -------------------------------------------------------------------------- 32 | # MAE encoder specifics 33 | self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) 34 | num_patches = self.patch_embed.num_patches 35 | 36 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 37 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding 38 | 39 | self.blocks = nn.ModuleList([ 40 | Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) 41 | for i in range(depth)]) 42 | self.norm = norm_layer(embed_dim) 43 | # -------------------------------------------------------------------------- 44 | 45 | # -------------------------------------------------------------------------- 46 | # MAE decoder specifics 47 | self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) 48 | 49 | self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) 50 | 51 | self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding 52 | 53 | self.decoder_blocks = nn.ModuleList([ 54 | Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) 55 | for i in range(decoder_depth)]) 56 | 57 | self.decoder_norm = norm_layer(decoder_embed_dim) 58 | self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch 59 | # -------------------------------------------------------------------------- 60 | 61 | self.norm_pix_loss = norm_pix_loss 62 | 63 | self.initialize_weights() 64 | 65 | def initialize_weights(self): 66 | # initialization 67 | # initialize (and freeze) pos_embed by sin-cos embedding 68 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) 69 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 70 | 71 | decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) 72 | self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) 73 | 74 | # initialize patch_embed like nn.Linear (instead of nn.Conv2d) 75 | w = self.patch_embed.proj.weight.data 76 | torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 77 | 78 | # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) 79 | torch.nn.init.normal_(self.cls_token, std=.02) 80 | torch.nn.init.normal_(self.mask_token, std=.02) 81 | 82 | # initialize nn.Linear and nn.LayerNorm 83 | self.apply(self._init_weights) 84 | 85 | def _init_weights(self, m): 86 | if isinstance(m, nn.Linear): 87 | # we use xavier_uniform following official JAX ViT: 88 | torch.nn.init.xavier_uniform_(m.weight) 89 | if isinstance(m, nn.Linear) and m.bias is not None: 90 | nn.init.constant_(m.bias, 0) 91 | elif isinstance(m, nn.LayerNorm): 92 | nn.init.constant_(m.bias, 0) 93 | nn.init.constant_(m.weight, 1.0) 94 | 95 | def patchify(self, imgs): 96 | """ 97 | imgs: (N, 3, H, W) 98 | x: (N, L, patch_size**2 *3) 99 | """ 100 | p = self.patch_embed.patch_size[0] 101 | assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 102 | 103 | h = w = imgs.shape[2] // p 104 | x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) 105 | x = torch.einsum('nchpwq->nhwpqc', x) 106 | x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) 107 | return x 108 | 109 | def unpatchify(self, x): 110 | """ 111 | x: (N, L, patch_size**2 *3) 112 | imgs: (N, 3, H, W) 113 | """ 114 | p = self.patch_embed.patch_size[0] 115 | h = w = int(x.shape[1]**.5) 116 | assert h * w == x.shape[1] 117 | 118 | x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) 119 | x = torch.einsum('nhwpqc->nchpwq', x) 120 | imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) 121 | return imgs 122 | 123 | def random_masking(self, x, mask_ratio): 124 | """ 125 | Perform per-sample random masking by per-sample shuffling. 126 | Per-sample shuffling is done by argsort random noise. 127 | x: [N, L, D], sequence 128 | """ 129 | N, L, D = x.shape # batch, length, dim 130 | len_keep = int(L * (1 - mask_ratio)) 131 | 132 | noise = torch.rand(N, L, device=x.device) # noise in [0, 1] 133 | 134 | # sort noise for each sample 135 | ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove 136 | ids_restore = torch.argsort(ids_shuffle, dim=1) 137 | 138 | # keep the first subset 139 | ids_keep = ids_shuffle[:, :len_keep] 140 | x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) 141 | 142 | # generate the binary mask: 0 is keep, 1 is remove 143 | mask = torch.ones([N, L], device=x.device) 144 | mask[:, :len_keep] = 0 145 | # unshuffle to get the binary mask 146 | mask = torch.gather(mask, dim=1, index=ids_restore) 147 | 148 | return x_masked, mask, ids_restore 149 | 150 | def forward_encoder(self, x, mask_ratio): 151 | # embed patches 152 | x = self.patch_embed(x) 153 | 154 | # add pos embed w/o cls token 155 | x = x + self.pos_embed[:, 1:, :] 156 | 157 | # masking: length -> length * mask_ratio 158 | x, mask, ids_restore = self.random_masking(x, mask_ratio) 159 | 160 | # append cls token 161 | cls_token = self.cls_token + self.pos_embed[:, :1, :] 162 | cls_tokens = cls_token.expand(x.shape[0], -1, -1) 163 | x = torch.cat((cls_tokens, x), dim=1) 164 | 165 | # apply Transformer blocks 166 | for blk in self.blocks: 167 | x = blk(x) 168 | x = self.norm(x) 169 | 170 | return x, mask, ids_restore 171 | 172 | def forward_decoder(self, x, ids_restore): 173 | # embed tokens 174 | x = self.decoder_embed(x) 175 | 176 | # append mask tokens to sequence 177 | mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) 178 | x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token 179 | x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle 180 | x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token 181 | 182 | # add pos embed 183 | x = x + self.decoder_pos_embed 184 | 185 | # apply Transformer blocks 186 | for blk in self.decoder_blocks: 187 | x = blk(x) 188 | x = self.decoder_norm(x) 189 | 190 | # predictor projection 191 | x = self.decoder_pred(x) 192 | 193 | # remove cls token 194 | x = x[:, 1:, :] 195 | 196 | return x 197 | 198 | def forward_loss(self, imgs, pred, mask): 199 | """ 200 | imgs: [N, 3, H, W] 201 | pred: [N, L, p*p*3] 202 | mask: [N, L], 0 is keep, 1 is remove, 203 | """ 204 | target = self.patchify(imgs) 205 | if self.norm_pix_loss: 206 | mean = target.mean(dim=-1, keepdim=True) 207 | var = target.var(dim=-1, keepdim=True) 208 | target = (target - mean) / (var + 1.e-6)**.5 209 | 210 | loss = (pred - target) ** 2 211 | loss = loss.mean(dim=-1) # [N, L], mean loss per patch 212 | 213 | loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches 214 | return loss 215 | 216 | def forward(self, imgs, mask_ratio=0.75): 217 | latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio) 218 | pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3] 219 | loss = self.forward_loss(imgs, pred, mask) 220 | return loss, pred, mask 221 | 222 | 223 | def mae_vit_base_patch16_dec512d8b(**kwargs): 224 | model = MaskedAutoencoderViT( 225 | patch_size=16, embed_dim=768, depth=12, num_heads=12, 226 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 227 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 228 | return model 229 | 230 | 231 | def mae_vit_large_patch16_dec512d8b(**kwargs): 232 | model = MaskedAutoencoderViT( 233 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, 234 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 235 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 236 | return model 237 | 238 | 239 | def mae_vit_huge_patch14_dec512d8b(**kwargs): 240 | model = MaskedAutoencoderViT( 241 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, 242 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 243 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 244 | return model 245 | 246 | 247 | # set recommended archs 248 | mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b # decoder: 512 dim, 8 blocks 249 | mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b # decoder: 512 dim, 8 blocks 250 | mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b # decoder: 512 dim, 8 blocks 251 | -------------------------------------------------------------------------------- /models_vit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | # DeiT: https://github.com/facebookresearch/deit 10 | # -------------------------------------------------------- 11 | 12 | from functools import partial 13 | 14 | import torch 15 | import torch.nn as nn 16 | 17 | import timm.models.vision_transformer 18 | 19 | 20 | class VisionTransformer(timm.models.vision_transformer.VisionTransformer): 21 | """ Vision Transformer with support for global average pooling 22 | """ 23 | def __init__(self, global_pool=False, **kwargs): 24 | super(VisionTransformer, self).__init__(**kwargs) 25 | 26 | self.global_pool = global_pool 27 | if self.global_pool: 28 | norm_layer = kwargs['norm_layer'] 29 | embed_dim = kwargs['embed_dim'] 30 | self.fc_norm = norm_layer(embed_dim) 31 | 32 | del self.norm # remove the original norm 33 | 34 | def forward_features(self, x): 35 | B = x.shape[0] 36 | x = self.patch_embed(x) 37 | 38 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 39 | x = torch.cat((cls_tokens, x), dim=1) 40 | x = x + self.pos_embed 41 | x = self.pos_drop(x) 42 | 43 | for blk in self.blocks: 44 | x = blk(x) 45 | 46 | if self.global_pool: 47 | x = x[:, 1:, :].mean(dim=1) # global pool without cls token 48 | outcome = self.fc_norm(x) 49 | else: 50 | x = self.norm(x) 51 | outcome = x[:, 0] 52 | 53 | return outcome 54 | 55 | 56 | def vit_base_patch16(**kwargs): 57 | model = VisionTransformer( 58 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 59 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 60 | return model 61 | 62 | 63 | def vit_large_patch16(**kwargs): 64 | model = VisionTransformer( 65 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 66 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 67 | return model 68 | 69 | 70 | def vit_huge_patch14(**kwargs): 71 | model = VisionTransformer( 72 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True, 73 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 74 | return model -------------------------------------------------------------------------------- /submitit_finetune.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # A script to run multinode training with submitit. 8 | # -------------------------------------------------------- 9 | 10 | import argparse 11 | import os 12 | import uuid 13 | from pathlib import Path 14 | 15 | import main_finetune as classification 16 | import submitit 17 | 18 | 19 | def parse_args(): 20 | classification_parser = classification.get_args_parser() 21 | parser = argparse.ArgumentParser("Submitit for MAE finetune", parents=[classification_parser]) 22 | parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node") 23 | parser.add_argument("--nodes", default=2, type=int, help="Number of nodes to request") 24 | parser.add_argument("--timeout", default=4320, type=int, help="Duration of the job") 25 | parser.add_argument("--job_dir", default="", type=str, help="Job dir. Leave empty for automatic.") 26 | 27 | parser.add_argument("--partition", default="learnfair", type=str, help="Partition where to submit") 28 | parser.add_argument("--use_volta32", action='store_true', help="Request 32G V100 GPUs") 29 | parser.add_argument('--comment', default="", type=str, help="Comment to pass to scheduler") 30 | return parser.parse_args() 31 | 32 | 33 | def get_shared_folder() -> Path: 34 | user = os.getenv("USER") 35 | if Path("/checkpoint/").is_dir(): 36 | p = Path(f"/checkpoint/{user}/experiments") 37 | p.mkdir(exist_ok=True) 38 | return p 39 | raise RuntimeError("No shared folder available") 40 | 41 | 42 | def get_init_file(): 43 | # Init file must not exist, but it's parent dir must exist. 44 | os.makedirs(str(get_shared_folder()), exist_ok=True) 45 | init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init" 46 | if init_file.exists(): 47 | os.remove(str(init_file)) 48 | return init_file 49 | 50 | 51 | class Trainer(object): 52 | def __init__(self, args): 53 | self.args = args 54 | 55 | def __call__(self): 56 | import main_finetune as classification 57 | 58 | self._setup_gpu_args() 59 | classification.main(self.args) 60 | 61 | def checkpoint(self): 62 | import os 63 | import submitit 64 | 65 | self.args.dist_url = get_init_file().as_uri() 66 | checkpoint_file = os.path.join(self.args.output_dir, "checkpoint.pth") 67 | if os.path.exists(checkpoint_file): 68 | self.args.resume = checkpoint_file 69 | print("Requeuing ", self.args) 70 | empty_trainer = type(self)(self.args) 71 | return submitit.helpers.DelayedSubmission(empty_trainer) 72 | 73 | def _setup_gpu_args(self): 74 | import submitit 75 | from pathlib import Path 76 | 77 | job_env = submitit.JobEnvironment() 78 | self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id))) 79 | self.args.log_dir = self.args.output_dir 80 | self.args.gpu = job_env.local_rank 81 | self.args.rank = job_env.global_rank 82 | self.args.world_size = job_env.num_tasks 83 | print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") 84 | 85 | 86 | def main(): 87 | args = parse_args() 88 | if args.job_dir == "": 89 | args.job_dir = get_shared_folder() / "%j" 90 | 91 | # Note that the folder will depend on the job_id, to easily track experiments 92 | executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30) 93 | 94 | num_gpus_per_node = args.ngpus 95 | nodes = args.nodes 96 | timeout_min = args.timeout 97 | 98 | partition = args.partition 99 | kwargs = {} 100 | if args.use_volta32: 101 | kwargs['slurm_constraint'] = 'volta32gb' 102 | if args.comment: 103 | kwargs['slurm_comment'] = args.comment 104 | 105 | executor.update_parameters( 106 | mem_gb=40 * num_gpus_per_node, 107 | gpus_per_node=num_gpus_per_node, 108 | tasks_per_node=num_gpus_per_node, # one task per GPU 109 | cpus_per_task=10, 110 | nodes=nodes, 111 | timeout_min=timeout_min, 112 | # Below are cluster dependent parameters 113 | slurm_partition=partition, 114 | slurm_signal_delay_s=120, 115 | **kwargs 116 | ) 117 | 118 | executor.update_parameters(name="mae") 119 | 120 | args.dist_url = get_init_file().as_uri() 121 | args.output_dir = args.job_dir 122 | 123 | trainer = Trainer(args) 124 | job = executor.submit(trainer) 125 | 126 | # print("Submitted job_id:", job.job_id) 127 | print(job.job_id) 128 | 129 | 130 | if __name__ == "__main__": 131 | main() 132 | -------------------------------------------------------------------------------- /submitit_linprobe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # A script to run multinode training with submitit. 8 | # -------------------------------------------------------- 9 | 10 | import argparse 11 | import os 12 | import uuid 13 | from pathlib import Path 14 | 15 | import main_linprobe as classification 16 | import submitit 17 | 18 | 19 | def parse_args(): 20 | classification_parser = classification.get_args_parser() 21 | parser = argparse.ArgumentParser("Submitit for MAE linear probe", parents=[classification_parser]) 22 | parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node") 23 | parser.add_argument("--nodes", default=2, type=int, help="Number of nodes to request") 24 | parser.add_argument("--timeout", default=4320, type=int, help="Duration of the job") 25 | parser.add_argument("--job_dir", default="", type=str, help="Job dir. Leave empty for automatic.") 26 | 27 | parser.add_argument("--partition", default="learnfair", type=str, help="Partition where to submit") 28 | parser.add_argument("--use_volta32", action='store_true', help="Request 32G V100 GPUs") 29 | parser.add_argument('--comment', default="", type=str, help="Comment to pass to scheduler") 30 | return parser.parse_args() 31 | 32 | 33 | def get_shared_folder() -> Path: 34 | user = os.getenv("USER") 35 | if Path("/checkpoint/").is_dir(): 36 | p = Path(f"/checkpoint/{user}/experiments") 37 | p.mkdir(exist_ok=True) 38 | return p 39 | raise RuntimeError("No shared folder available") 40 | 41 | 42 | def get_init_file(): 43 | # Init file must not exist, but it's parent dir must exist. 44 | os.makedirs(str(get_shared_folder()), exist_ok=True) 45 | init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init" 46 | if init_file.exists(): 47 | os.remove(str(init_file)) 48 | return init_file 49 | 50 | 51 | class Trainer(object): 52 | def __init__(self, args): 53 | self.args = args 54 | 55 | def __call__(self): 56 | import main_linprobe as classification 57 | 58 | self._setup_gpu_args() 59 | classification.main(self.args) 60 | 61 | def checkpoint(self): 62 | import os 63 | import submitit 64 | 65 | self.args.dist_url = get_init_file().as_uri() 66 | checkpoint_file = os.path.join(self.args.output_dir, "checkpoint.pth") 67 | if os.path.exists(checkpoint_file): 68 | self.args.resume = checkpoint_file 69 | print("Requeuing ", self.args) 70 | empty_trainer = type(self)(self.args) 71 | return submitit.helpers.DelayedSubmission(empty_trainer) 72 | 73 | def _setup_gpu_args(self): 74 | import submitit 75 | from pathlib import Path 76 | 77 | job_env = submitit.JobEnvironment() 78 | self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id))) 79 | self.args.log_dir = self.args.output_dir 80 | self.args.gpu = job_env.local_rank 81 | self.args.rank = job_env.global_rank 82 | self.args.world_size = job_env.num_tasks 83 | print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") 84 | 85 | 86 | def main(): 87 | args = parse_args() 88 | if args.job_dir == "": 89 | args.job_dir = get_shared_folder() / "%j" 90 | 91 | # Note that the folder will depend on the job_id, to easily track experiments 92 | executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30) 93 | 94 | num_gpus_per_node = args.ngpus 95 | nodes = args.nodes 96 | timeout_min = args.timeout 97 | 98 | partition = args.partition 99 | kwargs = {} 100 | if args.use_volta32: 101 | kwargs['slurm_constraint'] = 'volta32gb' 102 | if args.comment: 103 | kwargs['slurm_comment'] = args.comment 104 | 105 | executor.update_parameters( 106 | mem_gb=40 * num_gpus_per_node, 107 | gpus_per_node=num_gpus_per_node, 108 | tasks_per_node=num_gpus_per_node, # one task per GPU 109 | cpus_per_task=10, 110 | nodes=nodes, 111 | timeout_min=timeout_min, 112 | # Below are cluster dependent parameters 113 | slurm_partition=partition, 114 | slurm_signal_delay_s=120, 115 | **kwargs 116 | ) 117 | 118 | executor.update_parameters(name="mae") 119 | 120 | args.dist_url = get_init_file().as_uri() 121 | args.output_dir = args.job_dir 122 | 123 | trainer = Trainer(args) 124 | job = executor.submit(trainer) 125 | 126 | # print("Submitted job_id:", job.job_id) 127 | print(job.job_id) 128 | 129 | 130 | if __name__ == "__main__": 131 | main() 132 | -------------------------------------------------------------------------------- /submitit_pretrain.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # A script to run multinode training with submitit. 8 | # -------------------------------------------------------- 9 | 10 | import argparse 11 | import os 12 | import uuid 13 | from pathlib import Path 14 | 15 | import main_pretrain as trainer 16 | import submitit 17 | 18 | 19 | def parse_args(): 20 | trainer_parser = trainer.get_args_parser() 21 | parser = argparse.ArgumentParser("Submitit for MAE pretrain", parents=[trainer_parser]) 22 | parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node") 23 | parser.add_argument("--nodes", default=2, type=int, help="Number of nodes to request") 24 | parser.add_argument("--timeout", default=4320, type=int, help="Duration of the job") 25 | parser.add_argument("--job_dir", default="", type=str, help="Job dir. Leave empty for automatic.") 26 | 27 | parser.add_argument("--partition", default="learnfair", type=str, help="Partition where to submit") 28 | parser.add_argument("--use_volta32", action='store_true', help="Request 32G V100 GPUs") 29 | parser.add_argument('--comment', default="", type=str, help="Comment to pass to scheduler") 30 | return parser.parse_args() 31 | 32 | 33 | def get_shared_folder() -> Path: 34 | user = os.getenv("USER") 35 | if Path("/checkpoint/").is_dir(): 36 | p = Path(f"/checkpoint/{user}/experiments") 37 | p.mkdir(exist_ok=True) 38 | return p 39 | raise RuntimeError("No shared folder available") 40 | 41 | 42 | def get_init_file(): 43 | # Init file must not exist, but it's parent dir must exist. 44 | os.makedirs(str(get_shared_folder()), exist_ok=True) 45 | init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init" 46 | if init_file.exists(): 47 | os.remove(str(init_file)) 48 | return init_file 49 | 50 | 51 | class Trainer(object): 52 | def __init__(self, args): 53 | self.args = args 54 | 55 | def __call__(self): 56 | import main_pretrain as trainer 57 | 58 | self._setup_gpu_args() 59 | trainer.main(self.args) 60 | 61 | def checkpoint(self): 62 | import os 63 | import submitit 64 | 65 | self.args.dist_url = get_init_file().as_uri() 66 | checkpoint_file = os.path.join(self.args.output_dir, "checkpoint.pth") 67 | if os.path.exists(checkpoint_file): 68 | self.args.resume = checkpoint_file 69 | print("Requeuing ", self.args) 70 | empty_trainer = type(self)(self.args) 71 | return submitit.helpers.DelayedSubmission(empty_trainer) 72 | 73 | def _setup_gpu_args(self): 74 | import submitit 75 | from pathlib import Path 76 | 77 | job_env = submitit.JobEnvironment() 78 | self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id))) 79 | self.args.log_dir = self.args.output_dir 80 | self.args.gpu = job_env.local_rank 81 | self.args.rank = job_env.global_rank 82 | self.args.world_size = job_env.num_tasks 83 | print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") 84 | 85 | 86 | def main(): 87 | args = parse_args() 88 | if args.job_dir == "": 89 | args.job_dir = get_shared_folder() / "%j" 90 | 91 | # Note that the folder will depend on the job_id, to easily track experiments 92 | executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30) 93 | 94 | num_gpus_per_node = args.ngpus 95 | nodes = args.nodes 96 | timeout_min = args.timeout 97 | 98 | partition = args.partition 99 | kwargs = {} 100 | if args.use_volta32: 101 | kwargs['slurm_constraint'] = 'volta32gb' 102 | if args.comment: 103 | kwargs['slurm_comment'] = args.comment 104 | 105 | executor.update_parameters( 106 | mem_gb=40 * num_gpus_per_node, 107 | gpus_per_node=num_gpus_per_node, 108 | tasks_per_node=num_gpus_per_node, # one task per GPU 109 | cpus_per_task=10, 110 | nodes=nodes, 111 | timeout_min=timeout_min, # max is 60 * 72 112 | # Below are cluster dependent parameters 113 | slurm_partition=partition, 114 | slurm_signal_delay_s=120, 115 | **kwargs 116 | ) 117 | 118 | executor.update_parameters(name="mae") 119 | 120 | args.dist_url = get_init_file().as_uri() 121 | args.output_dir = args.job_dir 122 | 123 | trainer = Trainer(args) 124 | job = executor.submit(trainer) 125 | 126 | # print("Submitted job_id:", job.job_id) 127 | print(job.job_id) 128 | 129 | 130 | if __name__ == "__main__": 131 | main() 132 | -------------------------------------------------------------------------------- /util/crop.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | import torch 10 | 11 | from torchvision import transforms 12 | from torchvision.transforms import functional as F 13 | 14 | 15 | class RandomResizedCrop(transforms.RandomResizedCrop): 16 | """ 17 | RandomResizedCrop for matching TF/TPU implementation: no for-loop is used. 18 | This may lead to results different with torchvision's version. 19 | Following BYOL's TF code: 20 | https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206 21 | """ 22 | @staticmethod 23 | def get_params(img, scale, ratio): 24 | width, height = F._get_image_size(img) 25 | area = height * width 26 | 27 | target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() 28 | log_ratio = torch.log(torch.tensor(ratio)) 29 | aspect_ratio = torch.exp( 30 | torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) 31 | ).item() 32 | 33 | w = int(round(math.sqrt(target_area * aspect_ratio))) 34 | h = int(round(math.sqrt(target_area / aspect_ratio))) 35 | 36 | w = min(w, width) 37 | h = min(h, height) 38 | 39 | i = torch.randint(0, height - h + 1, size=(1,)).item() 40 | j = torch.randint(0, width - w + 1, size=(1,)).item() 41 | 42 | return i, j, h, w -------------------------------------------------------------------------------- /util/datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # -------------------------------------------------------- 10 | 11 | import os 12 | import PIL 13 | 14 | from torchvision import datasets, transforms 15 | 16 | from timm.data import create_transform 17 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 18 | 19 | 20 | def build_dataset(is_train, args): 21 | transform = build_transform(is_train, args) 22 | 23 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 24 | dataset = datasets.ImageFolder(root, transform=transform) 25 | 26 | print(dataset) 27 | 28 | return dataset 29 | 30 | 31 | def build_transform(is_train, args): 32 | mean = IMAGENET_DEFAULT_MEAN 33 | std = IMAGENET_DEFAULT_STD 34 | # train transform 35 | if is_train: 36 | # this should always dispatch to transforms_imagenet_train 37 | transform = create_transform( 38 | input_size=args.input_size, 39 | is_training=True, 40 | color_jitter=args.color_jitter, 41 | auto_augment=args.aa, 42 | interpolation='bicubic', 43 | re_prob=args.reprob, 44 | re_mode=args.remode, 45 | re_count=args.recount, 46 | mean=mean, 47 | std=std, 48 | ) 49 | return transform 50 | 51 | # eval transform 52 | t = [] 53 | if args.input_size <= 224: 54 | crop_pct = 224 / 256 55 | else: 56 | crop_pct = 1.0 57 | size = int(args.input_size / crop_pct) 58 | t.append( 59 | transforms.Resize(size, interpolation=PIL.Image.BICUBIC), # to maintain same ratio w.r.t. 224 images 60 | ) 61 | t.append(transforms.CenterCrop(args.input_size)) 62 | 63 | t.append(transforms.ToTensor()) 64 | t.append(transforms.Normalize(mean, std)) 65 | return transforms.Compose(t) 66 | -------------------------------------------------------------------------------- /util/lars.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # LARS optimizer, implementation from MoCo v3: 8 | # https://github.com/facebookresearch/moco-v3 9 | # -------------------------------------------------------- 10 | 11 | import torch 12 | 13 | 14 | class LARS(torch.optim.Optimizer): 15 | """ 16 | LARS optimizer, no rate scaling or weight decay for parameters <= 1D. 17 | """ 18 | def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001): 19 | defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient) 20 | super().__init__(params, defaults) 21 | 22 | @torch.no_grad() 23 | def step(self): 24 | for g in self.param_groups: 25 | for p in g['params']: 26 | dp = p.grad 27 | 28 | if dp is None: 29 | continue 30 | 31 | if p.ndim > 1: # if not normalization gamma/beta or bias 32 | dp = dp.add(p, alpha=g['weight_decay']) 33 | param_norm = torch.norm(p) 34 | update_norm = torch.norm(dp) 35 | one = torch.ones_like(param_norm) 36 | q = torch.where(param_norm > 0., 37 | torch.where(update_norm > 0, 38 | (g['trust_coefficient'] * param_norm / update_norm), one), 39 | one) 40 | dp = dp.mul(q) 41 | 42 | param_state = self.state[p] 43 | if 'mu' not in param_state: 44 | param_state['mu'] = torch.zeros_like(p) 45 | mu = param_state['mu'] 46 | mu.mul_(g['momentum']).add_(dp) 47 | p.add_(mu, alpha=-g['lr']) -------------------------------------------------------------------------------- /util/lr_decay.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # ELECTRA https://github.com/google-research/electra 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import json 13 | 14 | 15 | def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75): 16 | """ 17 | Parameter groups for layer-wise lr decay 18 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 19 | """ 20 | param_group_names = {} 21 | param_groups = {} 22 | 23 | num_layers = len(model.blocks) + 1 24 | 25 | layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1)) 26 | 27 | for n, p in model.named_parameters(): 28 | if not p.requires_grad: 29 | continue 30 | 31 | # no decay: all 1D parameters and model specific ones 32 | if p.ndim == 1 or n in no_weight_decay_list: 33 | g_decay = "no_decay" 34 | this_decay = 0. 35 | else: 36 | g_decay = "decay" 37 | this_decay = weight_decay 38 | 39 | layer_id = get_layer_id_for_vit(n, num_layers) 40 | group_name = "layer_%d_%s" % (layer_id, g_decay) 41 | 42 | if group_name not in param_group_names: 43 | this_scale = layer_scales[layer_id] 44 | 45 | param_group_names[group_name] = { 46 | "lr_scale": this_scale, 47 | "weight_decay": this_decay, 48 | "params": [], 49 | } 50 | param_groups[group_name] = { 51 | "lr_scale": this_scale, 52 | "weight_decay": this_decay, 53 | "params": [], 54 | } 55 | 56 | param_group_names[group_name]["params"].append(n) 57 | param_groups[group_name]["params"].append(p) 58 | 59 | # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) 60 | 61 | return list(param_groups.values()) 62 | 63 | 64 | def get_layer_id_for_vit(name, num_layers): 65 | """ 66 | Assign a parameter with its layer id 67 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 68 | """ 69 | if name in ['cls_token', 'pos_embed']: 70 | return 0 71 | elif name.startswith('patch_embed'): 72 | return 0 73 | elif name.startswith('blocks'): 74 | return int(name.split('.')[1]) + 1 75 | else: 76 | return num_layers -------------------------------------------------------------------------------- /util/lr_sched.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | def adjust_learning_rate(optimizer, epoch, args): 10 | """Decay the learning rate with half-cycle cosine after warmup""" 11 | if epoch < args.warmup_epochs: 12 | lr = args.lr * epoch / args.warmup_epochs 13 | else: 14 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \ 15 | (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))) 16 | for param_group in optimizer.param_groups: 17 | if "lr_scale" in param_group: 18 | param_group["lr"] = lr * param_group["lr_scale"] 19 | else: 20 | param_group["lr"] = lr 21 | return lr 22 | -------------------------------------------------------------------------------- /util/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import builtins 13 | import datetime 14 | import os 15 | import time 16 | from collections import defaultdict, deque 17 | from pathlib import Path 18 | 19 | import torch 20 | import torch.distributed as dist 21 | from torch._six import inf 22 | 23 | 24 | class SmoothedValue(object): 25 | """Track a series of values and provide access to smoothed values over a 26 | window or the global series average. 27 | """ 28 | 29 | def __init__(self, window_size=20, fmt=None): 30 | if fmt is None: 31 | fmt = "{median:.4f} ({global_avg:.4f})" 32 | self.deque = deque(maxlen=window_size) 33 | self.total = 0.0 34 | self.count = 0 35 | self.fmt = fmt 36 | 37 | def update(self, value, n=1): 38 | self.deque.append(value) 39 | self.count += n 40 | self.total += value * n 41 | 42 | def synchronize_between_processes(self): 43 | """ 44 | Warning: does not synchronize the deque! 45 | """ 46 | if not is_dist_avail_and_initialized(): 47 | return 48 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 49 | dist.barrier() 50 | dist.all_reduce(t) 51 | t = t.tolist() 52 | self.count = int(t[0]) 53 | self.total = t[1] 54 | 55 | @property 56 | def median(self): 57 | d = torch.tensor(list(self.deque)) 58 | return d.median().item() 59 | 60 | @property 61 | def avg(self): 62 | d = torch.tensor(list(self.deque), dtype=torch.float32) 63 | return d.mean().item() 64 | 65 | @property 66 | def global_avg(self): 67 | return self.total / self.count 68 | 69 | @property 70 | def max(self): 71 | return max(self.deque) 72 | 73 | @property 74 | def value(self): 75 | return self.deque[-1] 76 | 77 | def __str__(self): 78 | return self.fmt.format( 79 | median=self.median, 80 | avg=self.avg, 81 | global_avg=self.global_avg, 82 | max=self.max, 83 | value=self.value) 84 | 85 | 86 | class MetricLogger(object): 87 | def __init__(self, delimiter="\t"): 88 | self.meters = defaultdict(SmoothedValue) 89 | self.delimiter = delimiter 90 | 91 | def update(self, **kwargs): 92 | for k, v in kwargs.items(): 93 | if v is None: 94 | continue 95 | if isinstance(v, torch.Tensor): 96 | v = v.item() 97 | assert isinstance(v, (float, int)) 98 | self.meters[k].update(v) 99 | 100 | def __getattr__(self, attr): 101 | if attr in self.meters: 102 | return self.meters[attr] 103 | if attr in self.__dict__: 104 | return self.__dict__[attr] 105 | raise AttributeError("'{}' object has no attribute '{}'".format( 106 | type(self).__name__, attr)) 107 | 108 | def __str__(self): 109 | loss_str = [] 110 | for name, meter in self.meters.items(): 111 | loss_str.append( 112 | "{}: {}".format(name, str(meter)) 113 | ) 114 | return self.delimiter.join(loss_str) 115 | 116 | def synchronize_between_processes(self): 117 | for meter in self.meters.values(): 118 | meter.synchronize_between_processes() 119 | 120 | def add_meter(self, name, meter): 121 | self.meters[name] = meter 122 | 123 | def log_every(self, iterable, print_freq, header=None): 124 | i = 0 125 | if not header: 126 | header = '' 127 | start_time = time.time() 128 | end = time.time() 129 | iter_time = SmoothedValue(fmt='{avg:.4f}') 130 | data_time = SmoothedValue(fmt='{avg:.4f}') 131 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 132 | log_msg = [ 133 | header, 134 | '[{0' + space_fmt + '}/{1}]', 135 | 'eta: {eta}', 136 | '{meters}', 137 | 'time: {time}', 138 | 'data: {data}' 139 | ] 140 | if torch.cuda.is_available(): 141 | log_msg.append('max mem: {memory:.0f}') 142 | log_msg = self.delimiter.join(log_msg) 143 | MB = 1024.0 * 1024.0 144 | for obj in iterable: 145 | data_time.update(time.time() - end) 146 | yield obj 147 | iter_time.update(time.time() - end) 148 | if i % print_freq == 0 or i == len(iterable) - 1: 149 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 150 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 151 | if torch.cuda.is_available(): 152 | print(log_msg.format( 153 | i, len(iterable), eta=eta_string, 154 | meters=str(self), 155 | time=str(iter_time), data=str(data_time), 156 | memory=torch.cuda.max_memory_allocated() / MB)) 157 | else: 158 | print(log_msg.format( 159 | i, len(iterable), eta=eta_string, 160 | meters=str(self), 161 | time=str(iter_time), data=str(data_time))) 162 | i += 1 163 | end = time.time() 164 | total_time = time.time() - start_time 165 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 166 | print('{} Total time: {} ({:.4f} s / it)'.format( 167 | header, total_time_str, total_time / len(iterable))) 168 | 169 | 170 | def setup_for_distributed(is_master): 171 | """ 172 | This function disables printing when not in master process 173 | """ 174 | builtin_print = builtins.print 175 | 176 | def print(*args, **kwargs): 177 | force = kwargs.pop('force', False) 178 | force = force or (get_world_size() > 8) 179 | if is_master or force: 180 | now = datetime.datetime.now().time() 181 | builtin_print('[{}] '.format(now), end='') # print with time stamp 182 | builtin_print(*args, **kwargs) 183 | 184 | builtins.print = print 185 | 186 | 187 | def is_dist_avail_and_initialized(): 188 | if not dist.is_available(): 189 | return False 190 | if not dist.is_initialized(): 191 | return False 192 | return True 193 | 194 | 195 | def get_world_size(): 196 | if not is_dist_avail_and_initialized(): 197 | return 1 198 | return dist.get_world_size() 199 | 200 | 201 | def get_rank(): 202 | if not is_dist_avail_and_initialized(): 203 | return 0 204 | return dist.get_rank() 205 | 206 | 207 | def is_main_process(): 208 | return get_rank() == 0 209 | 210 | 211 | def save_on_master(*args, **kwargs): 212 | if is_main_process(): 213 | torch.save(*args, **kwargs) 214 | 215 | 216 | def init_distributed_mode(args): 217 | if args.dist_on_itp: 218 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 219 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) 220 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 221 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) 222 | os.environ['LOCAL_RANK'] = str(args.gpu) 223 | os.environ['RANK'] = str(args.rank) 224 | os.environ['WORLD_SIZE'] = str(args.world_size) 225 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] 226 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 227 | args.rank = int(os.environ["RANK"]) 228 | args.world_size = int(os.environ['WORLD_SIZE']) 229 | args.gpu = int(os.environ['LOCAL_RANK']) 230 | elif 'SLURM_PROCID' in os.environ: 231 | args.rank = int(os.environ['SLURM_PROCID']) 232 | args.gpu = args.rank % torch.cuda.device_count() 233 | else: 234 | print('Not using distributed mode') 235 | setup_for_distributed(is_master=True) # hack 236 | args.distributed = False 237 | return 238 | 239 | args.distributed = True 240 | 241 | torch.cuda.set_device(args.gpu) 242 | args.dist_backend = 'nccl' 243 | print('| distributed init (rank {}): {}, gpu {}'.format( 244 | args.rank, args.dist_url, args.gpu), flush=True) 245 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 246 | world_size=args.world_size, rank=args.rank) 247 | torch.distributed.barrier() 248 | setup_for_distributed(args.rank == 0) 249 | 250 | 251 | class NativeScalerWithGradNormCount: 252 | state_dict_key = "amp_scaler" 253 | 254 | def __init__(self): 255 | self._scaler = torch.cuda.amp.GradScaler() 256 | 257 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 258 | self._scaler.scale(loss).backward(create_graph=create_graph) 259 | if update_grad: 260 | if clip_grad is not None: 261 | assert parameters is not None 262 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 263 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 264 | else: 265 | self._scaler.unscale_(optimizer) 266 | norm = get_grad_norm_(parameters) 267 | self._scaler.step(optimizer) 268 | self._scaler.update() 269 | else: 270 | norm = None 271 | return norm 272 | 273 | def state_dict(self): 274 | return self._scaler.state_dict() 275 | 276 | def load_state_dict(self, state_dict): 277 | self._scaler.load_state_dict(state_dict) 278 | 279 | 280 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 281 | if isinstance(parameters, torch.Tensor): 282 | parameters = [parameters] 283 | parameters = [p for p in parameters if p.grad is not None] 284 | norm_type = float(norm_type) 285 | if len(parameters) == 0: 286 | return torch.tensor(0.) 287 | device = parameters[0].grad.device 288 | if norm_type == inf: 289 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 290 | else: 291 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) 292 | return total_norm 293 | 294 | 295 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler): 296 | output_dir = Path(args.output_dir) 297 | epoch_name = str(epoch) 298 | if loss_scaler is not None: 299 | checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)] 300 | for checkpoint_path in checkpoint_paths: 301 | to_save = { 302 | 'model': model_without_ddp.state_dict(), 303 | 'optimizer': optimizer.state_dict(), 304 | 'epoch': epoch, 305 | 'scaler': loss_scaler.state_dict(), 306 | 'args': args, 307 | } 308 | 309 | save_on_master(to_save, checkpoint_path) 310 | else: 311 | client_state = {'epoch': epoch} 312 | model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state) 313 | 314 | 315 | def load_model(args, model_without_ddp, optimizer, loss_scaler): 316 | if args.resume: 317 | if args.resume.startswith('https'): 318 | checkpoint = torch.hub.load_state_dict_from_url( 319 | args.resume, map_location='cpu', check_hash=True) 320 | else: 321 | checkpoint = torch.load(args.resume, map_location='cpu') 322 | model_without_ddp.load_state_dict(checkpoint['model']) 323 | print("Resume checkpoint %s" % args.resume) 324 | if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval): 325 | optimizer.load_state_dict(checkpoint['optimizer']) 326 | args.start_epoch = checkpoint['epoch'] + 1 327 | if 'scaler' in checkpoint: 328 | loss_scaler.load_state_dict(checkpoint['scaler']) 329 | print("With optim & sched!") 330 | 331 | 332 | def all_reduce_mean(x): 333 | world_size = get_world_size() 334 | if world_size > 1: 335 | x_reduce = torch.tensor(x).cuda() 336 | dist.all_reduce(x_reduce) 337 | x_reduce /= world_size 338 | return x_reduce.item() 339 | else: 340 | return x -------------------------------------------------------------------------------- /util/pos_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # Position embedding utils 8 | # -------------------------------------------------------- 9 | 10 | import numpy as np 11 | 12 | import torch 13 | 14 | # -------------------------------------------------------- 15 | # 2D sine-cosine position embedding 16 | # References: 17 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 18 | # MoCo v3: https://github.com/facebookresearch/moco-v3 19 | # -------------------------------------------------------- 20 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 21 | """ 22 | grid_size: int of the grid height and width 23 | return: 24 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 25 | """ 26 | grid_h = np.arange(grid_size, dtype=np.float32) 27 | grid_w = np.arange(grid_size, dtype=np.float32) 28 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 29 | grid = np.stack(grid, axis=0) 30 | 31 | grid = grid.reshape([2, 1, grid_size, grid_size]) 32 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 33 | if cls_token: 34 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 35 | return pos_embed 36 | 37 | 38 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 39 | assert embed_dim % 2 == 0 40 | 41 | # use half of dimensions to encode grid_h 42 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 43 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 44 | 45 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 46 | return emb 47 | 48 | 49 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 50 | """ 51 | embed_dim: output dimension for each position 52 | pos: a list of positions to be encoded: size (M,) 53 | out: (M, D) 54 | """ 55 | assert embed_dim % 2 == 0 56 | omega = np.arange(embed_dim // 2, dtype=np.float) 57 | omega /= embed_dim / 2. 58 | omega = 1. / 10000**omega # (D/2,) 59 | 60 | pos = pos.reshape(-1) # (M,) 61 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 62 | 63 | emb_sin = np.sin(out) # (M, D/2) 64 | emb_cos = np.cos(out) # (M, D/2) 65 | 66 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 67 | return emb 68 | 69 | 70 | # -------------------------------------------------------- 71 | # Interpolate position embeddings for high-resolution 72 | # References: 73 | # DeiT: https://github.com/facebookresearch/deit 74 | # -------------------------------------------------------- 75 | def interpolate_pos_embed(model, checkpoint_model): 76 | if 'pos_embed' in checkpoint_model: 77 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 78 | embedding_size = pos_embed_checkpoint.shape[-1] 79 | num_patches = model.patch_embed.num_patches 80 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 81 | # height (== width) for the checkpoint position embedding 82 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 83 | # height (== width) for the new position embedding 84 | new_size = int(num_patches ** 0.5) 85 | # class_token and dist_token are kept unchanged 86 | if orig_size != new_size: 87 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 88 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 89 | # only the position tokens are interpolated 90 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 91 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 92 | pos_tokens = torch.nn.functional.interpolate( 93 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 94 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 95 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 96 | checkpoint_model['pos_embed'] = new_pos_embed 97 | --------------------------------------------------------------------------------