├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── DATASET.md ├── FINETUNE.md ├── LICENSE ├── PRETRAIN.md ├── README.md ├── demo ├── demo.avi ├── mr-95-demo-vid-0.gif ├── mr-95-demo-vid-1.gif ├── mr-95-demo-vid-2.gif ├── mr-95-demo-vid-3.gif ├── mr-95-demo-vid-4.gif ├── mr-95-demo-vid.gif ├── mr-98-demo-vid-0.gif ├── mr-98-demo-vid-1.gif ├── mr-98-demo-vid-2.gif ├── mr-98-demo-vid-3.gif ├── mr-98-demo-vid-4.gif └── mr-98-demo-vid.gif ├── engine_finetune.py ├── engine_pretrain.py ├── engine_test.py ├── launch_fb_flow.sh ├── launch_fb_local.sh ├── launch_fb_tensorboard.sh ├── main_finetune.py ├── main_pretrain.py ├── main_test.py ├── models_mae.py ├── models_vit.py ├── run_finetune.py ├── run_pretrain.py ├── run_test.py └── util ├── decoder ├── decoder.py ├── mixup.py ├── rand_augment.py ├── random_erasing.py ├── transform.py ├── utils.py └── video_container.py ├── env.py ├── kinetics.py ├── logging.py ├── lr_decay.py ├── lr_sched.py ├── meters.py ├── misc.py ├── pos_embed.py └── video_vit.py /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to mae_st 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to mae, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. 32 | -------------------------------------------------------------------------------- /DATASET.md: -------------------------------------------------------------------------------- 1 | # Dataset Preparation 2 | 3 | ## Kinetics 4 | 5 | The Kinetics Dataset could be downloaded via the code released by ActivityNet: 6 | 7 | 1. Download the videos via the official [scripts](https://github.com/activitynet/ActivityNet/tree/master/Crawler/Kinetics). 8 | 9 | 2. After all the videos were downloaded, resize the video to the short edge size of 256, then prepare the csv files for training, validation, and testing set as `train.csv`, `val.csv`, `test.csv`. The format of the csv file is: 10 | 11 | ``` 12 | path_to_video_1 label_1 13 | path_to_video_2 label_2 14 | path_to_video_3 label_3 15 | ... 16 | path_to_video_N label_N 17 | ``` 18 | 19 | More info about video dataset preparation can be found in [PySlowFast](https://github.com/facebookresearch/SlowFast). For dataset specific issues, please reach out to the [dataset provider](https://deepmind.com/research/open-source/kinetics). 20 | -------------------------------------------------------------------------------- /FINETUNE.md: -------------------------------------------------------------------------------- 1 | ## Fine-tuning Pre-trained MAE for Classification 2 | 3 | ### Evaluation 4 | 5 | As a sanity check, run evaluation using our ImageNet **fine-tuned** models: 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 |
ViT-LargeViT-Huge
pre-trained checkpoint on Kinetics-400downloaddownload
md5edf3a53d7f64
26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 |
ViT-LargeViT-Huge
pre-trained checkpoint on Kinetics-600downloaddownload
md59a964527495e
48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 |
ViT-LargeViT-Huge
pre-trained checkpoint on Kinetics-700downloaddownload
md5cdbada4c4e3c
69 | 70 | 71 | Evaluate ViT-Large: (`${KINETICS_DIR}` is a directory containing `{train, val}` sets of Kinetics): 72 | ``` 73 | python run_finetune.py --path_to_data_dir ${KINETICS_DIR} --rand_aug --epochs 50 --repeat_aug 2 --model vit_large_patch16 --batch_size 2 --distributed --dist_eval --smoothing 0.1 --mixup 0.8 --cutmix 1.0 --mixup_prob 1.0 --blr 0.0024 --num_frames 16 --sampling_rate 4 --dropout 0.3 --warmup_epochs 5 --layer_decay 0.75 --drop_path_rate 0.2 --aa rand-m7-mstd0.5-inc1 --clip_grad 5.0 --fp32"}${FINETUNE_APPENDIX} 74 | ``` 75 | This should give: 76 | ``` 77 | * Acc@1 84.35 78 | ``` 79 | 80 | #### Notes 81 | 82 | - The pre-trained models we provide are trained with *normalized* pixels `--norm_pix_loss` (1600 effective epochs). The models are pretrained in PySlowFast codebase. 83 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Attribution-NonCommercial 4.0 International 3 | 4 | ======================================================================= 5 | 6 | Creative Commons Corporation ("Creative Commons") is not a law firm and 7 | does not provide legal services or legal advice. Distribution of 8 | Creative Commons public licenses does not create a lawyer-client or 9 | other relationship. Creative Commons makes its licenses and related 10 | information available on an "as-is" basis. Creative Commons gives no 11 | warranties regarding its licenses, any material licensed under their 12 | terms and conditions, or any related information. Creative Commons 13 | disclaims all liability for damages resulting from their use to the 14 | fullest extent possible. 15 | 16 | Using Creative Commons Public Licenses 17 | 18 | Creative Commons public licenses provide a standard set of terms and 19 | conditions that creators and other rights holders may use to share 20 | original works of authorship and other material subject to copyright 21 | and certain other rights specified in the public license below. The 22 | following considerations are for informational purposes only, are not 23 | exhaustive, and do not form part of our licenses. 24 | 25 | Considerations for licensors: Our public licenses are 26 | intended for use by those authorized to give the public 27 | permission to use material in ways otherwise restricted by 28 | copyright and certain other rights. Our licenses are 29 | irrevocable. Licensors should read and understand the terms 30 | and conditions of the license they choose before applying it. 31 | Licensors should also secure all rights necessary before 32 | applying our licenses so that the public can reuse the 33 | material as expected. Licensors should clearly mark any 34 | material not subject to the license. This includes other CC- 35 | licensed material, or material used under an exception or 36 | limitation to copyright. More considerations for licensors: 37 | wiki.creativecommons.org/Considerations_for_licensors 38 | 39 | Considerations for the public: By using one of our public 40 | licenses, a licensor grants the public permission to use the 41 | licensed material under specified terms and conditions. If 42 | the licensor's permission is not necessary for any reason--for 43 | example, because of any applicable exception or limitation to 44 | copyright--then that use is not regulated by the license. Our 45 | licenses grant only permissions under copyright and certain 46 | other rights that a licensor has authority to grant. Use of 47 | the licensed material may still be restricted for other 48 | reasons, including because others have copyright or other 49 | rights in the material. A licensor may make special requests, 50 | such as asking that all changes be marked or described. 51 | Although not required by our licenses, you are encouraged to 52 | respect those requests where reasonable. More_considerations 53 | for the public: 54 | wiki.creativecommons.org/Considerations_for_licensees 55 | 56 | ======================================================================= 57 | 58 | Creative Commons Attribution-NonCommercial 4.0 International Public 59 | License 60 | 61 | By exercising the Licensed Rights (defined below), You accept and agree 62 | to be bound by the terms and conditions of this Creative Commons 63 | Attribution-NonCommercial 4.0 International Public License ("Public 64 | License"). To the extent this Public License may be interpreted as a 65 | contract, You are granted the Licensed Rights in consideration of Your 66 | acceptance of these terms and conditions, and the Licensor grants You 67 | such rights in consideration of benefits the Licensor receives from 68 | making the Licensed Material available under these terms and 69 | conditions. 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. Copyright and Similar Rights means copyright and/or similar rights 88 | closely related to copyright including, without limitation, 89 | performance, broadcast, sound recording, and Sui Generis Database 90 | Rights, without regard to how the rights are labeled or 91 | categorized. For purposes of this Public License, the rights 92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 93 | Rights. 94 | d. Effective Technological Measures means those measures that, in the 95 | absence of proper authority, may not be circumvented under laws 96 | fulfilling obligations under Article 11 of the WIPO Copyright 97 | Treaty adopted on December 20, 1996, and/or similar international 98 | agreements. 99 | 100 | e. Exceptions and Limitations means fair use, fair dealing, and/or 101 | any other exception or limitation to Copyright and Similar Rights 102 | that applies to Your use of the Licensed Material. 103 | 104 | f. Licensed Material means the artistic or literary work, database, 105 | or other material to which the Licensor applied this Public 106 | License. 107 | 108 | g. Licensed Rights means the rights granted to You subject to the 109 | terms and conditions of this Public License, which are limited to 110 | all Copyright and Similar Rights that apply to Your use of the 111 | Licensed Material and that the Licensor has authority to license. 112 | 113 | h. Licensor means the individual(s) or entity(ies) granting rights 114 | under this Public License. 115 | 116 | i. NonCommercial means not primarily intended for or directed towards 117 | commercial advantage or monetary compensation. For purposes of 118 | this Public License, the exchange of the Licensed Material for 119 | other material subject to Copyright and Similar Rights by digital 120 | file-sharing or similar means is NonCommercial provided there is 121 | no payment of monetary compensation in connection with the 122 | exchange. 123 | 124 | j. Share means to provide material to the public by any means or 125 | process that requires permission under the Licensed Rights, such 126 | as reproduction, public display, public performance, distribution, 127 | dissemination, communication, or importation, and to make material 128 | available to the public including in ways that members of the 129 | public may access the material from a place and at a time 130 | individually chosen by them. 131 | 132 | k. Sui Generis Database Rights means rights other than copyright 133 | resulting from Directive 96/9/EC of the European Parliament and of 134 | the Council of 11 March 1996 on the legal protection of databases, 135 | as amended and/or succeeded, as well as other essentially 136 | equivalent rights anywhere in the world. 137 | 138 | l. You means the individual or entity exercising the Licensed Rights 139 | under this Public License. Your has a corresponding meaning. 140 | 141 | Section 2 -- Scope. 142 | 143 | a. License grant. 144 | 145 | 1. Subject to the terms and conditions of this Public License, 146 | the Licensor hereby grants You a worldwide, royalty-free, 147 | non-sublicensable, non-exclusive, irrevocable license to 148 | exercise the Licensed Rights in the Licensed Material to: 149 | 150 | a. reproduce and Share the Licensed Material, in whole or 151 | in part, for NonCommercial purposes only; and 152 | 153 | b. produce, reproduce, and Share Adapted Material for 154 | NonCommercial purposes only. 155 | 156 | 2. Exceptions and Limitations. For the avoidance of doubt, where 157 | Exceptions and Limitations apply to Your use, this Public 158 | License does not apply, and You do not need to comply with 159 | its terms and conditions. 160 | 161 | 3. Term. The term of this Public License is specified in Section 162 | 6(a). 163 | 164 | 4. Media and formats; technical modifications allowed. The 165 | Licensor authorizes You to exercise the Licensed Rights in 166 | all media and formats whether now known or hereafter created, 167 | and to make technical modifications necessary to do so. The 168 | Licensor waives and/or agrees not to assert any right or 169 | authority to forbid You from making technical modifications 170 | necessary to exercise the Licensed Rights, including 171 | technical modifications necessary to circumvent Effective 172 | Technological Measures. For purposes of this Public License, 173 | simply making modifications authorized by this Section 2(a) 174 | (4) never produces Adapted Material. 175 | 176 | 5. Downstream recipients. 177 | 178 | a. Offer from the Licensor -- Licensed Material. Every 179 | recipient of the Licensed Material automatically 180 | receives an offer from the Licensor to exercise the 181 | Licensed Rights under the terms and conditions of this 182 | Public License. 183 | 184 | b. No downstream restrictions. You may not offer or impose 185 | any additional or different terms or conditions on, or 186 | apply any Effective Technological Measures to, the 187 | Licensed Material if doing so restricts exercise of the 188 | Licensed Rights by any recipient of the Licensed 189 | Material. 190 | 191 | 6. No endorsement. Nothing in this Public License constitutes or 192 | may be construed as permission to assert or imply that You 193 | are, or that Your use of the Licensed Material is, connected 194 | with, or sponsored, endorsed, or granted official status by, 195 | the Licensor or others designated to receive attribution as 196 | provided in Section 3(a)(1)(A)(i). 197 | 198 | b. Other rights. 199 | 200 | 1. Moral rights, such as the right of integrity, are not 201 | licensed under this Public License, nor are publicity, 202 | privacy, and/or other similar personality rights; however, to 203 | the extent possible, the Licensor waives and/or agrees not to 204 | assert any such rights held by the Licensor to the limited 205 | extent necessary to allow You to exercise the Licensed 206 | Rights, but not otherwise. 207 | 208 | 2. Patent and trademark rights are not licensed under this 209 | Public License. 210 | 211 | 3. To the extent possible, the Licensor waives any right to 212 | collect royalties from You for the exercise of the Licensed 213 | Rights, whether directly or through a collecting society 214 | under any voluntary or waivable statutory or compulsory 215 | licensing scheme. In all other cases the Licensor expressly 216 | reserves any right to collect such royalties, including when 217 | the Licensed Material is used other than for NonCommercial 218 | purposes. 219 | 220 | Section 3 -- License Conditions. 221 | 222 | Your exercise of the Licensed Rights is expressly made subject to the 223 | following conditions. 224 | 225 | a. Attribution. 226 | 227 | 1. If You Share the Licensed Material (including in modified 228 | form), You must: 229 | 230 | a. retain the following if it is supplied by the Licensor 231 | with the Licensed Material: 232 | 233 | i. identification of the creator(s) of the Licensed 234 | Material and any others designated to receive 235 | attribution, in any reasonable manner requested by 236 | the Licensor (including by pseudonym if 237 | designated); 238 | 239 | ii. a copyright notice; 240 | 241 | iii. a notice that refers to this Public License; 242 | 243 | iv. a notice that refers to the disclaimer of 244 | warranties; 245 | 246 | v. a URI or hyperlink to the Licensed Material to the 247 | extent reasonably practicable; 248 | 249 | b. indicate if You modified the Licensed Material and 250 | retain an indication of any previous modifications; and 251 | 252 | c. indicate the Licensed Material is licensed under this 253 | Public License, and include the text of, or the URI or 254 | hyperlink to, this Public License. 255 | 256 | 2. You may satisfy the conditions in Section 3(a)(1) in any 257 | reasonable manner based on the medium, means, and context in 258 | which You Share the Licensed Material. For example, it may be 259 | reasonable to satisfy the conditions by providing a URI or 260 | hyperlink to a resource that includes the required 261 | information. 262 | 263 | 3. If requested by the Licensor, You must remove any of the 264 | information required by Section 3(a)(1)(A) to the extent 265 | reasonably practicable. 266 | 267 | 4. If You Share Adapted Material You produce, the Adapter's 268 | License You apply must not prevent recipients of the Adapted 269 | Material from complying with this Public License. 270 | 271 | Section 4 -- Sui Generis Database Rights. 272 | 273 | Where the Licensed Rights include Sui Generis Database Rights that 274 | apply to Your use of the Licensed Material: 275 | 276 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 277 | to extract, reuse, reproduce, and Share all or a substantial 278 | portion of the contents of the database for NonCommercial purposes 279 | only; 280 | 281 | b. if You include all or a substantial portion of the database 282 | contents in a database in which You have Sui Generis Database 283 | Rights, then the database in which You have Sui Generis Database 284 | Rights (but not its individual contents) is Adapted Material; and 285 | 286 | c. You must comply with the conditions in Section 3(a) if You Share 287 | all or a substantial portion of the contents of the database. 288 | 289 | For the avoidance of doubt, this Section 4 supplements and does not 290 | replace Your obligations under this Public License where the Licensed 291 | Rights include other Copyright and Similar Rights. 292 | 293 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 294 | 295 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 296 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 297 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 298 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 299 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 300 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 301 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 302 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 303 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 304 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 305 | 306 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 307 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 308 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 309 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 310 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 311 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 312 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 313 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 314 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 315 | 316 | c. The disclaimer of warranties and limitation of liability provided 317 | above shall be interpreted in a manner that, to the extent 318 | possible, most closely approximates an absolute disclaimer and 319 | waiver of all liability. 320 | 321 | Section 6 -- Term and Termination. 322 | 323 | a. This Public License applies for the term of the Copyright and 324 | Similar Rights licensed here. However, if You fail to comply with 325 | this Public License, then Your rights under this Public License 326 | terminate automatically. 327 | 328 | b. Where Your right to use the Licensed Material has terminated under 329 | Section 6(a), it reinstates: 330 | 331 | 1. automatically as of the date the violation is cured, provided 332 | it is cured within 30 days of Your discovery of the 333 | violation; or 334 | 335 | 2. upon express reinstatement by the Licensor. 336 | 337 | For the avoidance of doubt, this Section 6(b) does not affect any 338 | right the Licensor may have to seek remedies for Your violations 339 | of this Public License. 340 | 341 | c. For the avoidance of doubt, the Licensor may also offer the 342 | Licensed Material under separate terms or conditions or stop 343 | distributing the Licensed Material at any time; however, doing so 344 | will not terminate this Public License. 345 | 346 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 347 | License. 348 | 349 | Section 7 -- Other Terms and Conditions. 350 | 351 | a. The Licensor shall not be bound by any additional or different 352 | terms or conditions communicated by You unless expressly agreed. 353 | 354 | b. Any arrangements, understandings, or agreements regarding the 355 | Licensed Material not stated herein are separate from and 356 | independent of the terms and conditions of this Public License. 357 | 358 | Section 8 -- Interpretation. 359 | 360 | a. For the avoidance of doubt, this Public License does not, and 361 | shall not be interpreted to, reduce, limit, restrict, or impose 362 | conditions on any use of the Licensed Material that could lawfully 363 | be made without permission under this Public License. 364 | 365 | b. To the extent possible, if any provision of this Public License is 366 | deemed unenforceable, it shall be automatically reformed to the 367 | minimum extent necessary to make it enforceable. If the provision 368 | cannot be reformed, it shall be severed from this Public License 369 | without affecting the enforceability of the remaining terms and 370 | conditions. 371 | 372 | c. No term or condition of this Public License will be waived and no 373 | failure to comply consented to unless expressly agreed to by the 374 | Licensor. 375 | 376 | d. Nothing in this Public License constitutes or may be interpreted 377 | as a limitation upon, or waiver of, any privileges and immunities 378 | that apply to the Licensor or You, including from the legal 379 | processes of any jurisdiction or authority. 380 | 381 | ======================================================================= 382 | 383 | Creative Commons is not a party to its public 384 | licenses. Notwithstanding, Creative Commons may elect to apply one of 385 | its public licenses to material it publishes and in those instances 386 | will be considered the “Licensor.” The text of the Creative Commons 387 | public licenses is dedicated to the public domain under the CC0 Public 388 | Domain Dedication. Except for the limited purpose of indicating that 389 | material is shared under a Creative Commons public license or as 390 | otherwise permitted by the Creative Commons policies published at 391 | creativecommons.org/policies, Creative Commons does not authorize the 392 | use of the trademark "Creative Commons" or any other trademark or logo 393 | of Creative Commons without its prior written consent including, 394 | without limitation, in connection with any unauthorized modifications 395 | to any of its public licenses or any other arrangements, 396 | understandings, or agreements concerning use of licensed material. For 397 | the avoidance of doubt, this paragraph does not form part of the 398 | public licenses. 399 | 400 | Creative Commons may be contacted at creativecommons.org. 401 | -------------------------------------------------------------------------------- /PRETRAIN.md: -------------------------------------------------------------------------------- 1 | ## Pre-training MAE 2 | To pre-train ViT-Large (recommended default), run the following: 3 | 4 | ``` 5 | python run_pretrain.py \ 6 | --path_to_data_dir ${KINETICS_DIR} \ 7 | --batch_size 2 \ 8 | --model mae_vit_large_patch16 \ 9 | --no_env \ 10 | --epochs 100 \ 11 | --distributed \ 12 | --num_frames 16 \ 13 | --decoder_embed_dim 512 \ 14 | --decoder_depth 4 \ 15 | --pin_mem \ 16 | --num_workers 14 \ 17 | --t_patch_size 2 \ 18 | --repeat_aug 4 \ 19 | --sampling_rate 4 \ 20 | --norm_pix_loss \ 21 | --blr 1.6e-3 \ 22 | --warmup_epochs 5 \ 23 | --mask_ratio 0.9 \ 24 | --pred_t_dim 8 \ 25 | --clip_grad 0.02 \ 26 | ``` 27 | 28 | blr is the base learning rate. The actual lr is computed by the linear scaling rule: lr = blr * effective batch size / 256. 29 | Here we use --norm_pix_loss as the target for better representation learning. To train a baseline model (e.g., for visualization), use pixel-based constructiomae_vit_large_patch16n and turn off --norm_pix_loss. 30 | To train ViT-Base or ViT-Huge, set --model mae_vit_base_patch16 or --model mae_vit_huge_patch14. 31 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Masked Autoencoders As Spatiotemporal Learners: A PyTorch Implementation 2 | 3 |

4 | 5 |

6 | 7 | 8 | This is a PyTorch/GPU re-implementation of the paper [Masked Autoencoders As Spatiotemporal Learners](https://arxiv.org/abs/2205.09113): 9 | ``` 10 | @Article{MaskedAutoencodersSpatiotemporal2022, 11 | author = {Christoph Feichtenhofer and Haoqi Fan and Yanghao Li and Kaiming He}, 12 | journal = {arXiv:2205.09113}, 13 | title = {Masked Autoencoders As Spatiotemporal Learners}, 14 | year = {2022}, 15 | } 16 | ``` 17 | Another implementation that supports AVA and SSv2 downstream evaluation is available in [PySlowFast](https://github.com/facebookresearch/SlowFast). 18 | 19 | * This repo is a modification on the [MAE repo](https://github.com/facebookresearch/mae). Installation and preparation follow [INSTALL.md](https://github.com/facebookresearch/SlowFast/blob/main/INSTALL.md). 20 | 21 | * This repo is based on [`timm==0.3.2`](https://github.com/rwightman/pytorch-image-models), for which a [fix](https://github.com/rwightman/pytorch-image-models/issues/420#issuecomment-776459842) is needed to work with PyTorch 1.8.1+. 22 | 23 | 24 | 25 | ### Catalog 26 | 27 | - [x] Visualization demo 28 | - [x] Pre-trained checkpoints + fine-tuning code + testing code 29 | - [x] Pre-training code 30 | 31 | ### Visualization demo 32 | 33 | 34 | Visualization of MAE output with 95% (left) and 98% (right) mask rate on the same video. 35 |
36 | 37 |
38 |
39 | 40 |
41 |
42 | 43 |
44 |
45 | 46 |
47 |
48 | 49 |
50 | 51 | Run our interactive visualization demo using [Colab notebook](TO_BE_ADD) (no GPU needed): 52 |

53 | 54 |

55 | 56 | ### Fine-tuning with pre-trained checkpoints 57 | 58 | The following table provides the pre-trained checkpoints used in the paper, pretrained with **90%** mask ratio and **1600 effective epochs**, converted from the PySlowFast codebase: 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 |
ViT-LargeViT-Huge
pre-trained checkpoint on Kinetics-400downloaddownload
md5edf3a53d7f64
77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 |
ViT-LargeViT-Huge
pre-trained checkpoint on Kinetics-600downloaddownload
md59a964527495e
96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 |
ViT-LargeViT-Huge
pre-trained checkpoint on Kinetics-700downloaddownload
md5cdbada4c4e3c
115 | 116 | 117 | The fine-tuning instruction is in [FINETUNE.md](FINETUNE.md). 118 | 119 | 120 | ### Pre-training 121 | 122 | The pre-training instruction is in [PRETRAIN.md](PRETRAIN.md). 123 | 124 | ### License 125 | 126 | This project is under the CC-BY-NC 4.0 license. See [LICENSE](LICENSE) for details. 127 | -------------------------------------------------------------------------------- /demo/demo.avi: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/mae_st/c5dec1bc01062097906ea67fe47935ec33fd46df/demo/demo.avi -------------------------------------------------------------------------------- /demo/mr-95-demo-vid-0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/mae_st/c5dec1bc01062097906ea67fe47935ec33fd46df/demo/mr-95-demo-vid-0.gif -------------------------------------------------------------------------------- /demo/mr-95-demo-vid-1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/mae_st/c5dec1bc01062097906ea67fe47935ec33fd46df/demo/mr-95-demo-vid-1.gif -------------------------------------------------------------------------------- /demo/mr-95-demo-vid-2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/mae_st/c5dec1bc01062097906ea67fe47935ec33fd46df/demo/mr-95-demo-vid-2.gif -------------------------------------------------------------------------------- /demo/mr-95-demo-vid-3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/mae_st/c5dec1bc01062097906ea67fe47935ec33fd46df/demo/mr-95-demo-vid-3.gif -------------------------------------------------------------------------------- /demo/mr-95-demo-vid-4.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/mae_st/c5dec1bc01062097906ea67fe47935ec33fd46df/demo/mr-95-demo-vid-4.gif -------------------------------------------------------------------------------- /demo/mr-95-demo-vid.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/mae_st/c5dec1bc01062097906ea67fe47935ec33fd46df/demo/mr-95-demo-vid.gif -------------------------------------------------------------------------------- /demo/mr-98-demo-vid-0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/mae_st/c5dec1bc01062097906ea67fe47935ec33fd46df/demo/mr-98-demo-vid-0.gif -------------------------------------------------------------------------------- /demo/mr-98-demo-vid-1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/mae_st/c5dec1bc01062097906ea67fe47935ec33fd46df/demo/mr-98-demo-vid-1.gif -------------------------------------------------------------------------------- /demo/mr-98-demo-vid-2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/mae_st/c5dec1bc01062097906ea67fe47935ec33fd46df/demo/mr-98-demo-vid-2.gif -------------------------------------------------------------------------------- /demo/mr-98-demo-vid-3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/mae_st/c5dec1bc01062097906ea67fe47935ec33fd46df/demo/mr-98-demo-vid-3.gif -------------------------------------------------------------------------------- /demo/mr-98-demo-vid-4.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/mae_st/c5dec1bc01062097906ea67fe47935ec33fd46df/demo/mr-98-demo-vid-4.gif -------------------------------------------------------------------------------- /demo/mr-98-demo-vid.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/mae_st/c5dec1bc01062097906ea67fe47935ec33fd46df/demo/mr-98-demo-vid.gif -------------------------------------------------------------------------------- /engine_finetune.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import math 13 | import sys 14 | from typing import Iterable, Optional 15 | 16 | import mae_st.util.lr_sched as lr_sched 17 | import mae_st.util.misc as misc 18 | import torch 19 | from mae_st.util.logging import master_print as print 20 | from timm.data import Mixup 21 | from timm.utils import accuracy 22 | 23 | 24 | def train_one_epoch( 25 | model: torch.nn.Module, 26 | criterion: torch.nn.Module, 27 | data_loader: Iterable, 28 | optimizer: torch.optim.Optimizer, 29 | device: torch.device, 30 | epoch: int, 31 | loss_scaler, 32 | max_norm: float = 0, 33 | mixup_fn: Optional[Mixup] = None, 34 | log_writer=None, 35 | args=None, 36 | fp32=False, 37 | ): 38 | model.train(True) 39 | metric_logger = misc.MetricLogger(delimiter=" ") 40 | metric_logger.add_meter("lr", misc.SmoothedValue(window_size=1, fmt="{value:.6f}")) 41 | metric_logger.add_meter( 42 | "cpu_mem", misc.SmoothedValue(window_size=1, fmt="{value:.6f}") 43 | ) 44 | metric_logger.add_meter( 45 | "cpu_mem_all", misc.SmoothedValue(window_size=1, fmt="{value:.6f}") 46 | ) 47 | metric_logger.add_meter( 48 | "gpu_mem", misc.SmoothedValue(window_size=1, fmt="{value:.6f}") 49 | ) 50 | header = "Epoch: [{}]".format(epoch) 51 | print_freq = 20 52 | 53 | accum_iter = args.accum_iter 54 | 55 | optimizer.zero_grad() 56 | 57 | if log_writer is not None: 58 | print("log_dir: {}".format(log_writer.log_dir)) 59 | 60 | for data_iter_step, (samples, targets) in enumerate( 61 | metric_logger.log_every(data_loader, print_freq, header) 62 | ): 63 | # we use a per iteration (instead of per epoch) lr scheduler 64 | if data_iter_step % accum_iter == 0: 65 | lr_sched.adjust_learning_rate( 66 | optimizer, data_iter_step / len(data_loader) + epoch, args 67 | ) 68 | 69 | if len(samples.shape) == 6: 70 | b, r, c, t, h, w = samples.shape 71 | samples = samples.view(b * r, c, t, h, w) 72 | targets = targets.view(b * r) 73 | 74 | if args.cpu_mix: 75 | if mixup_fn is not None: 76 | samples, targets = mixup_fn(samples, targets) 77 | samples = samples.to(device, non_blocking=True) 78 | targets = targets.to(device, non_blocking=True) 79 | else: 80 | samples = samples.to(device, non_blocking=True) 81 | targets = targets.to(device, non_blocking=True) 82 | if mixup_fn is not None: 83 | samples, targets = mixup_fn(samples, targets) 84 | 85 | with torch.cuda.amp.autocast(enabled=not fp32): 86 | outputs = model(samples) 87 | loss = criterion(outputs, targets) 88 | 89 | loss_value = loss.item() 90 | 91 | if not math.isfinite(loss_value): 92 | print("Loss is {}, stopping training".format(loss_value)) 93 | sys.exit(1) 94 | 95 | loss /= accum_iter 96 | loss_scaler( 97 | loss, 98 | optimizer, 99 | clip_grad=max_norm, 100 | parameters=model.parameters(), 101 | create_graph=False, 102 | update_grad=(data_iter_step + 1) % accum_iter == 0, 103 | ) 104 | if (data_iter_step + 1) % accum_iter == 0: 105 | optimizer.zero_grad() 106 | 107 | torch.cuda.synchronize() 108 | 109 | metric_logger.update(loss=loss_value) 110 | metric_logger.update(cpu_mem=misc.cpu_mem_usage()[0]) 111 | metric_logger.update(cpu_mem_all=misc.cpu_mem_usage()[1]) 112 | metric_logger.update(gpu_mem=misc.gpu_mem_usage()) 113 | min_lr = 10.0 114 | max_lr = 0.0 115 | for group in optimizer.param_groups: 116 | min_lr = min(min_lr, group["lr"]) 117 | max_lr = max(max_lr, group["lr"]) 118 | 119 | metric_logger.update(lr=max_lr) 120 | 121 | loss_value_reduce = misc.all_reduce_mean(loss_value) 122 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0: 123 | """We use epoch_1000x as the x-axis in tensorboard. 124 | This calibrates different curves when batch size changes. 125 | """ 126 | epoch_1000x = int( 127 | (data_iter_step / len(data_loader) + epoch) * 1000 * args.repeat_aug 128 | ) 129 | log_writer.add_scalar("loss", loss_value_reduce, epoch_1000x) 130 | log_writer.add_scalar("lr", max_lr, epoch_1000x) 131 | 132 | # gather the stats from all processes 133 | metric_logger.synchronize_between_processes() 134 | print("Averaged stats:", metric_logger) 135 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 136 | 137 | 138 | @torch.no_grad() 139 | def evaluate(data_loader, model, device): 140 | criterion = torch.nn.CrossEntropyLoss() 141 | 142 | metric_logger = misc.MetricLogger(delimiter=" ") 143 | header = "Test:" 144 | 145 | # switch to evaluation mode 146 | model.eval() 147 | 148 | for batch in metric_logger.log_every(data_loader, 10, header): 149 | images = batch[0] 150 | target = batch[-1] 151 | images = images.to(device, non_blocking=True) 152 | target = target.to(device, non_blocking=True) 153 | 154 | if len(images.shape) == 6: 155 | b, r, c, t, h, w = images.shape 156 | images = images.view(b * r, c, t, h, w) 157 | target = target.view(b * r) 158 | 159 | # compute output 160 | with torch.cuda.amp.autocast(): 161 | output = model(images) 162 | loss = criterion(output, target) 163 | 164 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 165 | 166 | batch_size = images.shape[0] 167 | metric_logger.update(loss=loss.item()) 168 | metric_logger.meters["acc1"].update(acc1.item(), n=batch_size) 169 | metric_logger.meters["acc5"].update(acc5.item(), n=batch_size) 170 | # gather the stats from all processes 171 | metric_logger.synchronize_between_processes() 172 | print( 173 | "* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}".format( 174 | top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss 175 | ) 176 | ) 177 | 178 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 179 | -------------------------------------------------------------------------------- /engine_pretrain.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | import math 12 | from typing import Iterable 13 | 14 | import mae_st.util.lr_sched as lr_sched 15 | import mae_st.util.misc as misc 16 | import torch 17 | from iopath.common.file_io import g_pathmgr as pathmgr 18 | 19 | 20 | def train_one_epoch( 21 | model: torch.nn.Module, 22 | data_loader: Iterable, 23 | optimizer: torch.optim.Optimizer, 24 | device: torch.device, 25 | epoch: int, 26 | loss_scaler, 27 | log_writer=None, 28 | args=None, 29 | fp32=False, 30 | ): 31 | model.train(True) 32 | metric_logger = misc.MetricLogger(delimiter=" ") 33 | metric_logger.add_meter("lr", misc.SmoothedValue(window_size=1, fmt="{value:.6f}")) 34 | metric_logger.add_meter( 35 | "cpu_mem", misc.SmoothedValue(window_size=1, fmt="{value:.6f}") 36 | ) 37 | metric_logger.add_meter( 38 | "cpu_mem_all", misc.SmoothedValue(window_size=1, fmt="{value:.6f}") 39 | ) 40 | metric_logger.add_meter( 41 | "gpu_mem", misc.SmoothedValue(window_size=1, fmt="{value:.6f}") 42 | ) 43 | metric_logger.add_meter( 44 | "mask_ratio", misc.SmoothedValue(window_size=1, fmt="{value:.6f}") 45 | ) 46 | header = "Epoch: [{}]".format(epoch) 47 | print_freq = 20 48 | 49 | accum_iter = args.accum_iter 50 | 51 | optimizer.zero_grad() 52 | 53 | if log_writer is not None: 54 | print("log_dir: {}".format(log_writer.log_dir)) 55 | 56 | for data_iter_step, (samples, _) in enumerate( 57 | metric_logger.log_every(data_loader, print_freq, header) 58 | ): 59 | # we use a per iteration (instead of per epoch) lr scheduler 60 | if data_iter_step % accum_iter == 0: 61 | lr_sched.adjust_learning_rate( 62 | optimizer, data_iter_step / len(data_loader) + epoch, args 63 | ) 64 | 65 | samples = samples.to(device, non_blocking=True) 66 | if len(samples.shape) == 6: 67 | b, r, c, t, h, w = samples.shape 68 | samples = samples.reshape(b * r, c, t, h, w) 69 | 70 | with torch.cuda.amp.autocast(enabled=not fp32): 71 | loss, _, _ = model( 72 | samples, 73 | mask_ratio=args.mask_ratio, 74 | ) 75 | 76 | loss_value = loss.item() 77 | 78 | if not math.isfinite(loss_value): 79 | for _ in range(args.num_checkpoint_del): 80 | try: 81 | path = misc.get_last_checkpoint(args) 82 | pathmgr.rm(path) 83 | print(f"remove checkpoint {path}") 84 | except Exception as _: 85 | pass 86 | raise Exception("Loss is {}, stopping training".format(loss_value)) 87 | 88 | loss /= accum_iter 89 | loss_scaler( 90 | loss, 91 | optimizer, 92 | parameters=model.parameters(), 93 | update_grad=(data_iter_step + 1) % accum_iter == 0, 94 | clip_grad=args.clip_grad, 95 | ) 96 | 97 | if (data_iter_step + 1) % accum_iter == 0: 98 | optimizer.zero_grad() 99 | 100 | torch.cuda.synchronize() 101 | 102 | metric_logger.update(loss=loss_value) 103 | metric_logger.update(cpu_mem=misc.cpu_mem_usage()[0]) 104 | metric_logger.update(cpu_mem_all=misc.cpu_mem_usage()[1]) 105 | metric_logger.update(gpu_mem=misc.gpu_mem_usage()) 106 | metric_logger.update(mask_ratio=args.mask_ratio) 107 | 108 | lr = optimizer.param_groups[0]["lr"] 109 | metric_logger.update(lr=lr) 110 | 111 | loss_value_reduce = misc.all_reduce_mean(loss_value) 112 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0: 113 | """We use epoch_1000x as the x-axis in tensorboard. 114 | This calibrates different curves when batch size changes. 115 | """ 116 | epoch_1000x = int( 117 | (data_iter_step / len(data_loader) + epoch) * 1000 * args.repeat_aug 118 | ) 119 | log_writer.add_scalar("train_loss", loss_value_reduce, epoch_1000x) 120 | log_writer.add_scalar("lr", lr, epoch_1000x) 121 | 122 | # gather the stats from all processes 123 | metric_logger.synchronize_between_processes() 124 | print("Averaged stats:", metric_logger) 125 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 126 | -------------------------------------------------------------------------------- /engine_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import mae_st.util.misc as misc 13 | import torch 14 | 15 | 16 | @torch.no_grad() 17 | def test(data_loader, model, device, test_meter, fp32=False): 18 | metric_logger = misc.MetricLogger(delimiter=" ") 19 | 20 | # switch to evaluation mode 21 | model.eval() 22 | softmax = torch.nn.Softmax(dim=1).cuda() 23 | 24 | for cur_iter, (images, labels, video_idx) in enumerate(data_loader): 25 | images = images.to(device, non_blocking=True) 26 | labels = labels.to(device, non_blocking=True) 27 | video_idx = video_idx.to(device, non_blocking=True) 28 | 29 | if len(images.shape) == 6: 30 | b, r, c, t, h, w = images.shape 31 | images = images.view(b * r, c, t, h, w) 32 | labels = labels.view(b * r) 33 | 34 | # compute output 35 | with torch.cuda.amp.autocast(enabled=not fp32): 36 | preds = model(images) 37 | preds = softmax(preds) 38 | 39 | if torch.distributed.is_initialized(): 40 | preds, labels, video_idx = misc.all_gather([preds, labels, video_idx]) 41 | preds = preds.cpu() 42 | labels = labels.cpu() 43 | video_idx = video_idx.cpu() 44 | # Update and log stats. 45 | test_meter.update_stats(preds.detach(), labels.detach(), video_idx.detach()) 46 | test_meter.log_iter_stats(cur_iter) 47 | 48 | test_meter.finalize_metrics() 49 | # gather the stats from all processes 50 | metric_logger.synchronize_between_processes() 51 | return test_meter.stats 52 | -------------------------------------------------------------------------------- /launch_fb_flow.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | #!/bin/bash 5 | 6 | if [ "$#" -lt 1 ] 7 | then 8 | echo "Need at least 1 parameter to determine number of machines" 9 | exit 10 | fi 11 | 12 | CHECKPOINT_PATH=manifold://winvision/tree/${USER}/logs/$(date -d "${start} + 1 day" +%F-%H%M%S-%3N) 13 | echo "${CHECKPOINT_PATH}" 14 | 15 | manifold mkdirs "${CHECKPOINT_PATH#"manifold://"}" 16 | manifold mkdirs "${CHECKPOINT_PATH#"manifold://"}""/pretrain" 17 | manifold mkdirs "${CHECKPOINT_PATH#"manifold://"}""/downstream" 18 | 19 | GANG_SCHEDULE=${GANG_SCHEDULE-1} 20 | GANG_AFFINITY=${GANG_AFFINITY-0} 21 | GPU_TYPE=${GPU_TYPE-3} 22 | POSTFIX=${POSTFIX-"benchmark"} 23 | ENT=${ENT-"default_ncg"} 24 | RUN_BENCHMARK=${RUN_BENCHMARK-0} 25 | 26 | if [ "$1" -lt 1 ] 27 | then 28 | FINETUNE_APPENDIX=" --finetune "${4} 29 | else 30 | FINETUNE_APPENDIX="" 31 | fi 32 | if [ "$2" -lt 1 ] 33 | then 34 | TESTING_APPENDIX=" --finetune "${4} 35 | else 36 | TESTING_APPENDIX="" 37 | fi 38 | 39 | P_CONFIG=${P_CONFIG-"--path_to_data_dir manifold://fair_vision_data/tree/PySlowFast/kinetics/k400 --batch_size 2 --model mae_vit_large_patch16 --no_env --epochs 100 --distributed --num_frames 16 --decoder_embed_dim 512 --decoder_depth 4 --repeat_aug 4 --sampling_rate 4 --norm_pix_loss --blr 1.6e-3 --warmup_epochs 20 --mask_ratio 0.9 --cls_embed --pred_t_dim 8 --fp32 --sep_pos_embed --clip_grad 0.02"} 40 | 41 | D_CONFIG=${D_CONFIG-"--path_to_data_dir manifold://fair_vision_data/tree/PySlowFast/kinetics/k400 --rand_aug --epochs 50 --no_env --repeat_aug 2 --model vit_large_patch16 --batch_size 2 --distributed --dist_eval --smoothing 0.1 --mixup 0.8 --cutmix 1.0 --mixup_prob 1.0 --blr 0.0024 --num_frames 16 --pin_mem --num_workers 12 --sampling_rate 4 --dropout 0.3 --warmup_epochs 5 --layer_decay 0.75 --drop_path_rate 0.2 --aa rand-m7-mstd0.5-inc1 --cls_embed --sep_pos_embed --clip_grad 5.0 --fp32"}${FINETUNE_APPENDIX} 42 | 43 | T_CONFIG=${T_CONFIG-"--path_to_data_dir manifold://fair_vision_data/tree/PySlowFast/kinetics/k400 --no_env --model vit_large_patch16 --batch_size 2 --distributed --dist_eval --num_frames 16 --pin_mem --num_workers 12 --sampling_rate 4 --dropout 0.3 --cls_embed --sep_pos_embed --fp32"}${TESTING_APPENDIX} 44 | 45 | 46 | flow-cli canary mae_st.mae_st.workflow@//fblearner/flow/projects/mae_st:workflow \ 47 | --parameters-json '{ 48 | "num_shard_pretrain": '"${1}"', 49 | "num_shard_finetune": '"${2}"', 50 | "num_shard_test": '"${3}"', 51 | "pretrain_config": "'"${P_CONFIG}"'", 52 | "downstream_config": "'"${D_CONFIG}"'", 53 | "test_config": "'"${T_CONFIG}"'", 54 | "output_dir": "'"${CHECKPOINT_PATH}"'", 55 | "gang_schedule": "'"${GANG_SCHEDULE}"'", 56 | "gang_affinity": "'"${GANG_AFFINITY}"'", 57 | "gpu_type": "'"${GPU_TYPE}"'", 58 | "entitlement": "'"${ENT}"'"}' \ 59 | --entitlement "default_ncg" \ 60 | --run-as-secure-group "${SECURE_GROUP-vidron}" \ 61 | --name "${POSTFIX}||${P_CONFIG}||${1}nodes" \ 62 | --mode opt \ 63 | -------------------------------------------------------------------------------- /launch_fb_local.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | #!/usr/bin/env bash 5 | 6 | # sudo fuser -v /dev/nvidia* | grep -o '[[:digit:]]*' |xargs -I{} sudo kill -9 {} 7 | 8 | # buck build --config client.skip-action-graph-cache=true @mode/opt -c python.native_link_strategy=separate \ 9 | buck build @mode/opt @mode/inplace \ 10 | //vision/fair/mae_st/... --show-output 11 | 12 | # 0: pretrain, 1: finetune, 2: test 13 | 14 | if [ "$1" -lt 1 ] 15 | then 16 | 17 | echo "pretrain" 18 | 19 | /data/users/"${USER}"/fbsource/fbcode/buck-out/gen/vision/fair/mae_st/run_pretrain_bin.par \ 20 | --path_to_data_dir manifold://fair_vision_data/tree/PySlowFast/kinetics/k400 \ 21 | --batch_size 1 --decoder_embed_dim 64 --decoder_depth 2 \ 22 | --epochs 1 --mask_ratio 0.9 --repeat_aug 2 \ 23 | --model mae_vit_large_patch16 \ 24 | --sampling_rate 4 --num_frames 16 \ 25 | --num_workers 2 \ 26 | --bias_wd \ 27 | --trunc_init \ 28 | --fp32 \ 29 | --jitter_aspect_relative 0.75 1.3333 --jitter_scales_relative 0.5 1.0 \ 30 | --cls_embed \ 31 | --sep_pos_embed \ 32 | --t_patch_size 2 \ 33 | 34 | --resume manifold://winvision/tree/feichtenhofer/logs/2022-04-07-104623-303/checkpoints/ssl_eval_checkpoint_epoch_00050.pyth \ 35 | 36 | --resume manifold://winvision/tree/feichtenhofer/logs/2022-04-05-150122-992/checkpoints/ssl_eval_checkpoint_epoch_00050.pyth \ 37 | 38 | --learnable_pos_embed \ 39 | 40 | --decoder_attn AttentionRelPos --encoder_attn AttentionRelPos --rel_pos_embed \ 41 | 42 | else 43 | 44 | if [ "$1" -lt 2 ] 45 | then 46 | 47 | echo "finetune" 48 | 49 | # AttentionSubsampleMaxpool, AttentionSubsampleStride2, AttentionSubsampleRand10, AttentionSubsampleRand25, AttentionSubsampleRand50, 50 | /data/users/"${USER}"/fbsource/fbcode/buck-out/gen/vision/fair/mae_st/run_finetune_bin.par \ 51 | --batch_size 1 --epochs 1 --repeat_aug 1 --smoothing 0.1 \ 52 | --mixup 0.0 --cutmix 0.0 --mixup_prob 0.0 \ 53 | --model vit_large_patch16 \ 54 | --t_patch_size 2 --num_frames 16 \ 55 | --rand_aug \ 56 | --sep_pos_embed \ 57 | --fp32 \ 58 | --cls_embed \ 59 | --finetune manifold://winvision/tree/haoqifan/logs/2022-05-17-131324-457/pretrain/checkpoint-00049.pth \ 60 | 61 | --finetune manifold://winvision/tree/feichtenhofer/logs/2022-04-07-104623-303/checkpoints/ssl_checkpoint_epoch_00050.pyth \ 62 | 63 | 64 | 65 | 66 | 67 | --encoder_attn AttentionOrg \ 68 | 69 | --finetune checkpoint-00000.pth \ 70 | 71 | 72 | --finetune manifold://winvision/tree/feichtenhofer/logs/2022-04-07-104623-303/checkpoints/ssl_eval_checkpoint_epoch_00050.pyth \ 73 | 74 | 75 | --finetune manifold://winvision/tree/feichtenhofer/logs/2022-04-05-150122-992/checkpoints/ssl_eval_checkpoint_epoch_00050.pyth \ 76 | 77 | --finetune manifold://winvision/tree/lyttonhao/mae_pretrain/Supin1k-ViT-Large-200ep-km_ema_wkbias.pth \ 78 | 79 | --encoder_attn AttentionRelPosWithCls \ 80 | --finetune mae_pretrain_vit_large.pth \ 81 | --no_qkv_bias \ 82 | --finetune manifold://winvision/tree/haoqifan/logs/2022-02-05-204420-480/pretrain/checkpoint-00399.pth \ 83 | 84 | # --no_qkv_bias 85 | 86 | # --encoder_attn AttentionSubsampleRand10 \ 87 | 88 | # --finetune manifold://fair_logging/tree/haoqifan/logs/2022-01-17-162701-592/pretrain/checkpoint-399.pth 89 | 90 | else 91 | 92 | echo "test" 93 | 94 | # AttentionSubsampleMaxpool, AttentionSubsampleStride2, AttentionSubsampleRand10, AttentionSubsampleRand25, AttentionSubsampleRand50, 95 | /data/users/"${USER}"/fbsource/fbcode/buck-out/gen/vision/fair/mae_st/run_test_bin.par \ 96 | --batch_size 2 97 | --model vit_large_patch16 \ 98 | --t_patch_size 2 --num_frames 16 \ 99 | --cls_embed --sep_pos_embed 100 | --finetune manifold://winvision/tree/feichtenhofer/logs/2022-04-05-150122-992/checkpoints/ssl_eval_checkpoint_epoch_00050.pyth \ 101 | 102 | --finetune manifold://fair_logging/tree/haoqifan/logs/2022-01-25-012936-625/downstream/checkpoint-99.pth 103 | 104 | # --finetune manifold://fair_logging/tree/haoqifan/logs/2022-01-17-162701-592/pretrain/checkpoint-399.pth 105 | 106 | fi 107 | fi 108 | -------------------------------------------------------------------------------- /launch_fb_tensorboard.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | #!/bin/sh 5 | 6 | 7 | # buck build @mode/opt //tensorboard 8 | if [ "$1" -lt 1 ] 9 | then 10 | ~/local/fbsource/fbcode/buck-out/gen/tensorboard/tensorboard.par --port=8092 --logdir=manifold://winvision/tree/haoqifan/logs/tensorboard/pretrain 11 | else 12 | ~/local/fbsource/fbcode/buck-out/gen/tensorboard/tensorboard.par --port=8095 --logdir=manifold://winvision/tree/haoqifan/logs/tensorboard/downstream 13 | fi 14 | 15 | # ~/local/fbsource/fbcode/buck-out/gen/tensorboard/tensorboard.par --port=8095 --logdir_spec=fair_logging:manifold://fair_logging/tree/haoqifan/logs/tensorboard/downstream,winvision:manifold://winvision/tree/haoqifan/logs/tensorboard/downstream 16 | -------------------------------------------------------------------------------- /main_pretrain.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | import argparse 12 | import datetime 13 | import json 14 | import os 15 | import time 16 | 17 | import mae_st.util.env 18 | 19 | import mae_st.util.misc as misc 20 | 21 | import numpy as np 22 | import timm 23 | import torch 24 | import torch.backends.cudnn as cudnn 25 | from iopath.common.file_io import g_pathmgr as pathmgr 26 | from mae_st import models_mae 27 | from mae_st.engine_pretrain import train_one_epoch 28 | from mae_st.util.kinetics import Kinetics 29 | from mae_st.util.misc import NativeScalerWithGradNormCount as NativeScaler 30 | from tensorboard.compat.tensorflow_stub.io.gfile import register_filesystem 31 | from torch.utils.tensorboard import SummaryWriter 32 | 33 | 34 | def get_args_parser(): 35 | parser = argparse.ArgumentParser("MAE pre-training", add_help=False) 36 | parser.add_argument( 37 | "--batch_size", 38 | default=4, 39 | type=int, 40 | help="Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus", 41 | ) 42 | parser.add_argument("--epochs", default=100, type=int) 43 | parser.add_argument( 44 | "--accum_iter", 45 | default=1, 46 | type=int, 47 | help="Accumulate gradient iterations (for increasing the effective batch size under memory constraints)", 48 | ) 49 | 50 | # Model parameters 51 | parser.add_argument( 52 | "--model", 53 | default="mae_vit_large_patch16", 54 | type=str, 55 | metavar="MODEL", 56 | help="Name of model to train", 57 | ) 58 | 59 | parser.add_argument("--input_size", default=224, type=int, help="images input size") 60 | 61 | parser.add_argument( 62 | "--mask_ratio", 63 | default=0.75, 64 | type=float, 65 | help="Masking ratio (percentage of removed patches).", 66 | ) 67 | 68 | parser.add_argument( 69 | "--norm_pix_loss", 70 | action="store_true", 71 | help="Use (per-patch) normalized pixels as targets for computing loss", 72 | ) 73 | parser.set_defaults(norm_pix_loss=False) 74 | 75 | # Optimizer parameters 76 | parser.add_argument( 77 | "--weight_decay", type=float, default=0.05, help="weight decay (default: 0.05)" 78 | ) 79 | 80 | parser.add_argument( 81 | "--lr", 82 | type=float, 83 | default=None, 84 | metavar="LR", 85 | help="learning rate (absolute lr)", 86 | ) 87 | parser.add_argument( 88 | "--blr", 89 | type=float, 90 | default=1e-3, 91 | metavar="LR", 92 | help="base learning rate: absolute_lr = base_lr * total_batch_size / 256", 93 | ) 94 | parser.add_argument( 95 | "--min_lr", 96 | type=float, 97 | default=0.0, 98 | metavar="LR", 99 | help="lower lr bound for cyclic schedulers that hit 0", 100 | ) 101 | 102 | parser.add_argument( 103 | "--warmup_epochs", type=int, default=40, metavar="N", help="epochs to warmup LR" 104 | ) 105 | parser.add_argument( 106 | "--path_to_data_dir", 107 | default="", 108 | help="path where to save, empty for no saving", 109 | ) 110 | parser.add_argument( 111 | "--output_dir", 112 | default="./output_dir", 113 | help="path where to save, empty for no saving", 114 | ) 115 | parser.add_argument( 116 | "--log_dir", 117 | default="", 118 | help="path where to tensorboard log", 119 | ) 120 | parser.add_argument( 121 | "--device", default="cuda", help="device to use for training / testing" 122 | ) 123 | parser.add_argument("--seed", default=0, type=int) 124 | parser.add_argument("--resume", default="", help="resume from checkpoint") 125 | 126 | parser.add_argument( 127 | "--start_epoch", default=0, type=int, metavar="N", help="start epoch" 128 | ) 129 | parser.add_argument("--num_workers", default=10, type=int) 130 | parser.add_argument( 131 | "--pin_mem", 132 | action="store_true", 133 | help="Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.", 134 | ) 135 | parser.add_argument("--no_pin_mem", action="store_false", dest="pin_mem") 136 | parser.set_defaults(pin_mem=True) 137 | 138 | # distributed training parameters 139 | parser.add_argument( 140 | "--world_size", default=1, type=int, help="number of distributed processes" 141 | ) 142 | parser.add_argument("--local_rank", default=-1, type=int) 143 | parser.add_argument("--dist_on_itp", action="store_true") 144 | parser.add_argument("--no_env", action="store_true") 145 | 146 | # Video related configs 147 | parser.add_argument( 148 | "--dist_url", default="env://", help="url used to set up distributed training" 149 | ) 150 | 151 | parser.add_argument("--decoder_embed_dim", default=512, type=int) 152 | parser.add_argument("--decoder_depth", default=8, type=int) 153 | parser.add_argument("--decoder_num_heads", default=16, type=int) 154 | parser.add_argument("--t_patch_size", default=2, type=int) 155 | parser.add_argument("--num_frames", default=16, type=int) 156 | parser.add_argument("--checkpoint_period", default=1, type=int) 157 | parser.add_argument("--sampling_rate", default=4, type=int) 158 | parser.add_argument("--distributed", action="store_true") 159 | parser.add_argument("--repeat_aug", default=4, type=int) 160 | parser.add_argument( 161 | "--clip_grad", 162 | type=float, 163 | default=None, 164 | ) 165 | parser.add_argument("--no_qkv_bias", action="store_true") 166 | parser.add_argument("--bias_wd", action="store_true") 167 | parser.add_argument("--num_checkpoint_del", default=20, type=int) 168 | parser.add_argument("--sep_pos_embed", action="store_true") 169 | parser.set_defaults(sep_pos_embed=True) 170 | parser.add_argument( 171 | "--trunc_init", 172 | action="store_true", 173 | ) 174 | parser.add_argument( 175 | "--fp32", 176 | action="store_true", 177 | ) 178 | parser.set_defaults(fp32=True) 179 | parser.add_argument( 180 | "--jitter_scales_relative", 181 | default=[0.5, 1.0], 182 | type=float, 183 | nargs="+", 184 | ) 185 | parser.add_argument( 186 | "--jitter_aspect_relative", 187 | default=[0.75, 1.3333], 188 | type=float, 189 | nargs="+", 190 | ) 191 | parser.add_argument( 192 | "--beta", 193 | default=None, 194 | type=float, 195 | nargs="+", 196 | ) 197 | parser.add_argument( 198 | "--pred_t_dim", 199 | type=int, 200 | default=8, 201 | ) 202 | parser.add_argument("--cls_embed", action="store_true") 203 | parser.set_defaults(cls_embed=True) 204 | return parser 205 | 206 | 207 | def main(args): 208 | misc.init_distributed_mode(args) 209 | 210 | print("job dir: {}".format(os.path.dirname(os.path.realpath(__file__)))) 211 | print("{}".format(args).replace(", ", ",\n")) 212 | 213 | device = torch.device(args.device) 214 | 215 | # fix the seed for reproducibility 216 | seed = args.seed + misc.get_rank() 217 | torch.manual_seed(seed) 218 | np.random.seed(seed) 219 | 220 | cudnn.benchmark = True 221 | 222 | dataset_train = Kinetics( 223 | mode="pretrain", 224 | path_to_data_dir=args.path_to_data_dir, 225 | sampling_rate=args.sampling_rate, 226 | num_frames=args.num_frames, 227 | train_jitter_scales=(256, 320), 228 | repeat_aug=args.repeat_aug, 229 | jitter_aspect_relative=args.jitter_aspect_relative, 230 | jitter_scales_relative=args.jitter_scales_relative, 231 | ) 232 | if args.distributed: 233 | num_tasks = misc.get_world_size() 234 | global_rank = misc.get_rank() 235 | sampler_train = torch.utils.data.DistributedSampler( 236 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 237 | ) 238 | print("Sampler_train = %s" % str(sampler_train)) 239 | else: 240 | num_tasks = 1 241 | global_rank = 0 242 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 243 | 244 | if global_rank == 0 and args.log_dir is not None: 245 | try: 246 | pathmgr.mkdirs(args.log_dir) 247 | except Exception as _: 248 | pass 249 | log_writer = SummaryWriter(log_dir=args.log_dir) 250 | else: 251 | log_writer = None 252 | 253 | data_loader_train = torch.utils.data.DataLoader( 254 | dataset_train, 255 | sampler=sampler_train, 256 | batch_size=args.batch_size, 257 | num_workers=args.num_workers, 258 | pin_memory=args.pin_mem, 259 | drop_last=True, 260 | ) 261 | 262 | # define the model 263 | model = models_mae.__dict__[args.model]( 264 | **vars(args), 265 | ) 266 | 267 | model.to(device) 268 | 269 | model_without_ddp = model 270 | print("Model = %s" % str(model_without_ddp)) 271 | 272 | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() 273 | 274 | if args.lr is None: # only base_lr is specified 275 | args.lr = args.blr * eff_batch_size / 256 276 | 277 | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) 278 | print("actual lr: %.2e" % args.lr) 279 | 280 | print("accumulate grad iterations: %d" % args.accum_iter) 281 | print("effective batch size: %d" % eff_batch_size) 282 | 283 | if args.distributed: 284 | model = torch.nn.parallel.DistributedDataParallel( 285 | model, 286 | device_ids=[torch.cuda.current_device()], 287 | # find_unused_parameters=True, 288 | ) 289 | model_without_ddp = model.module 290 | 291 | # following timm: set wd as 0 for bias and norm layers 292 | param_groups = misc.add_weight_decay( 293 | model_without_ddp, 294 | args.weight_decay, 295 | bias_wd=args.bias_wd, 296 | ) 297 | if args.beta is None: 298 | beta = (0.9, 0.95) 299 | else: 300 | beta = args.beta 301 | optimizer = torch.optim._multi_tensor.AdamW( 302 | param_groups, 303 | lr=args.lr, 304 | betas=beta, 305 | ) 306 | loss_scaler = NativeScaler(fp32=args.fp32) 307 | 308 | misc.load_model( 309 | args=args, 310 | model_without_ddp=model_without_ddp, 311 | optimizer=optimizer, 312 | loss_scaler=loss_scaler, 313 | ) 314 | 315 | checkpoint_path = "" 316 | print(f"Start training for {args.epochs} epochs") 317 | start_time = time.time() 318 | for epoch in range(args.start_epoch, args.epochs): 319 | if args.distributed: 320 | data_loader_train.sampler.set_epoch(epoch) 321 | train_stats = train_one_epoch( 322 | model, 323 | data_loader_train, 324 | optimizer, 325 | device, 326 | epoch, 327 | loss_scaler, 328 | log_writer=log_writer, 329 | args=args, 330 | fp32=args.fp32, 331 | ) 332 | if args.output_dir and ( 333 | epoch % args.checkpoint_period == 0 or epoch + 1 == args.epochs 334 | ): 335 | checkpoint_path = misc.save_model( 336 | args=args, 337 | model=model, 338 | model_without_ddp=model_without_ddp, 339 | optimizer=optimizer, 340 | loss_scaler=loss_scaler, 341 | epoch=epoch, 342 | ) 343 | 344 | log_stats = { 345 | **{f"train_{k}": v for k, v in train_stats.items()}, 346 | "epoch": epoch, 347 | } 348 | 349 | if args.output_dir and misc.is_main_process(): 350 | if log_writer is not None: 351 | log_writer.flush() 352 | with pathmgr.open( 353 | f"{args.output_dir}/log.txt", 354 | "a", 355 | ) as f: 356 | f.write(json.dumps(log_stats) + "\n") 357 | 358 | total_time = time.time() - start_time 359 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 360 | print("Training time {}".format(total_time_str)) 361 | print(torch.cuda.memory_allocated()) 362 | return [checkpoint_path] 363 | 364 | 365 | def launch_one_thread( 366 | local_rank, 367 | shard_rank, 368 | num_gpus_per_node, 369 | num_shards, 370 | init_method, 371 | output_path, 372 | opts, 373 | stats_queue, 374 | ): 375 | print(opts) 376 | args = get_args_parser() 377 | args = args.parse_args(opts) 378 | args.rank = shard_rank * num_gpus_per_node + local_rank 379 | args.world_size = num_shards * num_gpus_per_node 380 | args.gpu = local_rank 381 | args.dist_url = init_method 382 | args.output_dir = output_path 383 | output = main(args) 384 | stats_queue.put(output) 385 | -------------------------------------------------------------------------------- /main_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import argparse 13 | import os 14 | 15 | import mae_st.models_vit as models_vit 16 | import mae_st.util.misc as misc 17 | 18 | import numpy as np 19 | import timm 20 | import torch 21 | import torch.backends.cudnn as cudnn 22 | from iopath.common.file_io import g_pathmgr as pathmgr 23 | from mae_st.engine_test import test 24 | from mae_st.util.kinetics import Kinetics 25 | from mae_st.util.logging import master_print as print 26 | from mae_st.util.meters import TestMeter 27 | from mae_st.util.pos_embed import interpolate_pos_embed 28 | 29 | 30 | def get_args_parser(): 31 | parser = argparse.ArgumentParser( 32 | "MAE fine-tuning for image classification", add_help=False 33 | ) 34 | parser.add_argument( 35 | "--batch_size", 36 | default=64, 37 | type=int, 38 | help="Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus", 39 | ) 40 | # Model parameters 41 | parser.add_argument( 42 | "--model", 43 | default="vit_large_patch16", 44 | type=str, 45 | metavar="MODEL", 46 | help="Name of model to train", 47 | ) 48 | 49 | parser.add_argument("--input_size", default=224, type=int, help="images input size") 50 | 51 | parser.add_argument( 52 | "--dropout", 53 | type=float, 54 | default=0.5, 55 | ) 56 | 57 | parser.add_argument( 58 | "--drop_path_rate", 59 | type=float, 60 | default=0.1, 61 | metavar="PCT", 62 | help="Drop path rate (default: 0.1)", 63 | ) 64 | 65 | # * Finetuning params 66 | parser.add_argument("--finetune", default="", help="finetune from checkpoint") 67 | parser.add_argument("--global_pool", action="store_true") 68 | parser.set_defaults(global_pool=True) 69 | parser.add_argument( 70 | "--cls_token", 71 | action="store_false", 72 | dest="global_pool", 73 | help="Use class token instead of global pool for classification", 74 | ) 75 | parser.add_argument( 76 | "--nb_classes", 77 | default=400, 78 | type=int, 79 | help="number of the classification types", 80 | ) 81 | parser.add_argument( 82 | "--path_to_data_dir", 83 | default="", 84 | help="path where to save, empty for no saving", 85 | ) 86 | parser.add_argument( 87 | "--output_dir", 88 | default="./output_dir", 89 | help="path where to save, empty for no saving", 90 | ) 91 | parser.add_argument( 92 | "--log_dir", 93 | default="", 94 | help="path where to tensorboard log", 95 | ) 96 | parser.add_argument( 97 | "--device", default="cuda", help="device to use for training / testing" 98 | ) 99 | parser.add_argument("--seed", default=0, type=int) 100 | parser.add_argument("--resume", default="", help="resume from checkpoint") 101 | 102 | parser.add_argument("--eval", action="store_true", help="Perform evaluation only") 103 | parser.add_argument( 104 | "--dist_eval", 105 | action="store_true", 106 | default=False, 107 | help="Enabling distributed evaluation (recommended during training for faster monitor", 108 | ) 109 | parser.add_argument("--num_workers", default=10, type=int) 110 | parser.add_argument( 111 | "--pin_mem", 112 | action="store_true", 113 | help="Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.", 114 | ) 115 | parser.add_argument("--no_pin_mem", action="store_false", dest="pin_mem") 116 | parser.set_defaults(pin_mem=True) 117 | 118 | # distributed training parameters 119 | parser.add_argument( 120 | "--world_size", default=1, type=int, help="number of distributed processes" 121 | ) 122 | parser.add_argument("--local_rank", default=-1, type=int) 123 | parser.add_argument("--dist_on_itp", action="store_true") 124 | parser.add_argument( 125 | "--dist_url", default="env://", help="url used to set up distributed training" 126 | ) 127 | 128 | # Video related configs 129 | parser.add_argument("--no_env", action="store_true") 130 | parser.add_argument("--rand_aug", default=False, action="store_true") 131 | parser.add_argument("--t_patch_size", default=4, type=int) 132 | parser.add_argument("--num_frames", default=32, type=int) 133 | parser.add_argument("--checkpoint_period", default=10, type=int) 134 | parser.add_argument("--sampling_rate", default=2, type=int) 135 | parser.add_argument("--distributed", action="store_true") 136 | parser.add_argument("--repeat_aug", default=1, type=int) 137 | parser.add_argument("--encoder_attn", default="AttentionSubsampleMaxpool", type=str) 138 | 139 | # Dataset parameters 140 | parser.add_argument( 141 | "--decoder_device", 142 | default="cpu", 143 | type=str, 144 | ) 145 | # Dataset parameters 146 | parser.add_argument( 147 | "--decoder_backend", 148 | default="torchvision", 149 | type=str, 150 | ) 151 | parser.add_argument("--no_qkv_bias", action="store_true") 152 | parser.add_argument("--sep_pos_embed", action="store_true") 153 | parser.add_argument( 154 | "--fp32", 155 | action="store_true", 156 | ) 157 | parser.add_argument("--cls_embed", action="store_true") 158 | return parser 159 | 160 | 161 | def main(args): 162 | misc.init_distributed_mode(args) 163 | 164 | print("job dir: {}".format(os.path.dirname(os.path.realpath(__file__)))) 165 | print("{}".format(args).replace(", ", ",\n")) 166 | 167 | device = torch.device(args.device) 168 | 169 | # fix the seed for reproducibility 170 | seed = args.seed + misc.get_rank() 171 | torch.manual_seed(seed) 172 | np.random.seed(seed) 173 | 174 | cudnn.benchmark = True 175 | 176 | dataset_test = Kinetics( 177 | mode="test", 178 | path_to_data_dir=args.path_to_data_dir, 179 | sampling_rate=args.sampling_rate, 180 | num_frames=args.num_frames, 181 | train_jitter_scales=(256, 320), 182 | test_crop_size=224, 183 | repeat_aug=args.repeat_aug, 184 | rand_aug=False, 185 | ) 186 | test_meter = TestMeter( 187 | dataset_test.num_videos // (3 * 10), 188 | 3 * 10, 189 | args.nb_classes, 190 | len(dataset_test), 191 | False, 192 | "sum", 193 | ) 194 | 195 | if args.distributed: 196 | num_tasks = misc.get_world_size() 197 | global_rank = misc.get_rank() 198 | if args.dist_eval: 199 | if len(dataset_test) % num_tasks != 0: 200 | print( 201 | "Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. " 202 | "This will slightly alter validation results as extra duplicate entries are added to achieve " 203 | "equal num of samples per-process." 204 | ) 205 | sampler_test = torch.utils.data.DistributedSampler( 206 | dataset_test, num_replicas=num_tasks, rank=global_rank, shuffle=True 207 | ) # shuffle=True to reduce monitor bias 208 | else: 209 | sampler_test = torch.utils.data.SequentialSampler(dataset_test) 210 | else: 211 | num_tasks = 1 212 | global_rank = 0 213 | sampler_test = torch.utils.data.RandomSampler(dataset_test) 214 | data_loader_test = torch.utils.data.DataLoader( 215 | dataset_test, 216 | sampler=sampler_test, 217 | batch_size=args.batch_size, 218 | num_workers=args.num_workers, 219 | pin_memory=args.pin_mem, 220 | drop_last=False, 221 | ) 222 | 223 | model = models_vit.__dict__[args.model]( 224 | num_classes=args.nb_classes, 225 | **vars(args), 226 | ) 227 | 228 | with pathmgr.open(args.finetune, "rb") as f: 229 | checkpoint = torch.load(f, map_location="cpu") 230 | 231 | print("Load pre-trained checkpoint from: %s" % args.finetune) 232 | if "model" in checkpoint.keys(): 233 | checkpoint_model = checkpoint["model"] 234 | else: 235 | checkpoint_model = checkpoint["model_state"] 236 | # interpolate position embedding 237 | interpolate_pos_embed(model, checkpoint_model) 238 | 239 | checkpoint_model = misc.convert_checkpoint(checkpoint_model) 240 | 241 | # load pre-trained model 242 | msg = model.load_state_dict(checkpoint_model, strict=False) 243 | print(msg) 244 | model.to(device) 245 | 246 | model_without_ddp = model 247 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 248 | 249 | print("Model = %s" % str(model_without_ddp)) 250 | print("number of params (M): %.2f" % (n_parameters / 1.0e6)) 251 | 252 | if args.distributed: 253 | model = torch.nn.parallel.DistributedDataParallel( 254 | model, device_ids=[torch.cuda.current_device()] 255 | ) 256 | model_without_ddp = model.module 257 | 258 | log_stats = test(data_loader_test, model, device, test_meter, fp32=args.fp32) 259 | return [log_stats] 260 | 261 | 262 | def launch_one_thread( 263 | local_rank, 264 | shard_rank, 265 | num_gpus_per_node, 266 | num_shards, 267 | init_method, 268 | output_path, 269 | opts, 270 | stats_queue, 271 | ): 272 | print(opts) 273 | args = get_args_parser() 274 | args = args.parse_args(opts) 275 | args.rank = shard_rank * num_gpus_per_node + local_rank 276 | args.world_size = num_shards * num_gpus_per_node 277 | args.gpu = local_rank 278 | args.dist_url = init_method 279 | args.output_dir = output_path 280 | output = main(args) 281 | stats_queue.put(output) 282 | -------------------------------------------------------------------------------- /models_mae.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | # DeiT: https://github.com/facebookresearch/deit 10 | # MAE: https://github.com/facebookresearch/mae 11 | # -------------------------------------------------------- 12 | 13 | from functools import partial 14 | 15 | import torch 16 | import torch.nn as nn 17 | from mae_st.util import video_vit 18 | from mae_st.util.logging import master_print as print 19 | 20 | 21 | class MaskedAutoencoderViT(nn.Module): 22 | """Masked Autoencoder with VisionTransformer backbone""" 23 | 24 | def __init__( 25 | self, 26 | img_size=224, 27 | patch_size=16, 28 | in_chans=3, 29 | embed_dim=1024, 30 | depth=24, 31 | num_heads=16, 32 | decoder_embed_dim=512, 33 | decoder_depth=8, 34 | decoder_num_heads=16, 35 | mlp_ratio=4.0, 36 | norm_layer=nn.LayerNorm, 37 | norm_pix_loss=False, 38 | num_frames=16, 39 | t_patch_size=4, 40 | patch_embed=video_vit.PatchEmbed, 41 | no_qkv_bias=False, 42 | sep_pos_embed=False, 43 | trunc_init=False, 44 | cls_embed=False, 45 | pred_t_dim=8, 46 | **kwargs, 47 | ): 48 | super().__init__() 49 | self.trunc_init = trunc_init 50 | self.sep_pos_embed = sep_pos_embed 51 | self.cls_embed = cls_embed 52 | self.pred_t_dim = pred_t_dim 53 | self.t_pred_patch_size = t_patch_size * pred_t_dim // num_frames 54 | 55 | self.patch_embed = patch_embed( 56 | img_size, 57 | patch_size, 58 | in_chans, 59 | embed_dim, 60 | num_frames, 61 | t_patch_size, 62 | ) 63 | num_patches = self.patch_embed.num_patches 64 | input_size = self.patch_embed.input_size 65 | self.input_size = input_size 66 | 67 | if self.cls_embed: 68 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 69 | self.decoder_cls_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) 70 | 71 | if sep_pos_embed: 72 | self.pos_embed_spatial = nn.Parameter( 73 | torch.zeros(1, input_size[1] * input_size[2], embed_dim) 74 | ) 75 | self.pos_embed_temporal = nn.Parameter( 76 | torch.zeros(1, input_size[0], embed_dim) 77 | ) 78 | if self.cls_embed: 79 | self.pos_embed_class = nn.Parameter(torch.zeros(1, 1, embed_dim)) 80 | else: 81 | if self.cls_embed: 82 | _num_patches = num_patches + 1 83 | else: 84 | _num_patches = num_patches 85 | 86 | self.pos_embed = nn.Parameter( 87 | torch.zeros(1, _num_patches, embed_dim), 88 | ) 89 | 90 | self.blocks = nn.ModuleList( 91 | [ 92 | video_vit.Block( 93 | embed_dim, 94 | num_heads, 95 | mlp_ratio, 96 | qkv_bias=not no_qkv_bias, 97 | qk_scale=None, 98 | norm_layer=norm_layer, 99 | ) 100 | for i in range(depth) 101 | ] 102 | ) 103 | self.norm = norm_layer(embed_dim) 104 | 105 | self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) 106 | 107 | self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) 108 | 109 | if sep_pos_embed: 110 | self.decoder_pos_embed_spatial = nn.Parameter( 111 | torch.zeros(1, input_size[1] * input_size[2], decoder_embed_dim) 112 | ) 113 | self.decoder_pos_embed_temporal = nn.Parameter( 114 | torch.zeros(1, input_size[0], decoder_embed_dim) 115 | ) 116 | if self.cls_embed: 117 | self.decoder_pos_embed_class = nn.Parameter( 118 | torch.zeros(1, 1, decoder_embed_dim) 119 | ) 120 | else: 121 | if self.cls_embed: 122 | _num_patches = num_patches + 1 123 | else: 124 | _num_patches = num_patches 125 | 126 | self.decoder_pos_embed = nn.Parameter( 127 | torch.zeros(1, _num_patches, decoder_embed_dim), 128 | ) 129 | 130 | self.decoder_blocks = nn.ModuleList( 131 | [ 132 | video_vit.Block( 133 | decoder_embed_dim, 134 | decoder_num_heads, 135 | mlp_ratio, 136 | qkv_bias=not no_qkv_bias, 137 | qk_scale=None, 138 | norm_layer=norm_layer, 139 | ) 140 | for i in range(decoder_depth) 141 | ] 142 | ) 143 | 144 | self.decoder_norm = norm_layer(decoder_embed_dim) 145 | self.decoder_pred = nn.Linear( 146 | decoder_embed_dim, 147 | self.t_pred_patch_size * patch_size**2 * in_chans, 148 | bias=True, 149 | ) 150 | 151 | self.norm_pix_loss = norm_pix_loss 152 | 153 | self.initialize_weights() 154 | 155 | print("model initialized") 156 | 157 | def initialize_weights(self): 158 | if self.cls_embed: 159 | torch.nn.init.trunc_normal_(self.cls_token, std=0.02) 160 | if self.sep_pos_embed: 161 | torch.nn.init.trunc_normal_(self.pos_embed_spatial, std=0.02) 162 | torch.nn.init.trunc_normal_(self.pos_embed_temporal, std=0.02) 163 | 164 | torch.nn.init.trunc_normal_(self.decoder_pos_embed_spatial, std=0.02) 165 | torch.nn.init.trunc_normal_(self.decoder_pos_embed_temporal, std=0.02) 166 | 167 | if self.cls_embed: 168 | torch.nn.init.trunc_normal_(self.pos_embed_class, std=0.02) 169 | torch.nn.init.trunc_normal_(self.decoder_pos_embed_class, std=0.02) 170 | else: 171 | torch.nn.init.trunc_normal_(self.pos_embed, std=0.02) 172 | torch.nn.init.trunc_normal_(self.decoder_pos_embed, std=0.02) 173 | w = self.patch_embed.proj.weight.data 174 | if self.trunc_init: 175 | torch.nn.init.trunc_normal_(w) 176 | torch.nn.init.trunc_normal_(self.mask_token, std=0.02) 177 | else: 178 | torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 179 | torch.nn.init.normal_(self.mask_token, std=0.02) 180 | 181 | # initialize nn.Linear and nn.LayerNorm 182 | self.apply(self._init_weights) 183 | 184 | def _init_weights(self, m): 185 | if isinstance(m, nn.Linear): 186 | # we use xavier_uniform following official JAX ViT: 187 | if self.trunc_init: 188 | nn.init.trunc_normal_(m.weight, std=0.02) 189 | else: 190 | torch.nn.init.xavier_uniform_(m.weight) 191 | if isinstance(m, nn.Linear) and m.bias is not None: 192 | nn.init.constant_(m.bias, 0) 193 | elif isinstance(m, nn.LayerNorm): 194 | nn.init.constant_(m.bias, 0) 195 | nn.init.constant_(m.weight, 1.0) 196 | 197 | def patchify(self, imgs): 198 | """ 199 | imgs: (N, 3, H, W) 200 | x: (N, L, patch_size**2 *3) 201 | """ 202 | N, _, T, H, W = imgs.shape 203 | p = self.patch_embed.patch_size[0] 204 | u = self.t_pred_patch_size 205 | assert H == W and H % p == 0 and T % u == 0 206 | h = w = H // p 207 | t = T // u 208 | 209 | x = imgs.reshape(shape=(N, 3, t, u, h, p, w, p)) 210 | x = torch.einsum("nctuhpwq->nthwupqc", x) 211 | x = x.reshape(shape=(N, t * h * w, u * p**2 * 3)) 212 | self.patch_info = (N, T, H, W, p, u, t, h, w) 213 | return x 214 | 215 | def unpatchify(self, x): 216 | """ 217 | x: (N, L, patch_size**2 *3) 218 | imgs: (N, 3, H, W) 219 | """ 220 | N, T, H, W, p, u, t, h, w = self.patch_info 221 | 222 | x = x.reshape(shape=(N, t, h, w, u, p, p, 3)) 223 | 224 | x = torch.einsum("nthwupqc->nctuhpwq", x) 225 | imgs = x.reshape(shape=(N, 3, T, H, W)) 226 | return imgs 227 | 228 | def random_masking(self, x, mask_ratio): 229 | """ 230 | Perform per-sample random masking by per-sample shuffling. 231 | Per-sample shuffling is done by argsort random noise. 232 | x: [N, L, D], sequence 233 | """ 234 | N, L, D = x.shape # batch, length, dim 235 | len_keep = int(L * (1 - mask_ratio)) 236 | 237 | noise = torch.rand(N, L, device=x.device) # noise in [0, 1] 238 | 239 | # sort noise for each sample 240 | ids_shuffle = torch.argsort( 241 | noise, dim=1 242 | ) # ascend: small is keep, large is remove 243 | ids_restore = torch.argsort(ids_shuffle, dim=1) 244 | 245 | # keep the first subset 246 | ids_keep = ids_shuffle[:, :len_keep] 247 | x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) 248 | 249 | # generate the binary mask: 0 is keep, 1 is remove 250 | mask = torch.ones([N, L], device=x.device) 251 | mask[:, :len_keep] = 0 252 | # unshuffle to get the binary mask 253 | mask = torch.gather(mask, dim=1, index=ids_restore) 254 | 255 | return x_masked, mask, ids_restore, ids_keep 256 | 257 | def forward_encoder(self, x, mask_ratio): 258 | # embed patches 259 | x = self.patch_embed(x) 260 | N, T, L, C = x.shape 261 | 262 | x = x.reshape(N, T * L, C) 263 | 264 | # masking: length -> length * mask_ratio 265 | x, mask, ids_restore, ids_keep = self.random_masking(x, mask_ratio) 266 | x = x.view(N, -1, C) 267 | # append cls token 268 | if self.cls_embed: 269 | cls_token = self.cls_token 270 | cls_tokens = cls_token.expand(x.shape[0], -1, -1) 271 | x = torch.cat((cls_tokens, x), dim=1) 272 | 273 | # add pos embed w/o cls token 274 | if self.sep_pos_embed: 275 | pos_embed = self.pos_embed_spatial.repeat( 276 | 1, self.input_size[0], 1 277 | ) + torch.repeat_interleave( 278 | self.pos_embed_temporal, 279 | self.input_size[1] * self.input_size[2], 280 | dim=1, 281 | ) 282 | pos_embed = pos_embed.expand(x.shape[0], -1, -1) 283 | pos_embed = torch.gather( 284 | pos_embed, 285 | dim=1, 286 | index=ids_keep.unsqueeze(-1).repeat(1, 1, pos_embed.shape[2]), 287 | ) 288 | if self.cls_embed: 289 | pos_embed = torch.cat( 290 | [ 291 | self.pos_embed_class.expand(pos_embed.shape[0], -1, -1), 292 | pos_embed, 293 | ], 294 | 1, 295 | ) 296 | else: 297 | if self.cls_embed: 298 | cls_ind = 1 299 | else: 300 | cls_ind = 0 301 | pos_embed = self.pos_embed[:, cls_ind:, :].expand(x.shape[0], -1, -1) 302 | pos_embed = torch.gather( 303 | pos_embed, 304 | dim=1, 305 | index=ids_keep.unsqueeze(-1).repeat(1, 1, pos_embed.shape[2]), 306 | ) 307 | if self.cls_embed: 308 | pos_embed = torch.cat( 309 | [ 310 | self.pos_embed[:, :1, :].expand(x.shape[0], -1, -1), 311 | pos_embed, 312 | ], 313 | 1, 314 | ) 315 | x = x.view([N, -1, C]) + pos_embed 316 | 317 | # apply Transformer blocks 318 | for blk in self.blocks: 319 | x = blk(x) 320 | x = self.norm(x) 321 | 322 | if self.cls_embed: 323 | # remove cls token 324 | x = x[:, 1:, :] 325 | else: 326 | x = x[:, :, :] 327 | 328 | return x, mask, ids_restore 329 | 330 | def forward_decoder(self, x, ids_restore): 331 | N = x.shape[0] 332 | T = self.patch_embed.t_grid_size 333 | H = W = self.patch_embed.grid_size 334 | 335 | # embed tokens 336 | x = self.decoder_embed(x) 337 | C = x.shape[-1] 338 | 339 | # append mask tokens to sequence 340 | mask_tokens = self.mask_token.repeat(N, T * H * W + 0 - x.shape[1], 1) 341 | x_ = torch.cat([x[:, :, :], mask_tokens], dim=1) # no cls token 342 | x_ = x_.view([N, T * H * W, C]) 343 | x_ = torch.gather( 344 | x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x_.shape[2]) 345 | ) # unshuffle 346 | x = x_.view([N, T * H * W, C]) 347 | # append cls token 348 | if self.cls_embed: 349 | decoder_cls_token = self.decoder_cls_token 350 | decoder_cls_tokens = decoder_cls_token.expand(x.shape[0], -1, -1) 351 | x = torch.cat((decoder_cls_tokens, x), dim=1) 352 | 353 | if self.sep_pos_embed: 354 | decoder_pos_embed = self.decoder_pos_embed_spatial.repeat( 355 | 1, self.input_size[0], 1 356 | ) + torch.repeat_interleave( 357 | self.decoder_pos_embed_temporal, 358 | self.input_size[1] * self.input_size[2], 359 | dim=1, 360 | ) 361 | if self.cls_embed: 362 | decoder_pos_embed = torch.cat( 363 | [ 364 | self.decoder_pos_embed_class.expand( 365 | decoder_pos_embed.shape[0], -1, -1 366 | ), 367 | decoder_pos_embed, 368 | ], 369 | 1, 370 | ) 371 | else: 372 | decoder_pos_embed = self.decoder_pos_embed[:, :, :] 373 | 374 | # add pos embed 375 | x = x + decoder_pos_embed 376 | 377 | attn = self.decoder_blocks[0].attn 378 | requires_t_shape = hasattr(attn, "requires_t_shape") and attn.requires_t_shape 379 | if requires_t_shape: 380 | x = x.view([N, T, H * W, C]) 381 | 382 | # apply Transformer blocks 383 | for blk in self.decoder_blocks: 384 | x = blk(x) 385 | x = self.decoder_norm(x) 386 | 387 | # predictor projection 388 | x = self.decoder_pred(x) 389 | 390 | if requires_t_shape: 391 | x = x.view([N, T * H * W, -1]) 392 | 393 | if self.cls_embed: 394 | # remove cls token 395 | x = x[:, 1:, :] 396 | else: 397 | x = x[:, :, :] 398 | 399 | return x 400 | 401 | def forward_loss(self, imgs, pred, mask): 402 | """ 403 | imgs: [N, 3, T, H, W] 404 | pred: [N, t*h*w, u*p*p*3] 405 | mask: [N*t, h*w], 0 is keep, 1 is remove, 406 | """ 407 | _imgs = torch.index_select( 408 | imgs, 409 | 2, 410 | torch.linspace( 411 | 0, 412 | imgs.shape[2] - 1, 413 | self.pred_t_dim, 414 | ) 415 | .long() 416 | .to(imgs.device), 417 | ) 418 | target = self.patchify(_imgs) 419 | if self.norm_pix_loss: 420 | mean = target.mean(dim=-1, keepdim=True) 421 | var = target.var(dim=-1, keepdim=True) 422 | target = (target - mean) / (var + 1.0e-6) ** 0.5 423 | 424 | loss = (pred - target) ** 2 425 | loss = loss.mean(dim=-1) # [N, L], mean loss per patch 426 | mask = mask.view(loss.shape) 427 | 428 | loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches 429 | return loss 430 | 431 | def forward(self, imgs, mask_ratio=0.75): 432 | latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio) 433 | pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3] 434 | loss = self.forward_loss(imgs, pred, mask) 435 | return loss, pred, mask 436 | 437 | 438 | def mae_vit_base_patch16(**kwargs): 439 | model = MaskedAutoencoderViT( 440 | patch_size=16, 441 | embed_dim=768, 442 | depth=12, 443 | num_heads=12, 444 | mlp_ratio=4, 445 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 446 | **kwargs, 447 | ) 448 | return model 449 | 450 | 451 | def mae_vit_large_patch16(**kwargs): 452 | model = MaskedAutoencoderViT( 453 | patch_size=16, 454 | embed_dim=1024, 455 | depth=24, 456 | num_heads=16, 457 | mlp_ratio=4, 458 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 459 | **kwargs, 460 | ) 461 | return model 462 | 463 | 464 | def mae_vit_huge_patch14(**kwargs): 465 | model = MaskedAutoencoderViT( 466 | patch_size=14, 467 | embed_dim=1280, 468 | depth=32, 469 | num_heads=16, 470 | mlp_ratio=4, 471 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 472 | **kwargs, 473 | ) 474 | return model 475 | -------------------------------------------------------------------------------- /models_vit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | # DeiT: https://github.com/facebookresearch/deit 10 | # MAE: https://github.com/facebookresearch/mae 11 | # -------------------------------------------------------- 12 | 13 | from functools import partial 14 | 15 | import torch 16 | import torch.nn as nn 17 | from mae_st.util.logging import master_print as print 18 | 19 | from mae_st.util.video_vit import Attention, Block, PatchEmbed 20 | 21 | 22 | class VisionTransformer(nn.Module): 23 | """Vision Transformer with support for global average pooling""" 24 | 25 | def __init__( 26 | self, 27 | num_frames, 28 | t_patch_size, 29 | img_size=224, 30 | patch_size=16, 31 | in_chans=3, 32 | num_classes=400, 33 | embed_dim=768, 34 | depth=12, 35 | num_heads=12, 36 | mlp_ratio=4.0, 37 | no_qkv_bias=False, 38 | qk_scale=None, 39 | drop_rate=0.0, 40 | attn_drop_rate=0.0, 41 | drop_path_rate=0.0, 42 | norm_layer=nn.LayerNorm, 43 | dropout=0.5, 44 | sep_pos_embed=False, 45 | cls_embed=False, 46 | **kwargs, 47 | ): 48 | super().__init__() 49 | print(locals()) 50 | 51 | self.sep_pos_embed = sep_pos_embed 52 | # -------------------------------------------------------------------------- 53 | # MAE encoder specifics 54 | self.patch_embed = PatchEmbed( 55 | img_size, patch_size, in_chans, embed_dim, num_frames, t_patch_size 56 | ) 57 | num_patches = self.patch_embed.num_patches 58 | input_size = self.patch_embed.input_size 59 | self.input_size = input_size 60 | self.cls_embed = cls_embed 61 | 62 | if self.cls_embed: 63 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 64 | 65 | if sep_pos_embed: 66 | self.pos_embed_spatial = nn.Parameter( 67 | torch.zeros(1, input_size[1] * input_size[2], embed_dim) 68 | ) 69 | self.pos_embed_temporal = nn.Parameter( 70 | torch.zeros(1, input_size[0], embed_dim) 71 | ) 72 | if self.cls_embed: 73 | self.pos_embed_class = nn.Parameter(torch.zeros(1, 1, embed_dim)) 74 | else: 75 | if self.cls_embed: 76 | _num_patches = num_patches + 1 77 | else: 78 | _num_patches = num_patches 79 | 80 | self.pos_embed = nn.Parameter( 81 | torch.zeros(1, _num_patches, embed_dim), requires_grad=True 82 | ) # fixed or not? 83 | 84 | dpr = [ 85 | x.item() for x in torch.linspace(0, drop_path_rate, depth) 86 | ] # stochastic depth decay rule 87 | 88 | self.blocks = nn.ModuleList( 89 | [ 90 | Block( 91 | embed_dim, 92 | num_heads, 93 | mlp_ratio, 94 | qkv_bias=not no_qkv_bias, 95 | qk_scale=None, 96 | norm_layer=norm_layer, 97 | drop_path=dpr[i], 98 | attn_func=partial( 99 | Attention, 100 | input_size=self.patch_embed.input_size, 101 | ), 102 | ) 103 | for i in range(depth) 104 | ] 105 | ) 106 | self.norm = norm_layer(embed_dim) 107 | # -------------------------------------------------------------------------- 108 | 109 | self.dropout = nn.Dropout(dropout) 110 | self.head = nn.Linear(embed_dim, num_classes) 111 | 112 | torch.nn.init.normal_(self.head.weight, std=0.02) 113 | 114 | @torch.jit.ignore 115 | def no_weight_decay(self): 116 | return { 117 | "cls_token", 118 | "pos_embed", 119 | "pos_embed_spatial", 120 | "pos_embed_temporal", 121 | "pos_embed_class", 122 | } 123 | 124 | def forward(self, x): 125 | # embed patches 126 | x = self.patch_embed(x) 127 | N, T, L, C = x.shape # T: temporal; L: spatial 128 | 129 | x = x.view([N, T * L, C]) 130 | 131 | # append cls token 132 | if self.cls_embed: 133 | cls_token = self.cls_token 134 | cls_tokens = cls_token.expand(x.shape[0], -1, -1) 135 | x = torch.cat((cls_tokens, x), dim=1) 136 | 137 | if self.sep_pos_embed: 138 | pos_embed = self.pos_embed_spatial.repeat( 139 | 1, self.input_size[0], 1 140 | ) + torch.repeat_interleave( 141 | self.pos_embed_temporal, 142 | self.input_size[1] * self.input_size[2], 143 | dim=1, 144 | ) 145 | if self.cls_embed: 146 | pos_embed = torch.cat( 147 | [ 148 | self.pos_embed_class.expand(pos_embed.shape[0], -1, -1), 149 | pos_embed, 150 | ], 151 | 1, 152 | ) 153 | else: 154 | pos_embed = self.pos_embed[:, :, :] 155 | x = x + pos_embed 156 | 157 | # reshape to [N, T, L, C] or [N, T*L, C] 158 | requires_t_shape = ( 159 | len(self.blocks) > 0 # support empty decoder 160 | and hasattr(self.blocks[0].attn, "requires_t_shape") 161 | and self.blocks[0].attn.requires_t_shape 162 | ) 163 | if requires_t_shape: 164 | x = x.view([N, T, L, C]) 165 | 166 | # apply Transformer blocks 167 | for blk in self.blocks: 168 | x = blk(x) 169 | 170 | if requires_t_shape: 171 | x = x.view([N, T * L, C]) 172 | 173 | # classifier 174 | x = x[:, 1:, :].mean(dim=1) # global pool 175 | x = self.norm(x) 176 | # x = self.fc_norm(x) 177 | x = self.dropout(x) 178 | x = self.head(x) 179 | 180 | return x 181 | 182 | 183 | def vit_base_patch16(**kwargs): 184 | model = VisionTransformer( 185 | patch_size=16, 186 | embed_dim=768, 187 | depth=12, 188 | num_heads=12, 189 | mlp_ratio=4, 190 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 191 | **kwargs, 192 | ) 193 | return model 194 | 195 | 196 | def vit_large_patch16(**kwargs): 197 | model = VisionTransformer( 198 | patch_size=16, 199 | embed_dim=1024, 200 | depth=24, 201 | num_heads=16, 202 | mlp_ratio=4, 203 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 204 | **kwargs, 205 | ) 206 | return model 207 | 208 | 209 | def vit_huge_patch14(**kwargs): 210 | model = VisionTransformer( 211 | patch_size=16, 212 | embed_dim=1280, 213 | depth=32, 214 | num_heads=16, 215 | mlp_ratio=4, 216 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 217 | **kwargs, 218 | ) 219 | return model 220 | -------------------------------------------------------------------------------- /run_finetune.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | 5 | from pathlib import Path 6 | 7 | from main_finetune import get_args_parser, main 8 | 9 | 10 | def invoke_main() -> None: 11 | global args 12 | args = get_args_parser() 13 | args = args.parse_args() 14 | if args.output_dir: 15 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 16 | main(args) 17 | 18 | 19 | if __name__ == "__main__": 20 | invoke_main() # pragma: no cover 21 | -------------------------------------------------------------------------------- /run_pretrain.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | 5 | from pathlib import Path 6 | 7 | from main_pretrain import get_args_parser, main 8 | 9 | 10 | def invoke_main() -> None: 11 | args = get_args_parser() 12 | args = args.parse_args() 13 | if args.output_dir: 14 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 15 | main(args) 16 | 17 | 18 | if __name__ == "__main__": 19 | invoke_main() # pragma: no cover 20 | -------------------------------------------------------------------------------- /run_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | 5 | from pathlib import Path 6 | 7 | from main_test import get_args_parser, main 8 | 9 | 10 | def invoke_main() -> None: 11 | global args 12 | args = get_args_parser() 13 | args = args.parse_args() 14 | if args.output_dir: 15 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 16 | main(args) 17 | 18 | 19 | if __name__ == "__main__": 20 | invoke_main() # pragma: no cover 21 | -------------------------------------------------------------------------------- /util/decoder/decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | 5 | import math 6 | import random 7 | 8 | import numpy as np 9 | import torch 10 | import torchvision.io as io 11 | 12 | 13 | def temporal_sampling(frames, start_idx, end_idx, num_samples): 14 | """ 15 | Given the start and end frame index, sample num_samples frames between 16 | the start and end with equal interval. 17 | Args: 18 | frames (tensor): a tensor of video frames, dimension is 19 | `num video frames` x `channel` x `height` x `width`. 20 | start_idx (int): the index of the start frame. 21 | end_idx (int): the index of the end frame. 22 | num_samples (int): number of frames to sample. 23 | Returns: 24 | frames (tersor): a tensor of temporal sampled video frames, dimension is 25 | `num clip frames` x `channel` x `height` x `width`. 26 | """ 27 | index = torch.linspace(start_idx, end_idx, num_samples) 28 | index = torch.clamp(index, 0, frames.shape[0] - 1).long() 29 | new_frames = torch.index_select(frames, 0, index) 30 | return new_frames 31 | 32 | 33 | def get_start_end_idx(video_size, clip_size, clip_idx, num_clips, use_offset=False): 34 | """ 35 | Sample a clip of size clip_size from a video of size video_size and 36 | return the indices of the first and last frame of the clip. If clip_idx is 37 | -1, the clip is randomly sampled, otherwise uniformly split the video to 38 | num_clips clips, and select the start and end index of clip_idx-th video 39 | clip. 40 | Args: 41 | video_size (int): number of overall frames. 42 | clip_size (int): size of the clip to sample from the frames. 43 | clip_idx (int): if clip_idx is -1, perform random jitter sampling. If 44 | clip_idx is larger than -1, uniformly split the video to num_clips 45 | clips, and select the start and end index of the clip_idx-th video 46 | clip. 47 | num_clips (int): overall number of clips to uniformly sample from the 48 | given video for testing. 49 | Returns: 50 | start_idx (int): the start frame index. 51 | end_idx (int): the end frame index. 52 | """ 53 | delta = max(video_size - clip_size, 0) 54 | if clip_idx == -1: 55 | # Random temporal sampling. 56 | start_idx = random.uniform(0, delta) 57 | else: 58 | if use_offset: 59 | if num_clips == 1: 60 | # Take the center clip if num_clips is 1. 61 | start_idx = math.floor(delta / 2) 62 | else: 63 | # Uniformly sample the clip with the given index. 64 | start_idx = clip_idx * math.floor(delta / (num_clips - 1)) 65 | else: 66 | # Uniformly sample the clip with the given index. 67 | start_idx = delta * clip_idx / num_clips 68 | end_idx = start_idx + clip_size - 1 69 | return start_idx, end_idx 70 | 71 | 72 | def decode( 73 | container, 74 | sampling_rate, 75 | num_frames, 76 | clip_idx=-1, 77 | num_clips=10, 78 | video_meta=None, 79 | target_fps=30, 80 | max_spatial_scale=0, 81 | use_offset=False, 82 | rigid_decode_all_video=True, 83 | modalities=("visual",), 84 | ): 85 | """ 86 | Decode the video and perform temporal sampling. 87 | Args: 88 | container (container): pyav container. 89 | sampling_rate (int): frame sampling rate (interval between two sampled 90 | frames). 91 | num_frames (int): number of frames to sample. 92 | clip_idx (int): if clip_idx is -1, perform random temporal 93 | sampling. If clip_idx is larger than -1, uniformly split the 94 | video to num_clips clips, and select the 95 | clip_idx-th video clip. 96 | num_clips (int): overall number of clips to uniformly 97 | sample from the given video. 98 | video_meta (dict): a dict contains VideoMetaData. Details can be find 99 | at `pytorch/vision/torchvision/io/_video_opt.py`. 100 | target_fps (int): the input video may have different fps, convert it to 101 | the target video fps before frame sampling. 102 | max_spatial_scale (int): keep the aspect ratio and resize the frame so 103 | that shorter edge size is max_spatial_scale. Only used in 104 | `torchvision` backend. 105 | Returns: 106 | frames (tensor): decoded frames from the video. 107 | """ 108 | try: 109 | assert clip_idx >= -1, "Not valied clip_idx {}".format(clip_idx) 110 | # Convert the bytes to a tensor. 111 | video_tensor = torch.from_numpy(np.frombuffer(container, dtype=np.uint8)) 112 | 113 | decode_all_video = True 114 | video_start_pts, video_end_pts = 0, -1 115 | # The video_meta is empty, fetch the meta data from the raw video. 116 | if len(video_meta) == 0: 117 | # Tracking the meta info for selective decoding in the future. 118 | meta = io._probe_video_from_memory(video_tensor) 119 | # Using the information from video_meta to perform selective decoding. 120 | video_meta["video_timebase"] = meta.video_timebase 121 | video_meta["video_numerator"] = meta.video_timebase.numerator 122 | video_meta["video_denominator"] = meta.video_timebase.denominator 123 | video_meta["has_video"] = meta.has_video 124 | video_meta["video_duration"] = meta.video_duration 125 | video_meta["video_fps"] = meta.video_fps 126 | video_meta["audio_timebas"] = meta.audio_timebase 127 | video_meta["audio_numerator"] = meta.audio_timebase.numerator 128 | video_meta["audio_denominator"] = meta.audio_timebase.denominator 129 | video_meta["has_audio"] = meta.has_audio 130 | video_meta["audio_duration"] = meta.audio_duration 131 | video_meta["audio_sample_rate"] = meta.audio_sample_rate 132 | 133 | fps = video_meta["video_fps"] 134 | if not rigid_decode_all_video: 135 | if ( 136 | video_meta["has_video"] 137 | and video_meta["video_denominator"] > 0 138 | and video_meta["video_duration"] > 0 139 | ): 140 | # try selective decoding. 141 | decode_all_video = False 142 | clip_size = sampling_rate * num_frames / target_fps * fps 143 | start_idx, end_idx = get_start_end_idx( 144 | fps * video_meta["video_duration"], 145 | clip_size, 146 | clip_idx, 147 | num_clips, 148 | use_offset=use_offset, 149 | ) 150 | # Convert frame index to pts. 151 | pts_per_frame = video_meta["video_denominator"] / fps 152 | video_start_pts = int(start_idx * pts_per_frame) 153 | video_end_pts = int(end_idx * pts_per_frame) 154 | 155 | # Decode the raw video with the tv decoder. 156 | v_frames, _ = io._read_video_from_memory( 157 | video_tensor, 158 | seek_frame_margin=1.0, 159 | read_video_stream="visual" in modalities, 160 | video_width=0, 161 | video_height=0, 162 | video_min_dimension=max_spatial_scale, 163 | video_pts_range=(video_start_pts, video_end_pts), 164 | video_timebase_numerator=video_meta["video_numerator"], 165 | video_timebase_denominator=video_meta["video_denominator"], 166 | ) 167 | 168 | if v_frames.shape == torch.Size([0]): 169 | # failed selective decoding 170 | decode_all_video = True 171 | video_start_pts, video_end_pts = 0, -1 172 | v_frames, _ = io._read_video_from_memory( 173 | video_tensor, 174 | seek_frame_margin=1.0, 175 | read_video_stream="visual" in modalities, 176 | video_width=0, 177 | video_height=0, 178 | video_min_dimension=max_spatial_scale, 179 | video_pts_range=(video_start_pts, video_end_pts), 180 | video_timebase_numerator=video_meta["video_numerator"], 181 | video_timebase_denominator=video_meta["video_denominator"], 182 | ) 183 | except Exception as e: 184 | print("Failed to decode by torchvision with exception: {}".format(e)) 185 | return None 186 | 187 | # Return None if the frames was not decoded successfully. 188 | if v_frames is None or v_frames.size(0) == 0: 189 | return None, fps, decode_all_video 190 | return v_frames, fps, decode_all_video 191 | -------------------------------------------------------------------------------- /util/decoder/mixup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | 5 | """ 6 | This implementation is based on 7 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/mixup.py, 8 | published under an Apache License 2.0. 9 | 10 | COMMENT FROM ORIGINAL: 11 | Mixup and Cutmix 12 | Papers: 13 | mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412) 14 | CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features (https://arxiv.org/abs/1905.04899) # NOQA 15 | Code Reference: 16 | CutMix: https://github.com/clovaai/CutMix-PyTorch 17 | Hacked together by / Copyright 2020 Ross Wightman 18 | """ 19 | 20 | import numpy as np 21 | import torch 22 | 23 | 24 | def convert_to_one_hot(targets, num_classes, on_value=1.0, off_value=0.0): 25 | """ 26 | This function converts target class indices to one-hot vectors, given the 27 | number of classes. 28 | Args: 29 | targets (loader): Class labels. 30 | num_classes (int): Total number of classes. 31 | on_value (float): Target Value for ground truth class. 32 | off_value (float): Target Value for other classes.This value is used for 33 | label smoothing. 34 | """ 35 | 36 | targets = targets.long().view(-1, 1) 37 | return torch.full( 38 | (targets.size()[0], num_classes), off_value, device=targets.device 39 | ).scatter_(1, targets, on_value) 40 | 41 | 42 | def mixup_target(target, num_classes, lam=1.0, smoothing=0.0): 43 | """ 44 | This function converts target class indices to one-hot vectors, given the 45 | number of classes. 46 | Args: 47 | targets (loader): Class labels. 48 | num_classes (int): Total number of classes. 49 | lam (float): lamba value for mixup/cutmix. 50 | smoothing (float): Label smoothing value. 51 | """ 52 | off_value = smoothing / num_classes 53 | on_value = 1.0 - smoothing + off_value 54 | target1 = convert_to_one_hot( 55 | target, 56 | num_classes, 57 | on_value=on_value, 58 | off_value=off_value, 59 | ) 60 | target2 = convert_to_one_hot( 61 | target.flip(0), 62 | num_classes, 63 | on_value=on_value, 64 | off_value=off_value, 65 | ) 66 | return target1 * lam + target2 * (1.0 - lam) 67 | 68 | 69 | def rand_bbox(img_shape, lam, margin=0.0, count=None): 70 | """ 71 | Generates a random square bbox based on lambda value. 72 | 73 | Args: 74 | img_shape (tuple): Image shape as tuple 75 | lam (float): Cutmix lambda value 76 | margin (float): Percentage of bbox dimension to enforce as margin (reduce amount of box outside image) 77 | count (int): Number of bbox to generate 78 | """ 79 | ratio = np.sqrt(1 - lam) 80 | img_h, img_w = img_shape[-2:] 81 | cut_h, cut_w = int(img_h * ratio), int(img_w * ratio) 82 | margin_y, margin_x = int(margin * cut_h), int(margin * cut_w) 83 | cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count) 84 | cx = np.random.randint(0 + margin_x, img_w - margin_x, size=count) 85 | yl = np.clip(cy - cut_h // 2, 0, img_h) 86 | yh = np.clip(cy + cut_h // 2, 0, img_h) 87 | xl = np.clip(cx - cut_w // 2, 0, img_w) 88 | xh = np.clip(cx + cut_w // 2, 0, img_w) 89 | return yl, yh, xl, xh 90 | 91 | 92 | def get_cutmix_bbox(img_shape, lam, correct_lam=True, count=None): 93 | """ 94 | Generates the box coordinates for cutmix. 95 | 96 | Args: 97 | img_shape (tuple): Image shape as tuple 98 | lam (float): Cutmix lambda value 99 | correct_lam (bool): Apply lambda correction when cutmix bbox clipped by 100 | image borders. 101 | count (int): Number of bbox to generate 102 | """ 103 | 104 | yl, yu, xl, xu = rand_bbox(img_shape, lam, count=count) 105 | if correct_lam: 106 | bbox_area = (yu - yl) * (xu - xl) 107 | lam = 1.0 - bbox_area / float(img_shape[-2] * img_shape[-1]) 108 | return (yl, yu, xl, xu), lam 109 | 110 | 111 | class MixUp: 112 | """ 113 | Apply mixup and/or cutmix for videos at batch level. 114 | mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412) 115 | CutMix: Regularization Strategy to Train Strong Classifiers with Localizable 116 | Features (https://arxiv.org/abs/1905.04899) 117 | """ 118 | 119 | def __init__( 120 | self, 121 | mixup_alpha=1.0, 122 | cutmix_alpha=0.0, 123 | mix_prob=1.0, 124 | switch_prob=0.5, 125 | correct_lam=True, 126 | label_smoothing=0.1, 127 | num_classes=1000, 128 | ): 129 | """ 130 | Args: 131 | mixup_alpha (float): Mixup alpha value. 132 | cutmix_alpha (float): Cutmix alpha value. 133 | mix_prob (float): Probability of applying mixup or cutmix. 134 | switch_prob (float): Probability of switching to cutmix instead of 135 | mixup when both are active. 136 | correct_lam (bool): Apply lambda correction when cutmix bbox 137 | clipped by image borders. 138 | label_smoothing (float): Apply label smoothing to the mixed target 139 | tensor. If label_smoothing is not used, set it to 0. 140 | num_classes (int): Number of classes for target. 141 | """ 142 | self.mixup_alpha = mixup_alpha 143 | self.cutmix_alpha = cutmix_alpha 144 | self.mix_prob = mix_prob 145 | self.switch_prob = switch_prob 146 | self.label_smoothing = label_smoothing 147 | self.num_classes = num_classes 148 | self.correct_lam = correct_lam 149 | 150 | def _get_mixup_params(self): 151 | lam = 1.0 152 | use_cutmix = False 153 | if np.random.rand() < self.mix_prob: 154 | if self.mixup_alpha > 0.0 and self.cutmix_alpha > 0.0: 155 | use_cutmix = np.random.rand() < self.switch_prob 156 | lam_mix = ( 157 | np.random.beta(self.cutmix_alpha, self.cutmix_alpha) 158 | if use_cutmix 159 | else np.random.beta(self.mixup_alpha, self.mixup_alpha) 160 | ) 161 | elif self.mixup_alpha > 0.0: 162 | lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha) 163 | elif self.cutmix_alpha > 0.0: 164 | use_cutmix = True 165 | lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) 166 | lam = float(lam_mix) 167 | return lam, use_cutmix 168 | 169 | def _mix_batch(self, x): 170 | lam, use_cutmix = self._get_mixup_params() 171 | if lam == 1.0: 172 | return 1.0 173 | if use_cutmix: 174 | (yl, yh, xl, xh), lam = get_cutmix_bbox( 175 | x.shape, 176 | lam, 177 | correct_lam=self.correct_lam, 178 | ) 179 | x[..., yl:yh, xl:xh] = x.flip(0)[..., yl:yh, xl:xh] 180 | else: 181 | x_flipped = x.flip(0).mul_(1.0 - lam) 182 | x.mul_(lam).add_(x_flipped) 183 | return lam 184 | 185 | def __call__(self, x, target): 186 | assert len(x) > 1, "Batch size should be greater than 1 for mixup." 187 | lam = self._mix_batch(x) 188 | target = mixup_target(target, self.num_classes, lam, self.label_smoothing) 189 | return x, target 190 | -------------------------------------------------------------------------------- /util/decoder/rand_augment.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | 5 | """ 6 | This implementation is based on 7 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/auto_augment.py 8 | pulished under an Apache License 2.0. 9 | 10 | COMMENT FROM ORIGINAL: 11 | AutoAugment, RandAugment, and AugMix for PyTorch 12 | This code implements the searched ImageNet policies with various tweaks and 13 | improvements and does not include any of the search code. AA and RA 14 | Implementation adapted from: 15 | https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py 16 | AugMix adapted from: 17 | https://github.com/google-research/augmix 18 | Papers: 19 | AutoAugment: Learning Augmentation Policies from Data 20 | https://arxiv.org/abs/1805.09501 21 | Learning Data Augmentation Strategies for Object Detection 22 | https://arxiv.org/abs/1906.11172 23 | RandAugment: Practical automated data augmentation... 24 | https://arxiv.org/abs/1909.13719 25 | AugMix: A Simple Data Processing Method to Improve Robustness and 26 | Uncertainty https://arxiv.org/abs/1912.02781 27 | 28 | Hacked together by / Copyright 2020 Ross Wightman 29 | """ 30 | 31 | import math 32 | import random 33 | import re 34 | 35 | import numpy as np 36 | import PIL 37 | from PIL import Image, ImageEnhance, ImageOps 38 | 39 | _PIL_VER = tuple([int(x) for x in PIL.__version__.split(".")[:2]]) 40 | 41 | _FILL = (128, 128, 128) 42 | 43 | # This signifies the max integer that the controller RNN could predict for the 44 | # augmentation scheme. 45 | _MAX_LEVEL = 10.0 46 | 47 | _HPARAMS_DEFAULT = { 48 | "translate_const": 250, 49 | "img_mean": _FILL, 50 | } 51 | 52 | _RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) 53 | 54 | 55 | def _interpolation(kwargs): 56 | interpolation = kwargs.pop("resample", Image.BILINEAR) 57 | if isinstance(interpolation, (list, tuple)): 58 | return random.choice(interpolation) 59 | else: 60 | return interpolation 61 | 62 | 63 | def _check_args_tf(kwargs): 64 | if "fillcolor" in kwargs and _PIL_VER < (5, 0): 65 | kwargs.pop("fillcolor") 66 | kwargs["resample"] = _interpolation(kwargs) 67 | 68 | 69 | def shear_x(img, factor, **kwargs): 70 | _check_args_tf(kwargs) 71 | return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs) 72 | 73 | 74 | def shear_y(img, factor, **kwargs): 75 | _check_args_tf(kwargs) 76 | return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs) 77 | 78 | 79 | def translate_x_rel(img, pct, **kwargs): 80 | pixels = pct * img.size[0] 81 | _check_args_tf(kwargs) 82 | return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) 83 | 84 | 85 | def translate_y_rel(img, pct, **kwargs): 86 | pixels = pct * img.size[1] 87 | _check_args_tf(kwargs) 88 | return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) 89 | 90 | 91 | def translate_x_abs(img, pixels, **kwargs): 92 | _check_args_tf(kwargs) 93 | return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) 94 | 95 | 96 | def translate_y_abs(img, pixels, **kwargs): 97 | _check_args_tf(kwargs) 98 | return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) 99 | 100 | 101 | def rotate(img, degrees, **kwargs): 102 | _check_args_tf(kwargs) 103 | if _PIL_VER >= (5, 2): 104 | return img.rotate(degrees, **kwargs) 105 | elif _PIL_VER >= (5, 0): 106 | w, h = img.size 107 | post_trans = (0, 0) 108 | rotn_center = (w / 2.0, h / 2.0) 109 | angle = -math.radians(degrees) 110 | matrix = [ 111 | round(math.cos(angle), 15), 112 | round(math.sin(angle), 15), 113 | 0.0, 114 | round(-math.sin(angle), 15), 115 | round(math.cos(angle), 15), 116 | 0.0, 117 | ] 118 | 119 | def transform(x, y, matrix): 120 | (a, b, c, d, e, f) = matrix 121 | return a * x + b * y + c, d * x + e * y + f 122 | 123 | matrix[2], matrix[5] = transform( 124 | -rotn_center[0] - post_trans[0], 125 | -rotn_center[1] - post_trans[1], 126 | matrix, 127 | ) 128 | matrix[2] += rotn_center[0] 129 | matrix[5] += rotn_center[1] 130 | return img.transform(img.size, Image.AFFINE, matrix, **kwargs) 131 | else: 132 | return img.rotate(degrees, resample=kwargs["resample"]) 133 | 134 | 135 | def auto_contrast(img, **__): 136 | return ImageOps.autocontrast(img) 137 | 138 | 139 | def invert(img, **__): 140 | return ImageOps.invert(img) 141 | 142 | 143 | def equalize(img, **__): 144 | return ImageOps.equalize(img) 145 | 146 | 147 | def solarize(img, thresh, **__): 148 | return ImageOps.solarize(img, thresh) 149 | 150 | 151 | def solarize_add(img, add, thresh=128, **__): 152 | lut = [] 153 | for i in range(256): 154 | if i < thresh: 155 | lut.append(min(255, i + add)) 156 | else: 157 | lut.append(i) 158 | if img.mode in ("L", "RGB"): 159 | if img.mode == "RGB" and len(lut) == 256: 160 | lut = lut + lut + lut 161 | return img.point(lut) 162 | else: 163 | return img 164 | 165 | 166 | def posterize(img, bits_to_keep, **__): 167 | if bits_to_keep >= 8: 168 | return img 169 | return ImageOps.posterize(img, bits_to_keep) 170 | 171 | 172 | def contrast(img, factor, **__): 173 | return ImageEnhance.Contrast(img).enhance(factor) 174 | 175 | 176 | def color(img, factor, **__): 177 | return ImageEnhance.Color(img).enhance(factor) 178 | 179 | 180 | def brightness(img, factor, **__): 181 | return ImageEnhance.Brightness(img).enhance(factor) 182 | 183 | 184 | def sharpness(img, factor, **__): 185 | return ImageEnhance.Sharpness(img).enhance(factor) 186 | 187 | 188 | def _randomly_negate(v): 189 | """With 50% prob, negate the value""" 190 | return -v if random.random() > 0.5 else v 191 | 192 | 193 | def _rotate_level_to_arg(level, _hparams): 194 | # range [-30, 30] 195 | level = (level / _MAX_LEVEL) * 30.0 196 | level = _randomly_negate(level) 197 | return (level,) 198 | 199 | 200 | def _enhance_level_to_arg(level, _hparams): 201 | # range [0.1, 1.9] 202 | return ((level / _MAX_LEVEL) * 1.8 + 0.1,) 203 | 204 | 205 | def _enhance_increasing_level_to_arg(level, _hparams): 206 | # the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend 207 | # range [0.1, 1.9] 208 | level = (level / _MAX_LEVEL) * 0.9 209 | level = 1.0 + _randomly_negate(level) 210 | return (level,) 211 | 212 | 213 | def _shear_level_to_arg(level, _hparams): 214 | # range [-0.3, 0.3] 215 | level = (level / _MAX_LEVEL) * 0.3 216 | level = _randomly_negate(level) 217 | return (level,) 218 | 219 | 220 | def _translate_abs_level_to_arg(level, hparams): 221 | translate_const = hparams["translate_const"] 222 | level = (level / _MAX_LEVEL) * float(translate_const) 223 | level = _randomly_negate(level) 224 | return (level,) 225 | 226 | 227 | def _translate_rel_level_to_arg(level, hparams): 228 | # default range [-0.45, 0.45] 229 | translate_pct = hparams.get("translate_pct", 0.45) 230 | level = (level / _MAX_LEVEL) * translate_pct 231 | level = _randomly_negate(level) 232 | return (level,) 233 | 234 | 235 | def _posterize_level_to_arg(level, _hparams): 236 | # As per Tensorflow TPU EfficientNet impl 237 | # range [0, 4], 'keep 0 up to 4 MSB of original image' 238 | # intensity/severity of augmentation decreases with level 239 | return (int((level / _MAX_LEVEL) * 4),) 240 | 241 | 242 | def _posterize_increasing_level_to_arg(level, hparams): 243 | # As per Tensorflow models research and UDA impl 244 | # range [4, 0], 'keep 4 down to 0 MSB of original image', 245 | # intensity/severity of augmentation increases with level 246 | return (4 - _posterize_level_to_arg(level, hparams)[0],) 247 | 248 | 249 | def _posterize_original_level_to_arg(level, _hparams): 250 | # As per original AutoAugment paper description 251 | # range [4, 8], 'keep 4 up to 8 MSB of image' 252 | # intensity/severity of augmentation decreases with level 253 | return (int((level / _MAX_LEVEL) * 4) + 4,) 254 | 255 | 256 | def _solarize_level_to_arg(level, _hparams): 257 | # range [0, 256] 258 | # intensity/severity of augmentation decreases with level 259 | return (int((level / _MAX_LEVEL) * 256),) 260 | 261 | 262 | def _solarize_increasing_level_to_arg(level, _hparams): 263 | # range [0, 256] 264 | # intensity/severity of augmentation increases with level 265 | return (256 - _solarize_level_to_arg(level, _hparams)[0],) 266 | 267 | 268 | def _solarize_add_level_to_arg(level, _hparams): 269 | # range [0, 110] 270 | return (int((level / _MAX_LEVEL) * 110),) 271 | 272 | 273 | LEVEL_TO_ARG = { 274 | "AutoContrast": None, 275 | "Equalize": None, 276 | "Invert": None, 277 | "Rotate": _rotate_level_to_arg, 278 | # There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers 279 | "Posterize": _posterize_level_to_arg, 280 | "PosterizeIncreasing": _posterize_increasing_level_to_arg, 281 | "PosterizeOriginal": _posterize_original_level_to_arg, 282 | "Solarize": _solarize_level_to_arg, 283 | "SolarizeIncreasing": _solarize_increasing_level_to_arg, 284 | "SolarizeAdd": _solarize_add_level_to_arg, 285 | "Color": _enhance_level_to_arg, 286 | "ColorIncreasing": _enhance_increasing_level_to_arg, 287 | "Contrast": _enhance_level_to_arg, 288 | "ContrastIncreasing": _enhance_increasing_level_to_arg, 289 | "Brightness": _enhance_level_to_arg, 290 | "BrightnessIncreasing": _enhance_increasing_level_to_arg, 291 | "Sharpness": _enhance_level_to_arg, 292 | "SharpnessIncreasing": _enhance_increasing_level_to_arg, 293 | "ShearX": _shear_level_to_arg, 294 | "ShearY": _shear_level_to_arg, 295 | "TranslateX": _translate_abs_level_to_arg, 296 | "TranslateY": _translate_abs_level_to_arg, 297 | "TranslateXRel": _translate_rel_level_to_arg, 298 | "TranslateYRel": _translate_rel_level_to_arg, 299 | } 300 | 301 | 302 | NAME_TO_OP = { 303 | "AutoContrast": auto_contrast, 304 | "Equalize": equalize, 305 | "Invert": invert, 306 | "Rotate": rotate, 307 | "Posterize": posterize, 308 | "PosterizeIncreasing": posterize, 309 | "PosterizeOriginal": posterize, 310 | "Solarize": solarize, 311 | "SolarizeIncreasing": solarize, 312 | "SolarizeAdd": solarize_add, 313 | "Color": color, 314 | "ColorIncreasing": color, 315 | "Contrast": contrast, 316 | "ContrastIncreasing": contrast, 317 | "Brightness": brightness, 318 | "BrightnessIncreasing": brightness, 319 | "Sharpness": sharpness, 320 | "SharpnessIncreasing": sharpness, 321 | "ShearX": shear_x, 322 | "ShearY": shear_y, 323 | "TranslateX": translate_x_abs, 324 | "TranslateY": translate_y_abs, 325 | "TranslateXRel": translate_x_rel, 326 | "TranslateYRel": translate_y_rel, 327 | } 328 | 329 | 330 | class AugmentOp: 331 | """ 332 | Apply for video. 333 | """ 334 | 335 | def __init__(self, name, prob=0.5, magnitude=10, hparams=None): 336 | hparams = hparams or _HPARAMS_DEFAULT 337 | self.aug_fn = NAME_TO_OP[name] 338 | self.level_fn = LEVEL_TO_ARG[name] 339 | self.prob = prob 340 | self.magnitude = magnitude 341 | self.hparams = hparams.copy() 342 | self.kwargs = { 343 | "fillcolor": hparams["img_mean"] if "img_mean" in hparams else _FILL, 344 | "resample": ( 345 | hparams["interpolation"] 346 | if "interpolation" in hparams 347 | else _RANDOM_INTERPOLATION 348 | ), 349 | } 350 | 351 | # If magnitude_std is > 0, we introduce some randomness 352 | # in the usually fixed policy and sample magnitude from a normal distribution 353 | # with mean `magnitude` and std-dev of `magnitude_std`. 354 | # NOTE This is my own hack, being tested, not in papers or reference impls. 355 | self.magnitude_std = self.hparams.get("magnitude_std", 0) 356 | 357 | def __call__(self, img_list): 358 | if self.prob < 1.0 and random.random() > self.prob: 359 | return img_list 360 | magnitude = self.magnitude 361 | if self.magnitude_std and self.magnitude_std > 0: 362 | magnitude = random.gauss(magnitude, self.magnitude_std) 363 | magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range 364 | level_args = ( 365 | self.level_fn(magnitude, self.hparams) if self.level_fn is not None else () 366 | ) 367 | 368 | if isinstance(img_list, list): 369 | return [self.aug_fn(img, *level_args, **self.kwargs) for img in img_list] 370 | else: 371 | return self.aug_fn(img_list, *level_args, **self.kwargs) 372 | 373 | 374 | _RAND_TRANSFORMS = [ 375 | "AutoContrast", 376 | "Equalize", 377 | "Invert", 378 | "Rotate", 379 | "Posterize", 380 | "Solarize", 381 | "SolarizeAdd", 382 | "Color", 383 | "Contrast", 384 | "Brightness", 385 | "Sharpness", 386 | "ShearX", 387 | "ShearY", 388 | "TranslateXRel", 389 | "TranslateYRel", 390 | ] 391 | 392 | 393 | _RAND_INCREASING_TRANSFORMS = [ 394 | "AutoContrast", 395 | "Equalize", 396 | "Invert", 397 | "Rotate", 398 | "PosterizeIncreasing", 399 | "SolarizeIncreasing", 400 | "SolarizeAdd", 401 | "ColorIncreasing", 402 | "ContrastIncreasing", 403 | "BrightnessIncreasing", 404 | "SharpnessIncreasing", 405 | "ShearX", 406 | "ShearY", 407 | "TranslateXRel", 408 | "TranslateYRel", 409 | ] 410 | 411 | 412 | # These experimental weights are based loosely on the relative improvements mentioned in paper. 413 | # They may not result in increased performance, but could likely be tuned to so. 414 | _RAND_CHOICE_WEIGHTS_0 = { 415 | "Rotate": 0.3, 416 | "ShearX": 0.2, 417 | "ShearY": 0.2, 418 | "TranslateXRel": 0.1, 419 | "TranslateYRel": 0.1, 420 | "Color": 0.025, 421 | "Sharpness": 0.025, 422 | "AutoContrast": 0.025, 423 | "Solarize": 0.005, 424 | "SolarizeAdd": 0.005, 425 | "Contrast": 0.005, 426 | "Brightness": 0.005, 427 | "Equalize": 0.005, 428 | "Posterize": 0, 429 | "Invert": 0, 430 | } 431 | 432 | 433 | def _select_rand_weights(weight_idx=0, transforms=None): 434 | transforms = transforms or _RAND_TRANSFORMS 435 | assert weight_idx == 0 # only one set of weights currently 436 | rand_weights = _RAND_CHOICE_WEIGHTS_0 437 | probs = [rand_weights[k] for k in transforms] 438 | probs /= np.sum(probs) 439 | return probs 440 | 441 | 442 | def rand_augment_ops(magnitude=10, hparams=None, transforms=None): 443 | hparams = hparams or _HPARAMS_DEFAULT 444 | transforms = transforms or _RAND_TRANSFORMS 445 | return [ 446 | AugmentOp(name, prob=0.5, magnitude=magnitude, hparams=hparams) 447 | for name in transforms 448 | ] 449 | 450 | 451 | class RandAugment: 452 | def __init__(self, ops, num_layers=2, choice_weights=None): 453 | self.ops = ops 454 | self.num_layers = num_layers 455 | self.choice_weights = choice_weights 456 | 457 | def __call__(self, img): 458 | # no replacement when using weighted choice 459 | ops = np.random.choice( 460 | self.ops, 461 | self.num_layers, 462 | replace=self.choice_weights is None, 463 | p=self.choice_weights, 464 | ) 465 | for op in ops: 466 | img = op(img) 467 | return img 468 | 469 | 470 | def rand_augment_transform(config_str, hparams): 471 | """ 472 | RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719 473 | 474 | Create a RandAugment transform 475 | :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by 476 | dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining 477 | sections, not order sepecific determine 478 | 'm' - integer magnitude of rand augment 479 | 'n' - integer num layers (number of transform ops selected per image) 480 | 'w' - integer probabiliy weight index (index of a set of weights to influence choice of op) 481 | 'mstd' - float std deviation of magnitude noise applied 482 | 'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0) 483 | Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5 484 | 'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2 485 | :param hparams: Other hparams (kwargs) for the RandAugmentation scheme 486 | :return: A PyTorch compatible Transform 487 | """ 488 | magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10) 489 | num_layers = 2 # default to 2 ops per image 490 | weight_idx = None # default to no probability weights for op choice 491 | transforms = _RAND_TRANSFORMS 492 | config = config_str.split("-") 493 | assert config[0] == "rand" 494 | config = config[1:] 495 | for c in config: 496 | cs = re.split(r"(\d.*)", c) 497 | if len(cs) < 2: 498 | continue 499 | key, val = cs[:2] 500 | if key == "mstd": 501 | # noise param injected via hparams for now 502 | hparams.setdefault("magnitude_std", float(val)) 503 | elif key == "inc": 504 | if bool(val): 505 | transforms = _RAND_INCREASING_TRANSFORMS 506 | elif key == "m": 507 | magnitude = int(val) 508 | elif key == "n": 509 | num_layers = int(val) 510 | elif key == "w": 511 | weight_idx = int(val) 512 | else: 513 | assert NotImplementedError 514 | ra_ops = rand_augment_ops( 515 | magnitude=magnitude, hparams=hparams, transforms=transforms 516 | ) 517 | choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx) 518 | return RandAugment(ra_ops, num_layers, choice_weights=choice_weights) 519 | -------------------------------------------------------------------------------- /util/decoder/random_erasing.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | 5 | """ 6 | This implementation is based on 7 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/random_erasing.py 8 | pulished under an Apache License 2.0. 9 | 10 | COMMENT FROM ORIGINAL: 11 | Originally inspired by impl at https://github.com/zhunzhong07/Random-Erasing, Apache 2.0 12 | Copyright Zhun Zhong & Liang Zheng 13 | Hacked together by / Copyright 2020 Ross Wightman 14 | """ 15 | 16 | import math 17 | import random 18 | 19 | import torch 20 | 21 | 22 | def _get_pixels(per_pixel, rand_color, patch_size, dtype=torch.float32, device="cuda"): 23 | # NOTE I've seen CUDA illegal memory access errors being caused by the normal_() 24 | # paths, flip the order so normal is run on CPU if this becomes a problem 25 | # Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508 26 | if per_pixel: 27 | return torch.empty(patch_size, dtype=dtype, device=device).normal_() 28 | elif rand_color: 29 | return torch.empty((patch_size[0], 1, 1), dtype=dtype, device=device).normal_() 30 | else: 31 | return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device) 32 | 33 | 34 | class RandomErasing: 35 | """Randomly selects a rectangle region in an image and erases its pixels. 36 | 'Random Erasing Data Augmentation' by Zhong et al. 37 | See https://arxiv.org/pdf/1708.04896.pdf 38 | This variant of RandomErasing is intended to be applied to either a batch 39 | or single image tensor after it has been normalized by dataset mean and std. 40 | Args: 41 | probability: Probability that the Random Erasing operation will be performed. 42 | min_area: Minimum percentage of erased area wrt input image area. 43 | max_area: Maximum percentage of erased area wrt input image area. 44 | min_aspect: Minimum aspect ratio of erased area. 45 | mode: pixel color mode, one of 'const', 'rand', or 'pixel' 46 | 'const' - erase block is constant color of 0 for all channels 47 | 'rand' - erase block is same per-channel random (normal) color 48 | 'pixel' - erase block is per-pixel random (normal) color 49 | max_count: maximum number of erasing blocks per image, area per box is scaled by count. 50 | per-image count is randomly chosen between 1 and this value. 51 | """ 52 | 53 | def __init__( 54 | self, 55 | probability=0.5, 56 | min_area=0.02, 57 | max_area=1 / 3, 58 | min_aspect=0.3, 59 | max_aspect=None, 60 | mode="const", 61 | min_count=1, 62 | max_count=None, 63 | num_splits=0, 64 | device="cuda", 65 | cube=True, 66 | ): 67 | self.probability = probability 68 | self.min_area = min_area 69 | self.max_area = max_area 70 | max_aspect = max_aspect or 1 / min_aspect 71 | self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) 72 | self.min_count = min_count 73 | self.max_count = max_count or min_count 74 | self.num_splits = num_splits 75 | mode = mode.lower() 76 | self.rand_color = False 77 | self.per_pixel = False 78 | self.cube = cube 79 | if mode == "rand": 80 | self.rand_color = True # per block random normal 81 | elif mode == "pixel": 82 | self.per_pixel = True # per pixel random normal 83 | else: 84 | assert not mode or mode == "const" 85 | self.device = device 86 | 87 | def _erase(self, img, chan, img_h, img_w, dtype): 88 | if random.random() > self.probability: 89 | return 90 | area = img_h * img_w 91 | count = ( 92 | self.min_count 93 | if self.min_count == self.max_count 94 | else random.randint(self.min_count, self.max_count) 95 | ) 96 | for _ in range(count): 97 | for _ in range(10): 98 | target_area = ( 99 | random.uniform(self.min_area, self.max_area) * area / count 100 | ) 101 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) 102 | h = int(round(math.sqrt(target_area * aspect_ratio))) 103 | w = int(round(math.sqrt(target_area / aspect_ratio))) 104 | if w < img_w and h < img_h: 105 | top = random.randint(0, img_h - h) 106 | left = random.randint(0, img_w - w) 107 | img[:, top : top + h, left : left + w] = _get_pixels( 108 | self.per_pixel, 109 | self.rand_color, 110 | (chan, h, w), 111 | dtype=dtype, 112 | device=self.device, 113 | ) 114 | break 115 | 116 | def _erase_cube( 117 | self, 118 | img, 119 | batch_start, 120 | batch_size, 121 | chan, 122 | img_h, 123 | img_w, 124 | dtype, 125 | ): 126 | if random.random() > self.probability: 127 | return 128 | area = img_h * img_w 129 | count = ( 130 | self.min_count 131 | if self.min_count == self.max_count 132 | else random.randint(self.min_count, self.max_count) 133 | ) 134 | for _ in range(count): 135 | for _ in range(100): 136 | target_area = ( 137 | random.uniform(self.min_area, self.max_area) * area / count 138 | ) 139 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) 140 | h = int(round(math.sqrt(target_area * aspect_ratio))) 141 | w = int(round(math.sqrt(target_area / aspect_ratio))) 142 | if w < img_w and h < img_h: 143 | top = random.randint(0, img_h - h) 144 | left = random.randint(0, img_w - w) 145 | for i in range(batch_start, batch_size): 146 | img_instance = img[i] 147 | img_instance[:, top : top + h, left : left + w] = _get_pixels( 148 | self.per_pixel, 149 | self.rand_color, 150 | (chan, h, w), 151 | dtype=dtype, 152 | device=self.device, 153 | ) 154 | break 155 | 156 | def __call__(self, input): 157 | if len(input.size()) == 3: 158 | self._erase(input, *input.size(), input.dtype) 159 | else: 160 | batch_size, chan, img_h, img_w = input.size() 161 | # skip first slice of batch if num_splits is set (for clean portion of samples) 162 | batch_start = batch_size // self.num_splits if self.num_splits > 1 else 0 163 | if self.cube: 164 | self._erase_cube( 165 | input, 166 | batch_start, 167 | batch_size, 168 | chan, 169 | img_h, 170 | img_w, 171 | input.dtype, 172 | ) 173 | else: 174 | for i in range(batch_start, batch_size): 175 | self._erase(input[i], chan, img_h, img_w, input.dtype) 176 | return input 177 | -------------------------------------------------------------------------------- /util/decoder/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | 5 | import logging 6 | import os 7 | import random 8 | import time 9 | from collections import defaultdict 10 | 11 | import cv2 12 | import numpy as np 13 | import torch 14 | from iopath.common.file_io import g_pathmgr as pathmgr 15 | from torch.utils.data.distributed import DistributedSampler 16 | 17 | from . import transform as transform 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | def retry_load_images(image_paths, retry=10, backend="pytorch"): 23 | """ 24 | This function is to load images with support of retrying for failed load. 25 | 26 | Args: 27 | image_paths (list): paths of images needed to be loaded. 28 | retry (int, optional): maximum time of loading retrying. Defaults to 10. 29 | backend (str): `pytorch` or `cv2`. 30 | 31 | Returns: 32 | imgs (list): list of loaded images. 33 | """ 34 | for i in range(retry): 35 | imgs = [] 36 | for image_path in image_paths: 37 | with pathmgr.open(image_path, "rb") as f: 38 | img_str = np.frombuffer(f.read(), np.uint8) 39 | img = cv2.imdecode(img_str, flags=cv2.IMREAD_COLOR) 40 | imgs.append(img) 41 | 42 | if all(img is not None for img in imgs): 43 | if backend == "pytorch": 44 | imgs = torch.as_tensor(np.stack(imgs)) 45 | return imgs 46 | else: 47 | logger.warn("Reading failed. Will retry.") 48 | time.sleep(1.0) 49 | if i == retry - 1: 50 | raise Exception("Failed to load images {}".format(image_paths)) 51 | 52 | 53 | def get_sequence(center_idx, half_len, sample_rate, num_frames): 54 | """ 55 | Sample frames among the corresponding clip. 56 | 57 | Args: 58 | center_idx (int): center frame idx for current clip 59 | half_len (int): half of the clip length 60 | sample_rate (int): sampling rate for sampling frames inside of the clip 61 | num_frames (int): number of expected sampled frames 62 | 63 | Returns: 64 | seq (list): list of indexes of sampled frames in this clip. 65 | """ 66 | seq = list(range(center_idx - half_len, center_idx + half_len, sample_rate)) 67 | 68 | for seq_idx in range(len(seq)): 69 | if seq[seq_idx] < 0: 70 | seq[seq_idx] = 0 71 | elif seq[seq_idx] >= num_frames: 72 | seq[seq_idx] = num_frames - 1 73 | return seq 74 | 75 | 76 | def spatial_sampling( 77 | frames, 78 | spatial_idx=-1, 79 | min_scale=256, 80 | max_scale=320, 81 | crop_size=224, 82 | random_horizontal_flip=True, 83 | inverse_uniform_sampling=False, 84 | aspect_ratio=None, 85 | scale=None, 86 | motion_shift=False, 87 | ): 88 | """ 89 | Perform spatial sampling on the given video frames. If spatial_idx is 90 | -1, perform random scale, random crop, and random flip on the given 91 | frames. If spatial_idx is 0, 1, or 2, perform spatial uniform sampling 92 | with the given spatial_idx. 93 | Args: 94 | frames (tensor): frames of images sampled from the video. The 95 | dimension is `num frames` x `height` x `width` x `channel`. 96 | spatial_idx (int): if -1, perform random spatial sampling. If 0, 1, 97 | or 2, perform left, center, right crop if width is larger than 98 | height, and perform top, center, buttom crop if height is larger 99 | than width. 100 | min_scale (int): the minimal size of scaling. 101 | max_scale (int): the maximal size of scaling. 102 | crop_size (int): the size of height and width used to crop the 103 | frames. 104 | inverse_uniform_sampling (bool): if True, sample uniformly in 105 | [1 / max_scale, 1 / min_scale] and take a reciprocal to get the 106 | scale. If False, take a uniform sample from [min_scale, 107 | max_scale]. 108 | aspect_ratio (list): Aspect ratio range for resizing. 109 | scale (list): Scale range for resizing. 110 | motion_shift (bool): Whether to apply motion shift for resizing. 111 | Returns: 112 | frames (tensor): spatially sampled frames. 113 | """ 114 | assert spatial_idx in [-1, 0, 1, 2] 115 | if spatial_idx == -1: 116 | if aspect_ratio is None and scale is None: 117 | frames = transform.random_short_side_scale_jitter( 118 | images=frames, 119 | min_size=min_scale, 120 | max_size=max_scale, 121 | inverse_uniform_sampling=inverse_uniform_sampling, 122 | ) 123 | frames = transform.random_crop(frames, crop_size) 124 | else: 125 | transform_func = ( 126 | transform.random_resized_crop_with_shift 127 | if motion_shift 128 | else transform.random_resized_crop 129 | ) 130 | frames = transform_func( 131 | images=frames, 132 | target_height=crop_size, 133 | target_width=crop_size, 134 | scale=scale, 135 | ratio=aspect_ratio, 136 | ) 137 | if random_horizontal_flip: 138 | frames = transform.horizontal_flip(0.5, frames) 139 | else: 140 | # The testing is deterministic and no jitter should be performed. 141 | # min_scale, max_scale, and crop_size are expect to be the same. 142 | assert len({min_scale, max_scale}) == 1 143 | frames = transform.random_short_side_scale_jitter(frames, min_scale, max_scale) 144 | frames = transform.uniform_crop(frames, crop_size, spatial_idx) 145 | return frames 146 | 147 | 148 | def as_binary_vector(labels, num_classes): 149 | """ 150 | Construct binary label vector given a list of label indices. 151 | Args: 152 | labels (list): The input label list. 153 | num_classes (int): Number of classes of the label vector. 154 | Returns: 155 | labels (numpy array): the resulting binary vector. 156 | """ 157 | label_arr = np.zeros((num_classes,)) 158 | 159 | for lbl in set(labels): 160 | label_arr[lbl] = 1.0 161 | return label_arr 162 | 163 | 164 | def aggregate_labels(label_list): 165 | """ 166 | Join a list of label list. 167 | Args: 168 | labels (list): The input label list. 169 | Returns: 170 | labels (list): The joint list of all lists in input. 171 | """ 172 | all_labels = [] 173 | for labels in label_list: 174 | for l in labels: 175 | all_labels.append(l) 176 | return list(set(all_labels)) 177 | 178 | 179 | def convert_to_video_level_labels(labels): 180 | """ 181 | Aggregate annotations from all frames of a video to form video-level labels. 182 | Args: 183 | labels (list): The input label list. 184 | Returns: 185 | labels (list): Same as input, but with each label replaced by 186 | a video-level one. 187 | """ 188 | for video_id in range(len(labels)): 189 | video_level_labels = aggregate_labels(labels[video_id]) 190 | for i in range(len(labels[video_id])): 191 | labels[video_id][i] = video_level_labels 192 | return labels 193 | 194 | 195 | def load_image_lists(frame_list_file, prefix="", return_list=False): 196 | """ 197 | Load image paths and labels from a "frame list". 198 | Each line of the frame list contains: 199 | `original_vido_id video_id frame_id path labels` 200 | Args: 201 | frame_list_file (string): path to the frame list. 202 | prefix (str): the prefix for the path. 203 | return_list (bool): if True, return a list. If False, return a dict. 204 | Returns: 205 | image_paths (list or dict): list of list containing path to each frame. 206 | If return_list is False, then return in a dict form. 207 | labels (list or dict): list of list containing label of each frame. 208 | If return_list is False, then return in a dict form. 209 | """ 210 | image_paths = defaultdict(list) 211 | labels = defaultdict(list) 212 | with pathmgr.open(frame_list_file, "r") as f: 213 | assert f.readline().startswith("original_vido_id") 214 | for line in f: 215 | row = line.split() 216 | # original_vido_id video_id frame_id path labels 217 | assert len(row) == 5 218 | video_name = row[0] 219 | if prefix == "": 220 | path = row[3] 221 | else: 222 | path = os.path.join(prefix, row[3]) 223 | image_paths[video_name].append(path) 224 | frame_labels = row[-1].replace('"', "") 225 | if frame_labels != "": 226 | labels[video_name].append([int(x) for x in frame_labels.split(",")]) 227 | else: 228 | labels[video_name].append([]) 229 | 230 | if return_list: 231 | keys = image_paths.keys() 232 | image_paths = [image_paths[key] for key in keys] 233 | labels = [labels[key] for key in keys] 234 | return image_paths, labels 235 | return dict(image_paths), dict(labels) 236 | 237 | 238 | def tensor_normalize(tensor, mean, std): 239 | """ 240 | Normalize a given tensor by subtracting the mean and dividing the std. 241 | Args: 242 | tensor (tensor): tensor to normalize. 243 | mean (tensor or list): mean value to subtract. 244 | std (tensor or list): std to divide. 245 | """ 246 | if tensor.dtype == torch.uint8: 247 | tensor = tensor.float() 248 | tensor = tensor / 255.0 249 | if type(mean) == tuple: 250 | mean = torch.tensor(mean) 251 | if type(std) == tuple: 252 | std = torch.tensor(std) 253 | tensor = tensor - mean 254 | tensor = tensor / std 255 | return tensor 256 | 257 | 258 | def get_random_sampling_rate(long_cycle_sampling_rate, sampling_rate): 259 | """ 260 | When multigrid training uses a fewer number of frames, we randomly 261 | increase the sampling rate so that some clips cover the original span. 262 | """ 263 | if long_cycle_sampling_rate > 0: 264 | assert long_cycle_sampling_rate >= sampling_rate 265 | return random.randint(sampling_rate, long_cycle_sampling_rate) 266 | else: 267 | return sampling_rate 268 | 269 | 270 | def revert_tensor_normalize(tensor, mean, std): 271 | """ 272 | Revert normalization for a given tensor by multiplying by the std and adding the mean. 273 | Args: 274 | tensor (tensor): tensor to revert normalization. 275 | mean (tensor or list): mean value to add. 276 | std (tensor or list): std to multiply. 277 | """ 278 | if type(mean) == list: 279 | mean = torch.tensor(mean) 280 | if type(std) == list: 281 | std = torch.tensor(std) 282 | tensor = tensor * std 283 | tensor = tensor + mean 284 | return tensor 285 | 286 | 287 | def create_sampler(dataset, shuffle, cfg): 288 | """ 289 | Create sampler for the given dataset. 290 | Args: 291 | dataset (torch.utils.data.Dataset): the given dataset. 292 | shuffle (bool): set to ``True`` to have the data reshuffled 293 | at every epoch. 294 | cfg (CfgNode): configs. Details can be found in 295 | slowfast/config/defaults.py 296 | Returns: 297 | sampler (Sampler): the created sampler. 298 | """ 299 | sampler = DistributedSampler(dataset) if cfg.NUM_GPUS > 1 else None 300 | 301 | return sampler 302 | 303 | 304 | def loader_worker_init_fn(dataset): 305 | """ 306 | Create init function passed to pytorch data loader. 307 | Args: 308 | dataset (torch.utils.data.Dataset): the given dataset. 309 | """ 310 | return None 311 | -------------------------------------------------------------------------------- /util/decoder/video_container.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | 5 | import av 6 | 7 | 8 | def get_video_container(path_to_vid, multi_thread_decode=False): 9 | """ 10 | Given the path to the video, return the pyav video container. 11 | Args: 12 | path_to_vid (str): path to the video. 13 | multi_thread_decode (bool): if True, perform multi-thread decoding. 14 | backend (str): decoder backend, options include `pyav` and 15 | `torchvision`, default is `pyav`. 16 | Returns: 17 | container (container): video container. 18 | """ 19 | with open(path_to_vid, "rb") as fp: 20 | container = fp.read() 21 | return container 22 | -------------------------------------------------------------------------------- /util/env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | 5 | """Set up Environment.""" 6 | 7 | from iopath.common.file_io import PathManagerFactory 8 | 9 | _ENV_SETUP_DONE = False 10 | pathmgr = PathManagerFactory.get(key="mae_st") 11 | checkpoint_pathmgr = PathManagerFactory.get(key="mae_st_checkpoint") 12 | 13 | 14 | def setup_environment(): 15 | global _ENV_SETUP_DONE 16 | if _ENV_SETUP_DONE: 17 | return 18 | _ENV_SETUP_DONE = True 19 | -------------------------------------------------------------------------------- /util/kinetics.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | 5 | import os 6 | import random 7 | 8 | import torch 9 | import torch.utils.data 10 | 11 | from iopath.common.file_io import g_pathmgr as pathmgr 12 | from mae_st.util.decoder.decoder import get_start_end_idx, temporal_sampling 13 | from torchvision import transforms 14 | 15 | from .decoder import decoder as decoder, utils as utils, video_container as container 16 | from .decoder.random_erasing import RandomErasing 17 | from .decoder.transform import create_random_augment 18 | 19 | 20 | class Kinetics(torch.utils.data.Dataset): 21 | """ 22 | Kinetics video loader. Construct the Kinetics video loader, then sample 23 | clips from the videos. For training and validation, a single clip is 24 | randomly sampled from every video with random cropping, scaling, and 25 | flipping. For testing, multiple clips are uniformaly sampled from every 26 | video with uniform cropping. For uniform cropping, we take the left, center, 27 | and right crop if the width is larger than height, or take top, center, and 28 | bottom crop if the height is larger than the width. 29 | """ 30 | 31 | def __init__( 32 | self, 33 | mode, 34 | path_to_data_dir, 35 | # decoding setting 36 | sampling_rate=4, 37 | num_frames=16, 38 | target_fps=30, 39 | # train aug settings 40 | train_jitter_scales=(256, 320), 41 | train_crop_size=224, 42 | train_random_horizontal_flip=True, 43 | # test setting, multi crops 44 | test_num_ensemble_views=10, 45 | test_num_spatial_crops=3, 46 | test_crop_size=256, 47 | # norm setting 48 | mean=(0.45, 0.45, 0.45), 49 | std=(0.225, 0.225, 0.225), 50 | # other parameters 51 | enable_multi_thread_decode=False, 52 | use_offset_sampling=True, 53 | inverse_uniform_sampling=False, 54 | num_retries=10, 55 | # pretrain augmentation 56 | repeat_aug=1, 57 | aa_type="rand-m7-n4-mstd0.5-inc1", 58 | pretrain_rand_flip=True, 59 | pretrain_rand_erase_prob=0.25, 60 | pretrain_rand_erase_mode="pixel", 61 | pretrain_rand_erase_count=1, 62 | pretrain_rand_erase_split=False, 63 | rand_aug=False, 64 | jitter_scales_relative=[0.5, 1.0], 65 | jitter_aspect_relative=[0.75, 1.3333], 66 | ): 67 | """ 68 | Construct the Kinetics video loader with a given csv file. The format of 69 | the csv file is: 70 | ``` 71 | path_to_video_1 label_1 72 | path_to_video_2 label_2 73 | ... 74 | path_to_video_N label_N 75 | ``` 76 | Args: 77 | mode (string): Options includes `train`, `val`, or `test` mode. 78 | For the train and val mode, the data loader will take data 79 | from the train or val set, and sample one clip per video. 80 | For the test mode, the data loader will take data from test set, 81 | and sample multiple clips per video. 82 | num_retries (int): number of retries. 83 | """ 84 | # Only support train, val, and test mode. 85 | assert mode in [ 86 | "pretrain", 87 | "finetune", 88 | "val", 89 | "test", 90 | ], "Split '{}' not supported for Kinetics".format(mode) 91 | self.mode = mode 92 | self.aa_type = aa_type 93 | self.pretrain_rand_flip = pretrain_rand_flip 94 | self.pretrain_rand_erase_prob = pretrain_rand_erase_prob 95 | self.pretrain_rand_erase_mode = pretrain_rand_erase_mode 96 | self.pretrain_rand_erase_count = pretrain_rand_erase_count 97 | self.pretrain_rand_erase_split = pretrain_rand_erase_split 98 | 99 | self.jitter_aspect_relative = jitter_aspect_relative 100 | self.jitter_scales_relative = jitter_scales_relative 101 | 102 | print( 103 | f"jitter_aspect_relative {jitter_aspect_relative} jitter_scales_relative {jitter_scales_relative}" 104 | ) 105 | 106 | self._repeat_aug = repeat_aug 107 | self._video_meta = {} 108 | self._num_retries = num_retries 109 | self._path_to_data_dir = path_to_data_dir 110 | 111 | self._train_jitter_scales = train_jitter_scales 112 | self._train_crop_size = train_crop_size 113 | self._train_random_horizontal_flip = train_random_horizontal_flip 114 | 115 | self._test_num_ensemble_views = test_num_ensemble_views 116 | self._test_num_spatial_crops = test_num_spatial_crops 117 | self._test_crop_size = test_crop_size 118 | 119 | self._sampling_rate = sampling_rate 120 | self._num_frames = num_frames 121 | self._target_fps = target_fps 122 | 123 | self._mean = mean 124 | self._std = std 125 | 126 | self._enable_multi_thread_decode = enable_multi_thread_decode 127 | self._inverse_uniform_sampling = inverse_uniform_sampling 128 | self._use_offset_sampling = use_offset_sampling 129 | 130 | print(self) 131 | print(locals()) 132 | 133 | # For training or validation mode, one single clip is sampled from every 134 | # video. For testing, NUM_ENSEMBLE_VIEWS clips are sampled from every 135 | # video. For every clip, NUM_SPATIAL_CROPS is cropped spatially from 136 | # the frames. 137 | if self.mode in ["pretrain", "finetune", "val"]: 138 | self._num_clips = 1 139 | elif self.mode in ["test"]: 140 | self._num_clips = test_num_ensemble_views * test_num_spatial_crops 141 | 142 | print("Constructing Kinetics {}...".format(mode)) 143 | self._construct_loader() 144 | if self.mode in ["pretrain", "val", "test"]: 145 | self.rand_aug = False 146 | print("Perform standard augmentation") 147 | else: 148 | self.rand_aug = rand_aug 149 | print("Perform rand augmentation") 150 | self.use_temporal_gradient = False 151 | self.temporal_gradient_rate = 0.0 152 | 153 | def _construct_loader(self): 154 | """ 155 | Construct the video loader. 156 | """ 157 | csv_file_name = { 158 | "pretrain": "train", 159 | "finetune": "train", 160 | "val": "val", 161 | "test": "test", 162 | } 163 | path_to_file = os.path.join( 164 | self._path_to_data_dir, 165 | "{}.csv".format(csv_file_name[self.mode]), 166 | ) 167 | assert pathmgr.exists(path_to_file), "{} dir not found".format(path_to_file) 168 | 169 | self._path_to_videos = [] 170 | self._labels = [] 171 | self._spatial_temporal_idx = [] 172 | with pathmgr.open(path_to_file, "r") as f: 173 | for clip_idx, path_label in enumerate(f.read().splitlines()): 174 | assert len(path_label.split()) == 2 175 | path, label = path_label.split() 176 | for idx in range(self._num_clips): 177 | self._path_to_videos.append(os.path.join(path)) 178 | self._labels.append(int(label)) 179 | self._spatial_temporal_idx.append(idx) 180 | self._video_meta[clip_idx * self._num_clips + idx] = {} 181 | assert ( 182 | len(self._path_to_videos) > 0 183 | ), "Failed to load Kinetics split {} from {}".format( 184 | self._split_idx, path_to_file 185 | ) 186 | print( 187 | "Constructing kinetics dataloader (size: {}) from {}".format( 188 | len(self._path_to_videos), path_to_file 189 | ) 190 | ) 191 | 192 | def __getitem__(self, index): 193 | """ 194 | Given the video index, return the list of frames, label, and video 195 | index if the video can be fetched and decoded successfully, otherwise 196 | repeatly find a random video that can be decoded as a replacement. 197 | Args: 198 | index (int): the video index provided by the pytorch sampler. 199 | Returns: 200 | frames (tensor): the frames of sampled from the video. The dimension 201 | is `channel` x `num frames` x `height` x `width`. 202 | label (int): the label of the current video. 203 | index (int): if the video provided by pytorch sampler can be 204 | decoded, then return the index of the video. If not, return the 205 | index of the video replacement that can be decoded. 206 | """ 207 | if self.mode in ["pretrain", "finetune", "val"]: 208 | # -1 indicates random sampling. 209 | temporal_sample_index = -1 210 | spatial_sample_index = -1 211 | min_scale, max_scale = self._train_jitter_scales 212 | crop_size = self._train_crop_size 213 | elif self.mode in ["test"]: 214 | temporal_sample_index = ( 215 | self._spatial_temporal_idx[index] // self._test_num_spatial_crops 216 | ) 217 | # spatial_sample_index is in [0, 1, 2]. Corresponding to left, 218 | # center, or right if width is larger than height, and top, middle, 219 | # or bottom if height is larger than width. 220 | spatial_sample_index = ( 221 | (self._spatial_temporal_idx[index] % self._test_num_spatial_crops) 222 | if self._test_num_spatial_crops > 1 223 | else 1 224 | ) 225 | min_scale, max_scale, crop_size = ( 226 | [self._test_crop_size] * 3 227 | if self._test_num_spatial_crops > 1 228 | else [self._train_jitter_scales[0]] * 2 + [self._test_crop_size] 229 | ) 230 | # The testing is deterministic and no jitter should be performed. 231 | # min_scale, max_scale, and crop_size are expect to be the same. 232 | assert len({min_scale, max_scale}) == 1 233 | else: 234 | raise NotImplementedError("Does not support {} mode".format(self.mode)) 235 | sampling_rate = self._sampling_rate 236 | # Try to decode and sample a clip from a video. If the video can not be 237 | # decoded, repeatly find a random video replacement that can be decoded. 238 | for i_try in range(self._num_retries): 239 | video_container = None 240 | try: 241 | video_container = container.get_video_container( 242 | self._path_to_videos[index], 243 | self._enable_multi_thread_decode, 244 | ) 245 | except Exception as e: 246 | print( 247 | "Failed to load video from {} with error {}".format( 248 | self._path_to_videos[index], e 249 | ) 250 | ) 251 | # Select a random video if the current video was not able to access. 252 | if video_container is None: 253 | print( 254 | "Failed to meta load video idx {} from {}; trial {}".format( 255 | index, self._path_to_videos[index], i_try 256 | ) 257 | ) 258 | if self.mode not in ["test"] and i_try > self._num_retries // 2: 259 | # let's try another one 260 | index = random.randint(0, len(self._path_to_videos) - 1) 261 | continue 262 | 263 | # Decode video. Meta info is used to perform selective decoding. 264 | frames, fps, decode_all_video = decoder.decode( 265 | video_container, 266 | sampling_rate, 267 | self._num_frames, 268 | temporal_sample_index, 269 | self._test_num_ensemble_views, 270 | video_meta=self._video_meta[index], 271 | target_fps=self._target_fps, 272 | max_spatial_scale=min_scale, 273 | use_offset=self._use_offset_sampling, 274 | rigid_decode_all_video=self.mode in ["pretrain"], 275 | ) 276 | 277 | # If decoding failed (wrong format, video is too short, and etc), 278 | # select another video. 279 | if frames is None: 280 | print( 281 | "Failed to decode video idx {} from {}; trial {}".format( 282 | index, self._path_to_videos[index], i_try 283 | ) 284 | ) 285 | if self.mode not in ["test"] and i_try > self._num_retries // 2: 286 | # let's try another one 287 | index = random.randint(0, len(self._path_to_videos) - 1) 288 | continue 289 | 290 | frames_list = [] 291 | label_list = [] 292 | label = self._labels[index] 293 | if self.rand_aug: 294 | for i in range(self._repeat_aug): 295 | clip_sz = sampling_rate * self._num_frames / self._target_fps * fps 296 | start_idx, end_idx = get_start_end_idx( 297 | frames.shape[0], 298 | clip_sz, 299 | temporal_sample_index if decode_all_video else 0, 300 | self._test_num_ensemble_views if decode_all_video else 1, 301 | use_offset=self._use_offset_sampling, 302 | ) 303 | # Perform temporal sampling from the decoded video. 304 | new_frames = temporal_sampling( 305 | frames, start_idx, end_idx, self._num_frames 306 | ) 307 | new_frames = self._aug_frame( 308 | new_frames, 309 | spatial_sample_index, 310 | min_scale, 311 | max_scale, 312 | crop_size, 313 | ) 314 | frames_list.append(new_frames) 315 | label_list.append(label) 316 | else: 317 | # T H W C -> C T H W. 318 | for i in range(self._repeat_aug): 319 | clip_sz = sampling_rate * self._num_frames / self._target_fps * fps 320 | start_idx, end_idx = get_start_end_idx( 321 | frames.shape[0], 322 | clip_sz, 323 | temporal_sample_index if decode_all_video else 0, 324 | self._test_num_ensemble_views if decode_all_video else 1, 325 | use_offset=self._use_offset_sampling, 326 | ) 327 | # Perform temporal sampling from the decoded video. 328 | new_frames = temporal_sampling( 329 | frames, start_idx, end_idx, self._num_frames 330 | ) 331 | 332 | new_frames = utils.tensor_normalize( 333 | new_frames, self._mean, self._std 334 | ) 335 | new_frames = new_frames.permute(3, 0, 1, 2) 336 | 337 | scl, asp = ( 338 | self.jitter_scales_relative, 339 | self.jitter_aspect_relative, 340 | ) 341 | relative_scales = ( 342 | None 343 | if (self.mode not in ["pretrain", "finetune"] or len(scl) == 0) 344 | else scl 345 | ) 346 | relative_aspect = ( 347 | None 348 | if (self.mode not in ["pretrain", "finetune"] or len(asp) == 0) 349 | else asp 350 | ) 351 | 352 | # Perform data augmentation. 353 | new_frames = utils.spatial_sampling( 354 | new_frames, 355 | spatial_idx=spatial_sample_index, 356 | min_scale=min_scale, 357 | max_scale=max_scale, 358 | crop_size=crop_size, 359 | random_horizontal_flip=self._train_random_horizontal_flip, 360 | inverse_uniform_sampling=self._inverse_uniform_sampling, 361 | aspect_ratio=relative_aspect, 362 | scale=relative_scales, 363 | ) 364 | frames_list.append(new_frames) 365 | label_list.append(label) 366 | frames = torch.stack(frames_list, dim=0) 367 | 368 | if self.mode in ["test"]: 369 | return frames, torch.tensor(label_list), index 370 | else: 371 | return frames, torch.tensor(label_list) 372 | else: 373 | raise RuntimeError( 374 | "Failed to fetch video after {} retries.".format(self._num_retries) 375 | ) 376 | 377 | def _aug_frame( 378 | self, 379 | frames, 380 | spatial_sample_index, 381 | min_scale, 382 | max_scale, 383 | crop_size, 384 | ): 385 | aug_transform = create_random_augment( 386 | input_size=(frames.size(1), frames.size(2)), 387 | auto_augment=self.aa_type, 388 | interpolation="bicubic", 389 | ) 390 | # T H W C -> T C H W. 391 | frames = frames.permute(0, 3, 1, 2) 392 | list_img = self._frame_to_list_img(frames) 393 | list_img = aug_transform(list_img) 394 | frames = self._list_img_to_frames(list_img) 395 | frames = frames.permute(0, 2, 3, 1) 396 | 397 | frames = utils.tensor_normalize( 398 | frames, 399 | (0.45, 0.45, 0.45), 400 | (0.225, 0.225, 0.225), 401 | ) 402 | # T H W C -> C T H W. 403 | frames = frames.permute(3, 0, 1, 2) 404 | # Perform data augmentation. 405 | scl, asp = ( 406 | self.jitter_scales_relative, 407 | self.jitter_aspect_relative, 408 | ) 409 | relative_scales = ( 410 | None 411 | if (self.mode not in ["pretrain", "finetune"] or len(scl) == 0) 412 | else scl 413 | ) 414 | relative_aspect = ( 415 | None 416 | if (self.mode not in ["pretrain", "finetune"] or len(asp) == 0) 417 | else asp 418 | ) 419 | frames = utils.spatial_sampling( 420 | frames, 421 | spatial_idx=spatial_sample_index, 422 | min_scale=min_scale, 423 | max_scale=max_scale, 424 | crop_size=crop_size, 425 | random_horizontal_flip=self.pretrain_rand_flip, 426 | inverse_uniform_sampling=False, 427 | aspect_ratio=relative_aspect, 428 | scale=relative_scales, 429 | motion_shift=False, 430 | ) 431 | 432 | if self.pretrain_rand_erase_prob > 0.0: 433 | erase_transform = RandomErasing( 434 | self.pretrain_rand_erase_prob, 435 | mode=self.pretrain_rand_erase_mode, 436 | max_count=self.pretrain_rand_erase_count, 437 | num_splits=self.pretrain_rand_erase_count, 438 | device="cpu", 439 | ) 440 | frames = frames.permute(1, 0, 2, 3) 441 | frames = erase_transform(frames) 442 | frames = frames.permute(1, 0, 2, 3) 443 | 444 | return frames 445 | 446 | def _frame_to_list_img(self, frames): 447 | img_list = [transforms.ToPILImage()(frames[i]) for i in range(frames.size(0))] 448 | return img_list 449 | 450 | def _list_img_to_frames(self, img_list): 451 | img_list = [transforms.ToTensor()(img) for img in img_list] 452 | return torch.stack(img_list) 453 | 454 | def __len__(self): 455 | """ 456 | Returns: 457 | (int): the number of videos in the dataset. 458 | """ 459 | return self.num_videos 460 | 461 | @property 462 | def num_videos(self): 463 | """ 464 | Returns: 465 | (int): the number of videos in the dataset. 466 | """ 467 | return len(self._path_to_videos) 468 | -------------------------------------------------------------------------------- /util/logging.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | 5 | """Logging.""" 6 | 7 | import atexit 8 | import builtins 9 | import decimal 10 | import functools 11 | import logging 12 | import os 13 | import sys 14 | 15 | import simplejson 16 | import torch 17 | import torch.distributed as dist 18 | from iopath.common.file_io import g_pathmgr as pathmgr 19 | 20 | 21 | def is_master_proc(multinode=False): 22 | """ 23 | Determines if the current process is the master process. 24 | """ 25 | if dist.is_initialized(): 26 | if multinode: 27 | return dist.get_rank() % dist.get_world_size() == 0 28 | else: 29 | return dist.get_rank() % torch.cuda.device_count() == 0 30 | else: 31 | return True 32 | 33 | 34 | def _suppress_print(): 35 | """ 36 | Suppresses printing from the current process. 37 | """ 38 | 39 | def print_pass(*objects, sep=" ", end="\n", file=sys.stdout, flush=False): 40 | pass 41 | 42 | builtins.print = print_pass 43 | 44 | 45 | @functools.lru_cache(maxsize=None) 46 | def _cached_log_stream(filename): 47 | # Use 1K buffer if writing to cloud storage. 48 | io = pathmgr.open(filename, "a", buffering=1024 if "://" in filename else -1) 49 | atexit.register(io.close) 50 | return io 51 | 52 | 53 | def setup_logging(output_dir=None): 54 | """ 55 | Sets up the logging for multiple processes. Only enable the logging for the 56 | master process, and suppress logging for the non-master processes. 57 | """ 58 | # Set up logging format. 59 | if is_master_proc(): 60 | # Enable logging for the master process. 61 | logging.root.handlers = [] 62 | else: 63 | # Suppress logging for non-master processes. 64 | _suppress_print() 65 | 66 | logger = logging.getLogger() 67 | logger.setLevel(logging.DEBUG) 68 | logger.propagate = False 69 | plain_formatter = logging.Formatter( 70 | "[%(asctime)s][%(levelname)s] %(filename)s: %(lineno)3d: %(message)s", 71 | datefmt="%m/%d %H:%M:%S", 72 | ) 73 | 74 | if is_master_proc(): 75 | ch = logging.StreamHandler(stream=sys.stdout) 76 | ch.setLevel(logging.DEBUG) 77 | ch.setFormatter(plain_formatter) 78 | logger.addHandler(ch) 79 | 80 | if output_dir is not None and is_master_proc(multinode=True): 81 | filename = os.path.join(output_dir, "stdout.log") 82 | fh = logging.StreamHandler(_cached_log_stream(filename)) 83 | fh.setLevel(logging.DEBUG) 84 | fh.setFormatter(plain_formatter) 85 | logger.addHandler(fh) 86 | 87 | 88 | def get_logger(name): 89 | """ 90 | Retrieve the logger with the specified name or, if name is None, return a 91 | logger which is the root logger of the hierarchy. 92 | Args: 93 | name (string): name of the logger. 94 | """ 95 | return logging.getLogger(name) 96 | 97 | 98 | def log_json_stats(stats): 99 | """ 100 | Logs json stats. 101 | Args: 102 | stats (dict): a dictionary of statistical information to log. 103 | """ 104 | stats = { 105 | k: decimal.Decimal("{:.5f}".format(v)) if isinstance(v, float) else v 106 | for k, v in stats.items() 107 | } 108 | json_stats = simplejson.dumps(stats, sort_keys=True, use_decimal=True) 109 | logger = get_logger(__name__) 110 | print("json_stats: {:s}".format(json_stats)) 111 | 112 | 113 | def master_print(*args, **kwargs): 114 | if is_master_proc(): 115 | print(*args, **kwargs) 116 | else: 117 | pass 118 | -------------------------------------------------------------------------------- /util/lr_decay.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # ELECTRA https://github.com/google-research/electra 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import json 13 | 14 | 15 | def param_groups_lrd( 16 | model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=0.75 17 | ): 18 | """ 19 | Parameter groups for layer-wise lr decay 20 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 21 | """ 22 | param_group_names = {} 23 | param_groups = {} 24 | 25 | num_layers = len(model.blocks) + 1 26 | 27 | layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1)) 28 | 29 | for n, p in model.named_parameters(): 30 | if not p.requires_grad: 31 | continue 32 | 33 | # no decay: all 1D parameters and model specific ones 34 | if p.ndim == 1 or n in no_weight_decay_list: 35 | g_decay = "no_decay" 36 | this_decay = 0.0 37 | else: 38 | g_decay = "decay" 39 | this_decay = weight_decay 40 | 41 | layer_id = get_layer_id_for_vit(n, num_layers) 42 | group_name = "layer_%d_%s" % (layer_id, g_decay) 43 | 44 | if group_name not in param_group_names: 45 | this_scale = layer_scales[layer_id] 46 | 47 | param_group_names[group_name] = { 48 | "lr_scale": this_scale, 49 | "weight_decay": this_decay, 50 | "params": [], 51 | } 52 | param_groups[group_name] = { 53 | "lr_scale": this_scale, 54 | "weight_decay": this_decay, 55 | "params": [], 56 | } 57 | 58 | param_group_names[group_name]["params"].append(n) 59 | param_groups[group_name]["params"].append(p) 60 | 61 | print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) 62 | 63 | return list(param_groups.values()) 64 | 65 | 66 | def get_layer_id_for_vit(name, num_layers): 67 | """ 68 | Assign a parameter with its layer id 69 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 70 | """ 71 | if name in [ 72 | "cls_token", 73 | "mask_token", 74 | ]: 75 | return 0 76 | elif name.startswith("patch_embed"): 77 | return 0 78 | elif name.startswith("pos_embed"): 79 | return 0 80 | elif name.startswith("blocks"): 81 | return int(name.split(".")[1]) + 1 82 | else: 83 | return num_layers 84 | -------------------------------------------------------------------------------- /util/lr_sched.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | 10 | def adjust_learning_rate(optimizer, epoch, args): 11 | """Decay the learning rate with half-cycle cosine after warmup""" 12 | if epoch < args.warmup_epochs: 13 | lr = args.lr * epoch / args.warmup_epochs 14 | else: 15 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * ( 16 | 1.0 17 | + math.cos( 18 | math.pi 19 | * (epoch - args.warmup_epochs) 20 | / (args.epochs - args.warmup_epochs) 21 | ) 22 | ) 23 | for param_group in optimizer.param_groups: 24 | if "lr_scale" in param_group: 25 | param_group["lr"] = lr * param_group["lr_scale"] 26 | else: 27 | param_group["lr"] = lr 28 | return lr 29 | -------------------------------------------------------------------------------- /util/meters.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | 5 | import numpy as np 6 | import torch 7 | from sklearn.metrics import average_precision_score 8 | 9 | 10 | def topks_correct(preds, labels, ks): 11 | """ 12 | Given the predictions, labels, and a list of top-k values, compute the 13 | number of correct predictions for each top-k value. 14 | 15 | Args: 16 | preds (array): array of predictions. Dimension is batchsize 17 | N x ClassNum. 18 | labels (array): array of labels. Dimension is batchsize N. 19 | ks (list): list of top-k values. For example, ks = [1, 5] correspods 20 | to top-1 and top-5. 21 | 22 | Returns: 23 | topks_correct (list): list of numbers, where the `i`-th entry 24 | corresponds to the number of top-`ks[i]` correct predictions. 25 | """ 26 | assert preds.size(0) == labels.size( 27 | 0 28 | ), "Batch dim of predictions and labels must match" 29 | # Find the top max_k predictions for each sample 30 | _top_max_k_vals, top_max_k_inds = torch.topk( 31 | preds, max(ks), dim=1, largest=True, sorted=True 32 | ) 33 | # (batch_size, max_k) -> (max_k, batch_size). 34 | top_max_k_inds = top_max_k_inds.t() 35 | # (batch_size, ) -> (max_k, batch_size). 36 | rep_max_k_labels = labels.view(1, -1).expand_as(top_max_k_inds) 37 | # (i, j) = 1 if top i-th prediction for the j-th sample is correct. 38 | top_max_k_correct = top_max_k_inds.eq(rep_max_k_labels) 39 | # Compute the number of topk correct predictions for each k. 40 | topks_correct = [top_max_k_correct[:k, :].float().sum() for k in ks] 41 | return topks_correct 42 | 43 | 44 | def topk_errors(preds, labels, ks): 45 | """ 46 | Computes the top-k error for each k. 47 | Args: 48 | preds (array): array of predictions. Dimension is N. 49 | labels (array): array of labels. Dimension is N. 50 | ks (list): list of ks to calculate the top accuracies. 51 | """ 52 | num_topks_correct = topks_correct(preds, labels, ks) 53 | return [(1.0 - x / preds.size(0)) * 100.0 for x in num_topks_correct] 54 | 55 | 56 | def topk_accuracies(preds, labels, ks): 57 | """ 58 | Computes the top-k accuracy for each k. 59 | Args: 60 | preds (array): array of predictions. Dimension is N. 61 | labels (array): array of labels. Dimension is N. 62 | ks (list): list of ks to calculate the top accuracies. 63 | """ 64 | num_topks_correct = topks_correct(preds, labels, ks) 65 | return [(x / preds.size(0)) * 100.0 for x in num_topks_correct] 66 | 67 | 68 | def get_map(preds, labels): 69 | """ 70 | Compute mAP for multi-label case. 71 | Args: 72 | preds (numpy tensor): num_examples x num_classes. 73 | labels (numpy tensor): num_examples x num_classes. 74 | Returns: 75 | mean_ap (int): final mAP score. 76 | """ 77 | 78 | print("Getting mAP for {} examples".format(preds.shape[0])) 79 | 80 | preds = preds[:, ~(np.all(labels == 0, axis=0))] 81 | labels = labels[:, ~(np.all(labels == 0, axis=0))] 82 | aps = [0] 83 | try: 84 | aps = average_precision_score(labels, preds, average=None) 85 | except ValueError: 86 | print( 87 | "Average precision requires a sufficient number of samples \ 88 | in a batch which are missing in this sample." 89 | ) 90 | 91 | mean_ap = np.mean(aps) 92 | return mean_ap 93 | 94 | 95 | class TestMeter: 96 | """ 97 | Perform the multi-view ensemble for testing: each video with an unique index 98 | will be sampled with multiple clips, and the predictions of the clips will 99 | be aggregated to produce the final prediction for the video. 100 | The accuracy is calculated with the given ground truth labels. 101 | """ 102 | 103 | def __init__( 104 | self, 105 | num_videos, 106 | num_clips, 107 | num_cls, 108 | overall_iters, 109 | multi_label=False, 110 | ensemble_method="sum", 111 | ): 112 | """ 113 | Construct tensors to store the predictions and labels. Expect to get 114 | num_clips predictions from each video, and calculate the metrics on 115 | num_videos videos. 116 | Args: 117 | num_videos (int): number of videos to test. 118 | num_clips (int): number of clips sampled from each video for 119 | aggregating the final prediction for the video. 120 | num_cls (int): number of classes for each prediction. 121 | overall_iters (int): overall iterations for testing. 122 | multi_label (bool): if True, use map as the metric. 123 | ensemble_method (str): method to perform the ensemble, options 124 | include "sum", and "max". 125 | """ 126 | 127 | self.num_clips = num_clips 128 | self.overall_iters = overall_iters 129 | self.multi_label = multi_label 130 | self.ensemble_method = ensemble_method 131 | # Initialize tensors. 132 | self.video_preds = torch.zeros((num_videos, num_cls)) 133 | if multi_label: 134 | self.video_preds -= 1e10 135 | 136 | self.video_labels = ( 137 | torch.zeros((num_videos, num_cls)) 138 | if multi_label 139 | else torch.zeros((num_videos)).long() 140 | ) 141 | self.clip_count = torch.zeros((num_videos)).long() 142 | self.topk_accs = [] 143 | self.stats = {} 144 | 145 | # Reset metric. 146 | self.reset() 147 | 148 | def reset(self): 149 | """ 150 | Reset the metric. 151 | """ 152 | self.clip_count.zero_() 153 | self.video_preds.zero_() 154 | if self.multi_label: 155 | self.video_preds -= 1e10 156 | self.video_labels.zero_() 157 | 158 | def update_stats(self, preds, labels, clip_ids): 159 | """ 160 | Collect the predictions from the current batch and perform on-the-flight 161 | summation as ensemble. 162 | Args: 163 | preds (tensor): predictions from the current batch. Dimension is 164 | N x C where N is the batch size and C is the channel size 165 | (num_cls). 166 | labels (tensor): the corresponding labels of the current batch. 167 | Dimension is N. 168 | clip_ids (tensor): clip indexes of the current batch, dimension is 169 | N. 170 | """ 171 | for ind in range(preds.shape[0]): 172 | vid_id = int(clip_ids[ind]) // self.num_clips 173 | if self.video_labels[vid_id].sum() > 0: 174 | assert torch.equal( 175 | self.video_labels[vid_id].type(torch.FloatTensor), 176 | labels[ind].type(torch.FloatTensor), 177 | ) 178 | self.video_labels[vid_id] = labels[ind] 179 | if self.ensemble_method == "sum": 180 | self.video_preds[vid_id] += preds[ind] 181 | elif self.ensemble_method == "max": 182 | self.video_preds[vid_id] = torch.max( 183 | self.video_preds[vid_id], preds[ind] 184 | ) 185 | else: 186 | raise NotImplementedError( 187 | "Ensemble Method {} is not supported".format(self.ensemble_method) 188 | ) 189 | self.clip_count[vid_id] += 1 190 | 191 | def log_iter_stats(self, cur_iter): 192 | """ 193 | Log the stats. 194 | Args: 195 | cur_iter (int): the current iteration of testing. 196 | """ 197 | stats = { 198 | "split": "test_iter", 199 | "cur_iter": "{}".format(cur_iter + 1), 200 | } 201 | print(stats) 202 | 203 | def finalize_metrics(self, ks=(1, 5)): 204 | """ 205 | Calculate and log the final ensembled metrics. 206 | ks (tuple): list of top-k values for topk_accuracies. For example, 207 | ks = (1, 5) correspods to top-1 and top-5 accuracy. 208 | """ 209 | if not all(self.clip_count == self.num_clips): 210 | print( 211 | "clip count {} ~= num clips {}".format( 212 | ", ".join( 213 | [ 214 | "{}: {}".format(i, k) 215 | for i, k in enumerate(self.clip_count.tolist()) 216 | ] 217 | ), 218 | self.num_clips, 219 | ) 220 | ) 221 | 222 | self.stats = {"split": "test_final"} 223 | if self.multi_label: 224 | map = get_map( 225 | self.video_preds.cpu().numpy(), self.video_labels.cpu().numpy() 226 | ) 227 | self.stats["map"] = map 228 | else: 229 | num_topks_correct = topks_correct(self.video_preds, self.video_labels, ks) 230 | topks = [(x / self.video_preds.size(0)) * 100.0 for x in num_topks_correct] 231 | assert len({len(ks), len(topks)}) == 1 232 | for k, topk in zip(ks, topks): 233 | self.stats["top{}_acc".format(k)] = "{:.{prec}f}".format(topk, prec=2) 234 | print(self.stats) 235 | -------------------------------------------------------------------------------- /util/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import builtins 13 | import datetime 14 | import math 15 | import os 16 | import time 17 | from collections import defaultdict, deque, OrderedDict 18 | 19 | import mae_st.util.logging as logging 20 | import psutil 21 | import torch 22 | import torch.distributed as dist 23 | import torch.fb.rendezvous.zeus 24 | from iopath.common.file_io import g_pathmgr as pathmgr 25 | from mae_st.util.logging import master_print as print 26 | from torch import inf 27 | 28 | 29 | logger = logging.get_logger(__name__) 30 | 31 | 32 | class SmoothedValue: 33 | """Track a series of values and provide access to smoothed values over a 34 | window or the global series average. 35 | """ 36 | 37 | def __init__(self, window_size=20, fmt=None): 38 | if fmt is None: 39 | fmt = "{median:.4f} ({global_avg:.4f})" 40 | self.deque = deque(maxlen=window_size) 41 | self.total = 0.0 42 | self.count = 0 43 | self.fmt = fmt 44 | 45 | def update(self, value, n=1): 46 | self.deque.append(value) 47 | self.count += n 48 | self.total += value * n 49 | 50 | def synchronize_between_processes(self): 51 | """ 52 | Warning: does not synchronize the deque! 53 | """ 54 | if not is_dist_avail_and_initialized(): 55 | return 56 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") 57 | dist.barrier() 58 | dist.all_reduce(t) 59 | t = t.tolist() 60 | self.count = int(t[0]) 61 | self.total = t[1] 62 | 63 | @property 64 | def median(self): 65 | d = torch.tensor(list(self.deque)) 66 | return d.median().item() 67 | 68 | @property 69 | def avg(self): 70 | d = torch.tensor(list(self.deque), dtype=torch.float32) 71 | return d.mean().item() 72 | 73 | @property 74 | def global_avg(self): 75 | return self.total / self.count 76 | 77 | @property 78 | def max(self): 79 | return max(self.deque) 80 | 81 | @property 82 | def value(self): 83 | return self.deque[-1] 84 | 85 | def __str__(self): 86 | return self.fmt.format( 87 | median=self.median, 88 | avg=self.avg, 89 | global_avg=self.global_avg, 90 | max=self.max, 91 | value=self.value, 92 | ) 93 | 94 | 95 | class MetricLogger: 96 | def __init__(self, delimiter="\t"): 97 | self.meters = defaultdict(SmoothedValue) 98 | self.delimiter = delimiter 99 | 100 | def update(self, **kwargs): 101 | for k, v in kwargs.items(): 102 | if v is None: 103 | continue 104 | if isinstance(v, torch.Tensor): 105 | v = v.item() 106 | assert isinstance(v, (float, int)) 107 | self.meters[k].update(v) 108 | 109 | def __getattr__(self, attr): 110 | if attr in self.meters: 111 | return self.meters[attr] 112 | if attr in self.__dict__: 113 | return self.__dict__[attr] 114 | raise AttributeError( 115 | "'{}' object has no attribute '{}'".format(type(self).__name__, attr) 116 | ) 117 | 118 | def __str__(self): 119 | loss_str = [] 120 | for name, meter in self.meters.items(): 121 | loss_str.append("{}: {}".format(name, str(meter))) 122 | return self.delimiter.join(loss_str) 123 | 124 | def synchronize_between_processes(self): 125 | for meter in self.meters.values(): 126 | meter.synchronize_between_processes() 127 | 128 | def add_meter(self, name, meter): 129 | self.meters[name] = meter 130 | 131 | def log_every(self, iterable, print_freq, header=None): 132 | i = 0 133 | if not header: 134 | header = "" 135 | start_time = time.time() 136 | end = time.time() 137 | iter_time = SmoothedValue(fmt="{avg:.4f}") 138 | data_time = SmoothedValue(fmt="{avg:.4f}") 139 | space_fmt = ":" + str(len(str(len(iterable)))) + "d" 140 | log_msg = [ 141 | header, 142 | "[{0" + space_fmt + "}/{1}]", 143 | "eta: {eta}", 144 | "{meters}", 145 | "time: {time}", 146 | "data: {data}", 147 | ] 148 | if torch.cuda.is_available(): 149 | log_msg.append("max mem: {memory:.0f}") 150 | log_msg = self.delimiter.join(log_msg) 151 | MB = 1024.0 * 1024.0 152 | for obj in iterable: 153 | data_time.update(time.time() - end) 154 | yield obj 155 | iter_time.update(time.time() - end) 156 | if i % print_freq == 0 or i == len(iterable) - 1: 157 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 158 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 159 | if torch.cuda.is_available(): 160 | print( 161 | log_msg.format( 162 | i, 163 | len(iterable), 164 | eta=eta_string, 165 | meters=str(self), 166 | time=str(iter_time), 167 | data=str(data_time), 168 | memory=torch.cuda.max_memory_allocated() / MB, 169 | ) 170 | ) 171 | 172 | else: 173 | print( 174 | log_msg.format( 175 | i, 176 | len(iterable), 177 | eta=eta_string, 178 | meters=str(self), 179 | time=str(iter_time), 180 | data=str(data_time), 181 | ) 182 | ) 183 | i += 1 184 | end = time.time() 185 | total_time = time.time() - start_time 186 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 187 | print( 188 | "{} Total time: {} ({:.4f} s / it)".format( 189 | header, total_time_str, total_time / len(iterable) 190 | ) 191 | ) 192 | 193 | 194 | def setup_for_distributed(is_master): 195 | """ 196 | This function disables printing when not in master process 197 | """ 198 | builtin_print = builtins.print 199 | 200 | def print(*args, **kwargs): 201 | force = kwargs.pop("force", False) 202 | force = force or (get_world_size() > 8) 203 | if is_master or force: 204 | now = datetime.datetime.now().time() 205 | builtin_print("[{}] ".format(now), end="") # print with time stamp 206 | builtin_print(*args, **kwargs) 207 | 208 | builtins.print = print 209 | 210 | 211 | def is_dist_avail_and_initialized(): 212 | if not dist.is_available(): 213 | return False 214 | if not dist.is_initialized(): 215 | return False 216 | return True 217 | 218 | 219 | def get_world_size(): 220 | if not is_dist_avail_and_initialized(): 221 | return 1 222 | return dist.get_world_size() 223 | 224 | 225 | def get_rank(): 226 | if not is_dist_avail_and_initialized(): 227 | return 0 228 | return dist.get_rank() 229 | 230 | 231 | def is_main_process(): 232 | return get_rank() == 0 233 | 234 | 235 | def save_on_master(state, path): 236 | if is_main_process(): 237 | print(f"save path {path}") 238 | with pathmgr.open(path, "wb") as f: 239 | torch.save(state, f) 240 | 241 | 242 | def init_distributed_mode(args): 243 | if args.no_env: 244 | pass 245 | elif args.dist_on_itp: 246 | args.rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) 247 | args.world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"]) 248 | args.gpu = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) 249 | args.dist_url = "tcp://%s:%s" % ( 250 | os.environ["MASTER_ADDR"], 251 | os.environ["MASTER_PORT"], 252 | ) 253 | os.environ["LOCAL_RANK"] = str(args.gpu) 254 | os.environ["RANK"] = str(args.rank) 255 | os.environ["WORLD_SIZE"] = str(args.world_size) 256 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] 257 | elif "RANK" in os.environ and "WORLD_SIZE" in os.environ: 258 | args.rank = int(os.environ["RANK"]) 259 | args.world_size = int(os.environ["WORLD_SIZE"]) 260 | args.gpu = int(os.environ["LOCAL_RANK"]) 261 | elif "SLURM_PROCID" in os.environ: 262 | args.rank = int(os.environ["SLURM_PROCID"]) 263 | args.gpu = args.rank % torch.cuda.device_count() 264 | else: 265 | print("Not using distributed mode") 266 | setup_for_distributed(is_master=True) # hack 267 | args.distributed = False 268 | return 269 | 270 | args.distributed = True 271 | 272 | torch.cuda.set_device(args.gpu) 273 | args.dist_backend = "nccl" 274 | print( 275 | "| distributed init (rank {}): {}, gpu {}".format( 276 | args.rank, args.dist_url, args.gpu 277 | ), 278 | # flush=True, 279 | ) 280 | torch.distributed.init_process_group( 281 | backend=args.dist_backend, 282 | init_method=args.dist_url, 283 | world_size=args.world_size, 284 | rank=args.rank, 285 | ) 286 | torch.distributed.barrier() 287 | setup_for_distributed(args.rank == 0) 288 | 289 | 290 | class NativeScalerWithGradNormCount: 291 | state_dict_key = "amp_scaler" 292 | 293 | def __init__(self, fp32=False): 294 | self._scaler = torch.cuda.amp.GradScaler(enabled=not fp32) 295 | 296 | def __call__( 297 | self, 298 | loss, 299 | optimizer, 300 | clip_grad=None, 301 | parameters=None, 302 | create_graph=False, 303 | update_grad=True, 304 | ): 305 | self._scaler.scale(loss).backward(create_graph=create_graph) 306 | if update_grad: 307 | if clip_grad is not None: 308 | assert parameters is not None 309 | self._scaler.unscale_( 310 | optimizer 311 | ) # unscale the gradients of optimizer's assigned params in-place 312 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 313 | else: 314 | self._scaler.unscale_(optimizer) 315 | norm = get_grad_norm_(parameters) 316 | self._scaler.step(optimizer) 317 | self._scaler.update() 318 | else: 319 | norm = None 320 | return norm 321 | 322 | def state_dict(self): 323 | return self._scaler.state_dict() 324 | 325 | def load_state_dict(self, state_dict): 326 | self._scaler.load_state_dict(state_dict) 327 | 328 | 329 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 330 | if isinstance(parameters, torch.Tensor): 331 | parameters = [parameters] 332 | parameters = [p for p in parameters if p.grad is not None] 333 | norm_type = float(norm_type) 334 | if len(parameters) == 0: 335 | return torch.tensor(0.0) 336 | device = parameters[0].grad.device 337 | if norm_type == inf: 338 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 339 | else: 340 | total_norm = torch.norm( 341 | torch.stack( 342 | [torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters] 343 | ), 344 | norm_type, 345 | ) 346 | return total_norm 347 | 348 | 349 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler): 350 | checkpoint_path = "{}/checkpoint-{:05d}.pth".format(args.output_dir, epoch) 351 | to_save = { 352 | "model": model_without_ddp.state_dict(), 353 | "optimizer": optimizer.state_dict(), 354 | "epoch": epoch, 355 | "scaler": loss_scaler.state_dict(), 356 | "args": args, 357 | } 358 | 359 | save_on_master(to_save, checkpoint_path) 360 | return checkpoint_path 361 | 362 | 363 | def get_last_checkpoint(args): 364 | """ 365 | Get the last checkpoint from the checkpointing folder. 366 | Args: 367 | path_to_job (string): the path to the folder of the current job. 368 | """ 369 | d = args.output_dir 370 | names = pathmgr.ls(d) if pathmgr.exists(d) else [] 371 | names = [f for f in names if "checkpoint" in f] 372 | if len(names) == 0: 373 | print("No checkpoints found in '{}'.".format(d)) 374 | return None 375 | else: 376 | # Sort the checkpoints by epoch. 377 | name = sorted(names)[-1] 378 | return os.path.join(d, name) 379 | 380 | 381 | def load_model(args, model_without_ddp, optimizer, loss_scaler): 382 | if not args.resume: 383 | args.resume = get_last_checkpoint(args) 384 | if args.resume: 385 | if args.resume.startswith("https"): 386 | checkpoint = torch.hub.load_state_dict_from_url( 387 | args.resume, map_location="cpu", check_hash=True 388 | ) 389 | else: 390 | with pathmgr.open(args.resume, "rb") as f: 391 | checkpoint = torch.load(f, map_location="cpu") 392 | model_without_ddp.load_state_dict(checkpoint["model"]) 393 | print("Resume checkpoint %s" % args.resume) 394 | if ( 395 | "optimizer" in checkpoint 396 | and "epoch" in checkpoint 397 | and not (hasattr(args, "eval") and args.eval) 398 | ): 399 | optimizer.load_state_dict(checkpoint["optimizer"]) 400 | args.start_epoch = checkpoint["epoch"] + 1 401 | if "scaler" in checkpoint: 402 | loss_scaler.load_state_dict(checkpoint["scaler"]) 403 | print("With optim & sched!") 404 | 405 | 406 | def all_reduce_mean(x): 407 | world_size = get_world_size() 408 | if world_size > 1: 409 | x_reduce = torch.tensor(x).cuda() 410 | dist.all_reduce(x_reduce) 411 | x_reduce /= world_size 412 | return x_reduce.item() 413 | else: 414 | return x 415 | 416 | 417 | def gpu_mem_usage(): 418 | """ 419 | Compute the GPU memory usage for the current device (GB). 420 | """ 421 | if torch.cuda.is_available(): 422 | mem_usage_bytes = torch.cuda.max_memory_allocated() 423 | else: 424 | mem_usage_bytes = 0 425 | return mem_usage_bytes / 1024**3 426 | 427 | 428 | def cpu_mem_usage(): 429 | """ 430 | Compute the system memory (RAM) usage for the current device (GB). 431 | Returns: 432 | usage (float): used memory (GB). 433 | total (float): total memory (GB). 434 | """ 435 | vram = psutil.virtual_memory() 436 | usage = (vram.total - vram.available) / 1024**3 437 | total = vram.total / 1024**3 438 | 439 | return usage, total 440 | 441 | 442 | def all_gather(tensors): 443 | """ 444 | All gathers the provided tensors from all processes across machines. 445 | Args: 446 | tensors (list): tensors to perform all gather across all processes in 447 | all machines. 448 | """ 449 | 450 | gather_list = [] 451 | output_tensor = [] 452 | world_size = dist.get_world_size() 453 | for tensor in tensors: 454 | tensor_placeholder = [torch.ones_like(tensor) for _ in range(world_size)] 455 | dist.all_gather(tensor_placeholder, tensor, async_op=False) 456 | gather_list.append(tensor_placeholder) 457 | for gathered_tensor in gather_list: 458 | output_tensor.append(torch.cat(gathered_tensor, dim=0)) 459 | return output_tensor 460 | 461 | 462 | def add_weight_decay(model, weight_decay=1e-5, skip_list=(), bias_wd=False): 463 | decay = [] 464 | no_decay = [] 465 | for name, param in model.named_parameters(): 466 | if not param.requires_grad: 467 | continue # frozen weights 468 | if ( 469 | (not bias_wd) 470 | and len(param.shape) == 1 471 | or name.endswith(".bias") 472 | or name in skip_list 473 | ): 474 | no_decay.append(param) 475 | else: 476 | decay.append(param) 477 | return [ 478 | {"params": no_decay, "weight_decay": 0.0}, 479 | {"params": decay, "weight_decay": weight_decay}, 480 | ] 481 | 482 | 483 | def inflate(model_2d, model_3d): 484 | state_dict_inflated = OrderedDict() 485 | for k, v2d in model_2d.items(): 486 | if "patch_embed.proj.weight" in k: 487 | v3d = model_3d[k] 488 | v3d = v2d.unsqueeze(2).repeat(1, 1, v3d.shape[2], 1, 1) / v3d.shape[2] 489 | state_dict_inflated[k] = v3d.clone() 490 | elif "pos_embed" in k: 491 | pos_embed_cls, pos_embed_spatial = torch.split(v2d, [1, 196], dim=1) 492 | state_dict_inflated["pos_embed_cls"] = pos_embed_cls.clone() 493 | state_dict_inflated["pos_embed"] = pos_embed_spatial.clone() 494 | else: 495 | state_dict_inflated[k] = v2d.clone() 496 | return state_dict_inflated 497 | 498 | 499 | def convert_checkpoint(model_2d): 500 | state_dict_inflated = OrderedDict() 501 | for k, v2d in model_2d.items(): 502 | if "head.projection.weight" in k: 503 | state_dict_inflated["head.weight"] = v2d.clone() 504 | elif "head.projection.bias" in k: 505 | state_dict_inflated["head.bias"] = v2d.clone() 506 | else: 507 | state_dict_inflated[k] = v2d.clone() 508 | return state_dict_inflated 509 | -------------------------------------------------------------------------------- /util/pos_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # Position embedding utils 8 | # -------------------------------------------------------- 9 | 10 | import mae_st.util.logging as logging 11 | import numpy as np 12 | import torch 13 | 14 | 15 | logger = logging.get_logger(__name__) 16 | 17 | 18 | # -------------------------------------------------------- 19 | # Interpolate position embeddings for high-resolution 20 | # References: 21 | # DeiT: https://github.com/facebookresearch/deit 22 | # -------------------------------------------------------- 23 | def interpolate_pos_embed(model, checkpoint_model): 24 | if "pos_embed" in checkpoint_model: 25 | pos_embed_checkpoint = checkpoint_model["pos_embed"] 26 | embedding_size = pos_embed_checkpoint.shape[-1] 27 | num_patches = model.patch_embed.num_patches 28 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 29 | # height (== width) for the checkpoint position embedding 30 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 31 | # height (== width) for the new position embedding 32 | new_size = int(num_patches**0.5) 33 | # class_token and dist_token are kept unchanged 34 | if orig_size != new_size: 35 | print( 36 | "Position interpolate from %dx%d to %dx%d" 37 | % (orig_size, orig_size, new_size, new_size) 38 | ) 39 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 40 | # only the position tokens are interpolated 41 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 42 | pos_tokens = pos_tokens.reshape( 43 | -1, orig_size, orig_size, embedding_size 44 | ).permute(0, 3, 1, 2) 45 | pos_tokens = torch.nn.functional.interpolate( 46 | pos_tokens, 47 | size=(new_size, new_size), 48 | mode="bicubic", 49 | align_corners=False, 50 | ) 51 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 52 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 53 | checkpoint_model["pos_embed"] = new_pos_embed 54 | -------------------------------------------------------------------------------- /util/video_vit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | 5 | import mae_st.util.logging as logging 6 | import torch 7 | import torch.nn as nn 8 | from timm.models.layers import to_2tuple 9 | from timm.models.vision_transformer import DropPath, Mlp 10 | 11 | 12 | logger = logging.get_logger(__name__) 13 | 14 | 15 | class PatchEmbed(nn.Module): 16 | """Image to Patch Embedding""" 17 | 18 | def __init__( 19 | self, 20 | img_size=224, 21 | patch_size=16, 22 | in_chans=3, 23 | embed_dim=768, 24 | # temporal related: 25 | frames=32, 26 | t_patch_size=4, 27 | ): 28 | super().__init__() 29 | img_size = to_2tuple(img_size) 30 | patch_size = to_2tuple(patch_size) 31 | assert img_size[1] % patch_size[1] == 0 32 | assert img_size[0] % patch_size[0] == 0 33 | assert frames % t_patch_size == 0 34 | num_patches = ( 35 | (img_size[1] // patch_size[1]) 36 | * (img_size[0] // patch_size[0]) 37 | * (frames // t_patch_size) 38 | ) 39 | self.input_size = ( 40 | frames // t_patch_size, 41 | img_size[0] // patch_size[0], 42 | img_size[1] // patch_size[1], 43 | ) 44 | print( 45 | f"img_size {img_size} patch_size {patch_size} frames {frames} t_patch_size {t_patch_size}" 46 | ) 47 | self.img_size = img_size 48 | self.patch_size = patch_size 49 | 50 | self.frames = frames 51 | self.t_patch_size = t_patch_size 52 | 53 | self.num_patches = num_patches 54 | 55 | self.grid_size = img_size[0] // patch_size[0] 56 | self.t_grid_size = frames // t_patch_size 57 | 58 | kernel_size = [t_patch_size] + list(patch_size) 59 | self.proj = nn.Conv3d( 60 | in_chans, embed_dim, kernel_size=kernel_size, stride=kernel_size 61 | ) 62 | 63 | def forward(self, x): 64 | B, C, T, H, W = x.shape 65 | assert ( 66 | H == self.img_size[0] and W == self.img_size[1] 67 | ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 68 | assert T == self.frames 69 | x = self.proj(x).flatten(3) 70 | x = torch.einsum("ncts->ntsc", x) # [N, T, H*W, C] 71 | return x 72 | 73 | 74 | class Attention(nn.Module): 75 | def __init__( 76 | self, 77 | dim, 78 | num_heads=8, 79 | qkv_bias=False, 80 | qk_scale=None, 81 | attn_drop=0.0, 82 | proj_drop=0.0, 83 | input_size=(4, 14, 14), 84 | ): 85 | super().__init__() 86 | assert dim % num_heads == 0, "dim should be divisible by num_heads" 87 | self.num_heads = num_heads 88 | head_dim = dim // num_heads 89 | self.scale = qk_scale or head_dim**-0.5 90 | 91 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 92 | self.k = nn.Linear(dim, dim, bias=qkv_bias) 93 | self.v = nn.Linear(dim, dim, bias=qkv_bias) 94 | assert attn_drop == 0.0 # do not use 95 | self.proj = nn.Linear(dim, dim) 96 | self.proj_drop = nn.Dropout(proj_drop) 97 | self.input_size = input_size 98 | assert input_size[1] == input_size[2] 99 | 100 | def forward(self, x): 101 | B, N, C = x.shape 102 | q = ( 103 | self.q(x) 104 | .reshape(B, N, self.num_heads, C // self.num_heads) 105 | .permute(0, 2, 1, 3) 106 | ) 107 | k = ( 108 | self.k(x) 109 | .reshape(B, N, self.num_heads, C // self.num_heads) 110 | .permute(0, 2, 1, 3) 111 | ) 112 | v = ( 113 | self.v(x) 114 | .reshape(B, N, self.num_heads, C // self.num_heads) 115 | .permute(0, 2, 1, 3) 116 | ) 117 | 118 | attn = (q @ k.transpose(-2, -1)) * self.scale 119 | 120 | attn = attn.softmax(dim=-1) 121 | 122 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 123 | x = self.proj(x) 124 | x = self.proj_drop(x) 125 | x = x.view(B, -1, C) 126 | return x 127 | 128 | 129 | class Block(nn.Module): 130 | """ 131 | Transformer Block with specified Attention function 132 | """ 133 | 134 | def __init__( 135 | self, 136 | dim, 137 | num_heads, 138 | mlp_ratio=4.0, 139 | qkv_bias=False, 140 | qk_scale=None, 141 | drop=0.0, 142 | attn_drop=0.0, 143 | drop_path=0.0, 144 | act_layer=nn.GELU, 145 | norm_layer=nn.LayerNorm, 146 | attn_func=Attention, 147 | ): 148 | super().__init__() 149 | self.norm1 = norm_layer(dim) 150 | self.attn = attn_func( 151 | dim, 152 | num_heads=num_heads, 153 | qkv_bias=qkv_bias, 154 | qk_scale=qk_scale, 155 | attn_drop=attn_drop, 156 | proj_drop=drop, 157 | ) 158 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 159 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 160 | self.norm2 = norm_layer(dim) 161 | mlp_hidden_dim = int(dim * mlp_ratio) 162 | self.mlp = Mlp( 163 | in_features=dim, 164 | hidden_features=mlp_hidden_dim, 165 | act_layer=act_layer, 166 | drop=drop, 167 | ) 168 | 169 | def forward(self, x): 170 | x = x + self.drop_path(self.attn(self.norm1(x))) 171 | x = x + self.drop_path(self.mlp(self.norm2(x))) 172 | return x 173 | --------------------------------------------------------------------------------