├── .github └── workflows │ └── python-publish.yml ├── .gitignore ├── CHANGELOG.md ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── LICENSE.models ├── README.md ├── examples ├── benchmark.ipynb ├── img │ ├── dog.jpg │ ├── hiera_arch.png │ └── inference_speed.png ├── inference.ipynb └── vid │ ├── dog.mp4 │ └── goat.mp4 ├── hiera ├── __init__.py ├── benchmarking.py ├── hfhub.py ├── hiera.py ├── hiera_mae.py └── hiera_utils.py ├── hubconf.py └── setup.py /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # Note: this is a github action and not part of the hiera codebase. 2 | 3 | # This workflow will upload a Python Package using Twine when a release is created 4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 5 | 6 | # This workflow uses actions that are not certified by GitHub. 7 | # They are provided by a third-party and are governed by 8 | # separate terms of service, privacy policy, and support 9 | # documentation. 10 | 11 | name: Upload Python Package 12 | 13 | on: 14 | release: 15 | types: [published] 16 | 17 | jobs: 18 | deploy: 19 | 20 | runs-on: ubuntu-latest 21 | 22 | steps: 23 | - uses: actions/checkout@v2 24 | - name: Set up Python 25 | uses: actions/setup-python@v2 26 | with: 27 | python-version: '3.x' 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | pip install build 32 | - name: Build package 33 | run: python -m build 34 | - name: Publish package 35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 36 | with: 37 | user: __token__ 38 | password: ${{ secrets.PYPI_API_TOKEN }} -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | __pycache__ 3 | core 4 | *.pyd 5 | *.egg-info 6 | build/ 7 | dist/ 8 | build.sh 9 | *.job -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | ### **[2024.03.02]** v0.1.4 2 | - License made more permissive! The license for the code has been changed to Apache 2.0. The license for the model remains as CC BY-NC 4.0, though (nothing we can do about that). 3 | 4 | ### **[2024.03.01]** v0.1.3 5 | - Added support to save and load models to the huggingface hub, if huggingface_hub is installed. 6 | - Most Hiera models have been uploaded to HuggingFace. 7 | 8 | ### **[2023.07.20]** v0.1.2 9 | - Released the full model zoo. 10 | - Added MAE functionality to the video models. 11 | 12 | ### **[2023.06.12]** v0.1.1 13 | - Added the ability to specify multiple pretrained checkpoints per architecture (specify with `checkpoint=`). 14 | - Added the ability to pass `strict=False` to a pretrained model so that you can use a different number of classes. **Note:** when changing the number of classes, the head layer will be reset. 15 | - Released all in1k finetuned models. 16 | 17 | ### **[2023.06.01]** v0.1.0 18 | - Initial Release. 19 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to hiera 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to hiera, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. -------------------------------------------------------------------------------- /LICENSE.models: -------------------------------------------------------------------------------- 1 | Creative Commons Attribution-NonCommercial 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright and 20 | certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | - Considerations for licensors: Our public licenses are intended for 25 | use by those authorized to give the public permission to use 26 | material in ways otherwise restricted by copyright and certain other 27 | rights. Our licenses are irrevocable. Licensors should read and 28 | understand the terms and conditions of the license they choose 29 | before applying it. Licensors should also secure all rights 30 | necessary before applying our licenses so that the public can reuse 31 | the material as expected. Licensors should clearly mark any material 32 | not subject to the license. This includes other CC-licensed 33 | material, or material used under an exception or limitation to 34 | copyright. More considerations for licensors : 35 | wiki.creativecommons.org/Considerations\_for\_licensors 36 | 37 | - Considerations for the public: By using one of our public licenses, 38 | a licensor grants the public permission to use the licensed material 39 | under specified terms and conditions. If the licensor's permission 40 | is not necessary for any reason–for example, because of any 41 | applicable exception or limitation to copyright–then that use is not 42 | regulated by the license. Our licenses grant only permissions under 43 | copyright and certain other rights that a licensor has authority to 44 | grant. Use of the licensed material may still be restricted for 45 | other reasons, including because others have copyright or other 46 | rights in the material. A licensor may make special requests, such 47 | as asking that all changes be marked or described. Although not 48 | required by our licenses, you are encouraged to respect those 49 | requests where reasonable. More considerations for the public : 50 | wiki.creativecommons.org/Considerations\_for\_licensees 51 | 52 | Creative Commons Attribution-NonCommercial 4.0 International Public 53 | License 54 | 55 | By exercising the Licensed Rights (defined below), You accept and agree 56 | to be bound by the terms and conditions of this Creative Commons 57 | Attribution-NonCommercial 4.0 International Public License ("Public 58 | License"). To the extent this Public License may be interpreted as a 59 | contract, You are granted the Licensed Rights in consideration of Your 60 | acceptance of these terms and conditions, and the Licensor grants You 61 | such rights in consideration of benefits the Licensor receives from 62 | making the Licensed Material available under these terms and conditions. 63 | 64 | - Section 1 – Definitions. 65 | 66 | - a. Adapted Material means material subject to Copyright and 67 | Similar Rights that is derived from or based upon the Licensed 68 | Material and in which the Licensed Material is translated, 69 | altered, arranged, transformed, or otherwise modified in a 70 | manner requiring permission under the Copyright and Similar 71 | Rights held by the Licensor. For purposes of this Public 72 | License, where the Licensed Material is a musical work, 73 | performance, or sound recording, Adapted Material is always 74 | produced where the Licensed Material is synched in timed 75 | relation with a moving image. 76 | - b. Adapter's License means the license You apply to Your 77 | Copyright and Similar Rights in Your contributions to Adapted 78 | Material in accordance with the terms and conditions of this 79 | Public License. 80 | - c. Copyright and Similar Rights means copyright and/or similar 81 | rights closely related to copyright including, without 82 | limitation, performance, broadcast, sound recording, and Sui 83 | Generis Database Rights, without regard to how the rights are 84 | labeled or categorized. For purposes of this Public License, the 85 | rights specified in Section 2(b)(1)-(2) are not Copyright and 86 | Similar Rights. 87 | - d. Effective Technological Measures means those measures that, 88 | in the absence of proper authority, may not be circumvented 89 | under laws fulfilling obligations under Article 11 of the WIPO 90 | Copyright Treaty adopted on December 20, 1996, and/or similar 91 | international agreements. 92 | - e. Exceptions and Limitations means fair use, fair dealing, 93 | and/or any other exception or limitation to Copyright and 94 | Similar Rights that applies to Your use of the Licensed 95 | Material. 96 | - f. Licensed Material means the artistic or literary work, 97 | database, or other material to which the Licensor applied this 98 | Public License. 99 | - g. Licensed Rights means the rights granted to You subject to 100 | the terms and conditions of this Public License, which are 101 | limited to all Copyright and Similar Rights that apply to Your 102 | use of the Licensed Material and that the Licensor has authority 103 | to license. 104 | - h. Licensor means the individual(s) or entity(ies) granting 105 | rights under this Public License. 106 | - i. NonCommercial means not primarily intended for or directed 107 | towards commercial advantage or monetary compensation. For 108 | purposes of this Public License, the exchange of the Licensed 109 | Material for other material subject to Copyright and Similar 110 | Rights by digital file-sharing or similar means is NonCommercial 111 | provided there is no payment of monetary compensation in 112 | connection with the exchange. 113 | - j. Share means to provide material to the public by any means or 114 | process that requires permission under the Licensed Rights, such 115 | as reproduction, public display, public performance, 116 | distribution, dissemination, communication, or importation, and 117 | to make material available to the public including in ways that 118 | members of the public may access the material from a place and 119 | at a time individually chosen by them. 120 | - k. Sui Generis Database Rights means rights other than copyright 121 | resulting from Directive 96/9/EC of the European Parliament and 122 | of the Council of 11 March 1996 on the legal protection of 123 | databases, as amended and/or succeeded, as well as other 124 | essentially equivalent rights anywhere in the world. 125 | - l. You means the individual or entity exercising the Licensed 126 | Rights under this Public License. Your has a corresponding 127 | meaning. 128 | 129 | - Section 2 – Scope. 130 | 131 | - a. License grant. 132 | - 1. Subject to the terms and conditions of this Public 133 | License, the Licensor hereby grants You a worldwide, 134 | royalty-free, non-sublicensable, non-exclusive, irrevocable 135 | license to exercise the Licensed Rights in the Licensed 136 | Material to: 137 | - A. reproduce and Share the Licensed Material, in whole 138 | or in part, for NonCommercial purposes only; and 139 | - B. produce, reproduce, and Share Adapted Material for 140 | NonCommercial purposes only. 141 | - 2. Exceptions and Limitations. For the avoidance of doubt, 142 | where Exceptions and Limitations apply to Your use, this 143 | Public License does not apply, and You do not need to comply 144 | with its terms and conditions. 145 | - 3. Term. The term of this Public License is specified in 146 | Section 6(a). 147 | - 4. Media and formats; technical modifications allowed. The 148 | Licensor authorizes You to exercise the Licensed Rights in 149 | all media and formats whether now known or hereafter 150 | created, and to make technical modifications necessary to do 151 | so. The Licensor waives and/or agrees not to assert any 152 | right or authority to forbid You from making technical 153 | modifications necessary to exercise the Licensed Rights, 154 | including technical modifications necessary to circumvent 155 | Effective Technological Measures. For purposes of this 156 | Public License, simply making modifications authorized by 157 | this Section 2(a)(4) never produces Adapted Material. 158 | - 5. Downstream recipients. 159 | - A. Offer from the Licensor – Licensed Material. Every 160 | recipient of the Licensed Material automatically 161 | receives an offer from the Licensor to exercise the 162 | Licensed Rights under the terms and conditions of this 163 | Public License. 164 | - B. No downstream restrictions. You may not offer or 165 | impose any additional or different terms or conditions 166 | on, or apply any Effective Technological Measures to, 167 | the Licensed Material if doing so restricts exercise of 168 | the Licensed Rights by any recipient of the Licensed 169 | Material. 170 | - 6. No endorsement. Nothing in this Public License 171 | constitutes or may be construed as permission to assert or 172 | imply that You are, or that Your use of the Licensed 173 | Material is, connected with, or sponsored, endorsed, or 174 | granted official status by, the Licensor or others 175 | designated to receive attribution as provided in Section 176 | 3(a)(1)(A)(i). 177 | - b. Other rights. 178 | - 1. Moral rights, such as the right of integrity, are not 179 | licensed under this Public License, nor are publicity, 180 | privacy, and/or other similar personality rights; however, 181 | to the extent possible, the Licensor waives and/or agrees 182 | not to assert any such rights held by the Licensor to the 183 | limited extent necessary to allow You to exercise the 184 | Licensed Rights, but not otherwise. 185 | - 2. Patent and trademark rights are not licensed under this 186 | Public License. 187 | - 3. To the extent possible, the Licensor waives any right to 188 | collect royalties from You for the exercise of the Licensed 189 | Rights, whether directly or through a collecting society 190 | under any voluntary or waivable statutory or compulsory 191 | licensing scheme. In all other cases the Licensor expressly 192 | reserves any right to collect such royalties, including when 193 | the Licensed Material is used other than for NonCommercial 194 | purposes. 195 | 196 | - Section 3 – License Conditions. 197 | 198 | Your exercise of the Licensed Rights is expressly made subject to 199 | the following conditions. 200 | 201 | - a. Attribution. 202 | - 1. If You Share the Licensed Material (including in modified 203 | form), You must: 204 | - A. retain the following if it is supplied by the 205 | Licensor with the Licensed Material: 206 | - i. identification of the creator(s) of the Licensed 207 | Material and any others designated to receive 208 | attribution, in any reasonable manner requested by 209 | the Licensor (including by pseudonym if designated); 210 | - ii. a copyright notice; 211 | - iii. a notice that refers to this Public License; 212 | - iv. a notice that refers to the disclaimer of 213 | warranties; 214 | - v. a URI or hyperlink to the Licensed Material to 215 | the extent reasonably practicable; 216 | - B. indicate if You modified the Licensed Material and 217 | retain an indication of any previous modifications; and 218 | - C. indicate the Licensed Material is licensed under this 219 | Public License, and include the text of, or the URI or 220 | hyperlink to, this Public License. 221 | - 2. You may satisfy the conditions in Section 3(a)(1) in any 222 | reasonable manner based on the medium, means, and context in 223 | which You Share the Licensed Material. For example, it may 224 | be reasonable to satisfy the conditions by providing a URI 225 | or hyperlink to a resource that includes the required 226 | information. 227 | - 3. If requested by the Licensor, You must remove any of the 228 | information required by Section 3(a)(1)(A) to the extent 229 | reasonably practicable. 230 | - 4. If You Share Adapted Material You produce, the Adapter's 231 | License You apply must not prevent recipients of the Adapted 232 | Material from complying with this Public License. 233 | 234 | - Section 4 – Sui Generis Database Rights. 235 | 236 | Where the Licensed Rights include Sui Generis Database Rights that 237 | apply to Your use of the Licensed Material: 238 | 239 | - a. for the avoidance of doubt, Section 2(a)(1) grants You the 240 | right to extract, reuse, reproduce, and Share all or a 241 | substantial portion of the contents of the database for 242 | NonCommercial purposes only; 243 | - b. if You include all or a substantial portion of the database 244 | contents in a database in which You have Sui Generis Database 245 | Rights, then the database in which You have Sui Generis Database 246 | Rights (but not its individual contents) is Adapted Material; 247 | and 248 | - c. You must comply with the conditions in Section 3(a) if You 249 | Share all or a substantial portion of the contents of the 250 | database. 251 | 252 | For the avoidance of doubt, this Section 4 supplements and does not 253 | replace Your obligations under this Public License where the 254 | Licensed Rights include other Copyright and Similar Rights. 255 | 256 | - Section 5 – Disclaimer of Warranties and Limitation of Liability. 257 | 258 | - a. Unless otherwise separately undertaken by the Licensor, to 259 | the extent possible, the Licensor offers the Licensed Material 260 | as-is and as-available, and makes no representations or 261 | warranties of any kind concerning the Licensed Material, whether 262 | express, implied, statutory, or other. This includes, without 263 | limitation, warranties of title, merchantability, fitness for a 264 | particular purpose, non-infringement, absence of latent or other 265 | defects, accuracy, or the presence or absence of errors, whether 266 | or not known or discoverable. Where disclaimers of warranties 267 | are not allowed in full or in part, this disclaimer may not 268 | apply to You. 269 | - b. To the extent possible, in no event will the Licensor be 270 | liable to You on any legal theory (including, without 271 | limitation, negligence) or otherwise for any direct, special, 272 | indirect, incidental, consequential, punitive, exemplary, or 273 | other losses, costs, expenses, or damages arising out of this 274 | Public License or use of the Licensed Material, even if the 275 | Licensor has been advised of the possibility of such losses, 276 | costs, expenses, or damages. Where a limitation of liability is 277 | not allowed in full or in part, this limitation may not apply to 278 | You. 279 | - c. The disclaimer of warranties and limitation of liability 280 | provided above shall be interpreted in a manner that, to the 281 | extent possible, most closely approximates an absolute 282 | disclaimer and waiver of all liability. 283 | 284 | - Section 6 – Term and Termination. 285 | 286 | - a. This Public License applies for the term of the Copyright and 287 | Similar Rights licensed here. However, if You fail to comply 288 | with this Public License, then Your rights under this Public 289 | License terminate automatically. 290 | - b. Where Your right to use the Licensed Material has terminated 291 | under Section 6(a), it reinstates: 292 | 293 | - 1. automatically as of the date the violation is cured, 294 | provided it is cured within 30 days of Your discovery of the 295 | violation; or 296 | - 2. upon express reinstatement by the Licensor. 297 | 298 | For the avoidance of doubt, this Section 6(b) does not affect 299 | any right the Licensor may have to seek remedies for Your 300 | violations of this Public License. 301 | 302 | - c. For the avoidance of doubt, the Licensor may also offer the 303 | Licensed Material under separate terms or conditions or stop 304 | distributing the Licensed Material at any time; however, doing 305 | so will not terminate this Public License. 306 | - d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 307 | License. 308 | 309 | - Section 7 – Other Terms and Conditions. 310 | 311 | - a. The Licensor shall not be bound by any additional or 312 | different terms or conditions communicated by You unless 313 | expressly agreed. 314 | - b. Any arrangements, understandings, or agreements regarding the 315 | Licensed Material not stated herein are separate from and 316 | independent of the terms and conditions of this Public License. 317 | 318 | - Section 8 – Interpretation. 319 | 320 | - a. For the avoidance of doubt, this Public License does not, and 321 | shall not be interpreted to, reduce, limit, restrict, or impose 322 | conditions on any use of the Licensed Material that could 323 | lawfully be made without permission under this Public License. 324 | - b. To the extent possible, if any provision of this Public 325 | License is deemed unenforceable, it shall be automatically 326 | reformed to the minimum extent necessary to make it enforceable. 327 | If the provision cannot be reformed, it shall be severed from 328 | this Public License without affecting the enforceability of the 329 | remaining terms and conditions. 330 | - c. No term or condition of this Public License will be waived 331 | and no failure to comply consented to unless expressly agreed to 332 | by the Licensor. 333 | - d. Nothing in this Public License constitutes or may be 334 | interpreted as a limitation upon, or waiver of, any privileges 335 | and immunities that apply to the Licensor or You, including from 336 | the legal processes of any jurisdiction or authority. 337 | 338 | ======================================================================= 339 | 340 | Creative Commons is not a party to its public licenses. Notwithstanding, 341 | Creative Commons may elect to apply one of its public licenses to 342 | material it publishes and in those instances will be considered the 343 | "Licensor." The text of the Creative Commons public licenses is 344 | dedicated to the public domain under the CC0 Public Domain Dedication. 345 | Except for the limited purpose of indicating that material is shared 346 | under a Creative Commons public license or as otherwise permitted by the 347 | Creative Commons policies published at creativecommons.org/policies, 348 | Creative Commons does not authorize the use of the trademark "Creative 349 | Commons" or any other trademark or logo of Creative Commons without its 350 | prior written consent including, without limitation, in connection with 351 | any unauthorized modifications to any of its public licenses or any 352 | other arrangements, understandings, or agreements concerning use of 353 | licensed material. For the avoidance of doubt, this paragraph does not 354 | form part of the public licenses. 355 | 356 | Creative Commons may be contacted at creativecommons.org. 357 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Hiera: A Hierarchical Vision Transformer without the Bells-and-Whistles 2 | 3 | [![Torch Hub Support](https://img.shields.io/badge/torch_hub-gray?logo=pytorch)](#torch-hub) 4 | [![HF Hub Support](https://img.shields.io/badge/%F0%9F%A4%97_huggingface_hub-gray)](#hugging-face-hub) 5 | [![Torch Hub Support](https://img.shields.io/badge/PyPI-gray?logo=pypi&logoColor=lightblue)](https://pypi.org/project/hiera-transformer/) 6 | [![Python 3.6](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/) 7 | [![Github Release](https://img.shields.io/github/release/facebookresearch/hiera.svg)](https://github.com/facebookresearch/hiera/releases) 8 | [![Code License](https://img.shields.io/badge/code_license-Apache_2.0-olive)](https://opensource.org/licenses/Apache-2.0) 9 | [![Model License](https://img.shields.io/badge/model_zoo_license-CC_BY--NC_4.0-lightgrey)](https://creativecommons.org/licenses/by-nc/4.0/deed.en) 10 | 11 | This is the official implementation for our ICML 2023 Oral paper: 12 | **[Hiera: A Hierarchical Vision Transformer without the Bells-and-Whistles][arxiv-link]** 13 | [Chaitanya Ryali](https://scholar.google.com/citations?user=4LWx24UAAAAJ)\*, 14 | [Yuan-Ting Hu](https://scholar.google.com/citations?user=aMpbemkAAAAJ)\*, 15 | [Daniel Bolya](https://scholar.google.com/citations?hl=en&user=K3ht_ZUAAAAJ)\*, 16 | [Chen Wei](https://scholar.google.com/citations?hl=en&user=LHQGpBUAAAAJ), 17 | [Haoqi Fan](https://scholar.google.com/citations?hl=en&user=76B8lrgAAAAJ), 18 | [Po-Yao Huang](https://scholar.google.com/citations?hl=en&user=E8K25LIAAAAJ), 19 | [Vaibhav Aggarwal](https://scholar.google.com/citations?hl=en&user=Qwm6ZOYAAAAJ), 20 | [Arkabandhu Chowdhury](https://scholar.google.com/citations?hl=en&user=42v1i_YAAAAJ), 21 | [Omid Poursaeed](https://scholar.google.com/citations?hl=en&user=Ugw9DX0AAAAJ), 22 | [Judy Hoffman](https://scholar.google.com/citations?hl=en&user=mqpjAt4AAAAJ), 23 | [Jitendra Malik](https://scholar.google.com/citations?hl=en&user=oY9R5YQAAAAJ), 24 | [Yanghao Li](https://scholar.google.com/citations?hl=en&user=-VgS8AIAAAAJ)\*, 25 | [Christoph Feichtenhofer](https://scholar.google.com/citations?hl=en&user=UxuqG1EAAAAJ)\* 26 | _[ICML '23 Oral][icml-link]_ | _[GitHub](https://github.com/facebookresearch/hiera)_ | _[arXiv][arxiv-link]_ | _[BibTeX](https://github.com/facebookresearch/hiera#citation)_ 27 | 28 | \*: Equal contribution. 29 | 30 | ## What is Hiera? 31 | **Hiera** is a _hierarchical_ vision transformer that is fast, powerful, and, above all, _simple_. It outperforms the state-of-the-art across a wide array of image and video tasks _while being much faster_. 32 | 33 |

34 | 35 |

36 | 37 | ## How does it work? 38 | ![A diagram of Hiera's architecture.](https://github.com/facebookresearch/hiera/raw/main/examples/img/hiera_arch.png) 39 | 40 | Vision transformers like [ViT](https://arxiv.org/abs/2010.11929) use the same spatial resolution and number of features throughout the whole network. But this is inefficient: the early layers don't need that many features, and the later layers don't need that much spatial resolution. Prior hierarchical models like [ResNet](https://arxiv.org/abs/1512.03385) accounted for this by using fewer features at the start and less spatial resolution at the end. 41 | 42 | Several domain specific vision transformers have been introduced that employ this hierarchical design, such as [Swin](https://arxiv.org/abs/2103.14030) or [MViT](https://arxiv.org/abs/2104.11227). But in the pursuit of state-of-the-art results using fully supervised training on ImageNet-1K, these models have become more and more complicated as they add specialized modules to make up for spatial biases that ViTs lack. While these changes produce effective models with attractive FLOP counts, under the hood the added complexity makes these models _slower_ overall. 43 | 44 | We show that a lot of this bulk is actually _unnecessary_. Instead of manually adding spatial bases through architectural changes, we opt to _teach_ the model these biases instead. By training with [MAE](https://arxiv.org/abs/2111.06377), we can simplify or remove _all_ of these bulky modules in existing transformers and _increase accuracy_ in the process. The result is Hiera, an extremely efficient and simple architecture that outperforms the state-of-the-art in several image and video recognition tasks. 45 | 46 | ## News 47 | - **[2024.03.02]** License for the code has been made more permissive (Apache 2.0)! Model license remains unchanged. 48 | - **[2023.06.12]** Added more in1k models and some video examples, see inference.ipynb (v0.1.1). 49 | - **[2023.06.01]** Initial release. 50 | 51 | See the [changelog](https://github.com/facebookresearch/hiera/tree/main/CHANGELOG.md) for more details. 52 | 53 | ## Installation 54 | 55 | Hiera requires a reasonably recent version of [torch](https://pytorch.org/get-started/locally/). 56 | After that, you can install hiera through [pip](https://pypi.org/project/hiera-transformer/): 57 | ```bash 58 | pip install hiera-transformer 59 | ``` 60 | This repo _should_ support the latest timm version, but timm is a constantly updating package. Create an issue if you have problems with a newer version of timm. 61 | 62 | ### Installing from Source 63 | 64 | If using [torch hub](#model-zoo), you don't need to install the `hiera` package. But, if you'd like to develop using hiera, it could be a good idea to install it from source: 65 | 66 | ```bash 67 | git clone https://github.com/facebookresearch/hiera.git 68 | cd hiera 69 | python setup.py build develop 70 | ``` 71 | 72 | 73 | ## Model Zoo 74 | Note that model weights are released under a separate license than the code. See the [model license](LICENSE.models) for more details. 75 | 76 | ### Torch Hub 77 | 78 | Here we provide model checkpoints for Hiera. Each model listed is accessible on [torch hub](https://pytorch.org/docs/stable/hub.html) even without the `hiera-transformer` package installed, e.g. the following initializes a base model pretrained and finetuned on ImageNet-1k: 79 | ```py 80 | model = torch.hub.load("facebookresearch/hiera", model="hiera_base_224", pretrained=True, checkpoint="mae_in1k_ft_in1k") 81 | ``` 82 | 83 | If you want a model with MAE pretraining only, you can replace the checkpoint with `"mae_in1k"`. Additionally, if you'd like to load the MAE decoder as well (e.g., to continue pretraining), add `mae_` the the start of the model name, e.g.: 84 | ```py 85 | model = torch.hub.load("facebookresearch/hiera", model="mae_hiera_base_224", pretrained=True, checkpoint="mae_in1k") 86 | ``` 87 | **Note:** Our MAE models were trained with a _normalized pixel loss_. That means that the patches were normalized before the network had to predict them. If you want to visualize the predictions, you'll have to unnormalize them using the visible patches (which might work but wouldn't be perfect) or unnormalize them using the ground truth. For model more names and corresponding checkpoint names see below. 88 | 89 | ### Hugging Face Hub 90 | 91 | This repo also has [🤗 hub](https://huggingface.co/docs/hub/index) support. With the `hiera-transformer` and `huggingface-hub` packages installed, you can simply run, e.g., 92 | ```py 93 | from hiera import Hiera 94 | model = Hiera.from_pretrained("facebook/hiera_base_224.mae_in1k_ft_in1k") # mae pt then in1k ft'd model 95 | model = Hiera.from_pretrained("facebook/hiera_base_224.mae_in1k") # just mae pt, no ft 96 | ``` 97 | to load a model. Use `.` from model zoo below. 98 | 99 | If you want to save a model, use `model.config` as the config, e.g., 100 | ```py 101 | model.save_pretrained("hiera-base-224", config=model.config) 102 | ``` 103 | 104 | ### Image Models 105 | | Model | Model Name | Pretrained Models
(IN-1K MAE) | Finetuned Models
(IN-1K Supervised) | IN-1K
Top-1 (%) | A100 fp16
Speed (im/s) | 106 | |----------|-----------------------|----------------------------------|----------------------------------------|:------------------:|:-------------------------:| 107 | | Hiera-T | `hiera_tiny_224` | [mae_in1k](https://dl.fbaipublicfiles.com/hiera/mae_hiera_tiny_224.pth) | [mae_in1k_ft_in1k](https://dl.fbaipublicfiles.com/hiera/hiera_tiny_224.pth) | 82.8 | 2758 | 108 | | Hiera-S | `hiera_small_224` | [mae_in1k](https://dl.fbaipublicfiles.com/hiera/mae_hiera_small_224.pth) | [mae_in1k_ft_in1k](https://dl.fbaipublicfiles.com/hiera/hiera_small_224.pth) | 83.8 | 2211 | 109 | | Hiera-B | `hiera_base_224` | [mae_in1k](https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_224.pth) | [mae_in1k_ft_in1k](https://dl.fbaipublicfiles.com/hiera/hiera_base_224.pth) | 84.5 | 1556 | 110 | | Hiera-B+ | `hiera_base_plus_224` | [mae_in1k](https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_plus_224.pth) | [mae_in1k_ft_in1k](https://dl.fbaipublicfiles.com/hiera/hiera_base_plus_224.pth) | 85.2 | 1247 | 111 | | Hiera-L | `hiera_large_224` | [mae_in1k](https://dl.fbaipublicfiles.com/hiera/mae_hiera_large_224.pth) | [mae_in1k_ft_in1k](https://dl.fbaipublicfiles.com/hiera/hiera_large_224.pth) | 86.1 | 531 | 112 | | Hiera-H | `hiera_huge_224` | [mae_in1k](https://dl.fbaipublicfiles.com/hiera/mae_hiera_huge_224.pth) | [mae_in1k_ft_in1k](https://dl.fbaipublicfiles.com/hiera/hiera_huge_224.pth) | 86.9 | 274 | 113 | 114 | Each model inputs a 224x224 image. 115 | ### Video Models 116 | | Model | Model Name | Pretrained Models
(K400 MAE) | Finetuned Models
(K400) | K400 (3x5 views)
Top-1 (%) | A100 fp16
Speed (clip/s) | 117 | |----------|--------------------------|---------------------------------|----------------------------|:-----------------------------:|:---------------------------:| 118 | | Hiera-B | `hiera_base_16x224` | [mae_k400](https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_16x224.pth) | [mae_k400_ft_k400](https://dl.fbaipublicfiles.com/hiera/hiera_base_16x224.pth) | 84.0 | 133.6 | 119 | | Hiera-B+ | `hiera_base_plus_16x224` | [mae_k400](https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_plus_16x224.pth) | [mae_k400_ft_k400](https://dl.fbaipublicfiles.com/hiera/hiera_base_plus_16x224.pth) | 85.0 | 84.1 | 120 | | Hiera-L | `hiera_large_16x224` | [mae_k400](https://dl.fbaipublicfiles.com/hiera/mae_hiera_large_16x224.pth) | [mae_k400_ft_k400](https://dl.fbaipublicfiles.com/hiera/hiera_large_16x224.pth) | 87.3 | 40.8 | 121 | | Hiera-H | `hiera_huge_16x224` | [mae_k400](https://dl.fbaipublicfiles.com/hiera/mae_hiera_huge_16x224.pth) | [mae_k400_ft_k400](https://dl.fbaipublicfiles.com/hiera/hiera_huge_16x224.pth) | 87.8 | 20.9 | 122 | 123 | Each model inputs 16 224x224 frames with a temporal stride of 4. 124 | 125 | **Note:** the speeds listed here were benchmarked _without_ PyTorch's optimized [scaled dot product attention](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html). If using PyTorch 2.0 or above, your inference speed will probably be faster than what's listed here. 126 | 127 | ## Usage 128 | 129 | This repo implements the code to run Hiera models for inference. This repository is still in progress. Here's what we currently have available and what we have planned: 130 | 131 | - [x] Image Inference 132 | - [x] MAE implementation 133 | - [x] Video Inference 134 | - [x] MAE implementation 135 | - [x] Full Model Zoo 136 | - [ ] Training scripts 137 | 138 | 139 | See [examples](https://github.com/facebookresearch/hiera/tree/main/examples) for examples of how to use Hiera. 140 | 141 | ### Inference 142 | 143 | See [examples/inference](https://github.com/facebookresearch/hiera/blob/main/examples/inference.ipynb) for an example of how to prepare the data for inference. 144 | 145 | Instantiate a model with either [torch hub](#model-zoo) or [🤗 hub](#model-zoo) or by [installing hiera](#installing-from-source) and running: 146 | ```py 147 | import hiera 148 | model = hiera.hiera_base_224(pretrained=True, checkpoint="mae_in1k_ft_in1k") 149 | ``` 150 | Then you can run inference like any other model: 151 | ```py 152 | output = model(x) 153 | ``` 154 | Video inference works the same way, just use a `16x224` model instead. 155 | 156 | **Note**: for efficiency, Hiera re-orders its tokens at the start of the network (see the `Roll` and `Unroll` modules in `hiera_utils.py`). Thus, tokens _aren't in spatial order_ by default. If you'd like to use intermediate feature maps for a downstream task, pass the `return_intermediates` flag when running the model: 157 | ```py 158 | output, intermediates = model(x, return_intermediates=True) 159 | ``` 160 | 161 | #### MAE Inference 162 | By default, the models do not include the MAE decoder. If you would like to use the decoder or compute MAE loss, you can instantiate an mae version by running: 163 | ```py 164 | import hiera 165 | model = hiera.mae_hiera_base_224(pretrained=True, checkpoint="mae_in1k") 166 | ``` 167 | Then when you run inference on the model, it will return a 4-tuple of `(loss, predictions, labels, mask)` where predictions and labels are for the _deleted tokens_ only. The returned mask will be `True` if the token is visible and `False` if it's deleted. You can change the masking ratio by passing it during inference: 168 | ```py 169 | loss, preds, labels, mask = model(x, mask_ratio=0.6) 170 | ``` 171 | The default mask ratio is `0.6` for images, but you should pass in `0.9` for video. See the paper for details. 172 | 173 | **Note:** We use _normalized pixel targets_ for MAE pretraining, meaning the patches are each individually normalized before the model model has to predict them. Thus, you have to unnormalize them using the ground truth before visualizing them. See `get_pixel_label_2d` in `hiera_mae.py` for details. 174 | 175 | ### Benchmarking 176 | We provide a script for easy benchmarking. See [examples/benchmark](https://github.com/facebookresearch/hiera/blob/main/examples/benchmark.ipynb) to see how to use it. 177 | 178 | #### Scaled Dot Product Attention 179 | PyTorch 2.0 introduced optimized [scaled dot product attention](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html), which can speed up transformers quite a bit. We didn't use this in our original benchmarking, but since it's a free speed-up this repo will automatically use it if available. To get its benefits, make sure your torch version is 2.0 or above. 180 | 181 | ### Training 182 | 183 | Coming soon. 184 | 185 | 186 | ## Citation 187 | If you use Hiera or this code in your work, please cite: 188 | ``` 189 | @article{ryali2023hiera, 190 | title={Hiera: A Hierarchical Vision Transformer without the Bells-and-Whistles}, 191 | author={Ryali, Chaitanya and Hu, Yuan-Ting and Bolya, Daniel and Wei, Chen and Fan, Haoqi and Huang, Po-Yao and Aggarwal, Vaibhav and Chowdhury, Arkabandhu and Poursaeed, Omid and Hoffman, Judy and Malik, Jitendra and Li, Yanghao and Feichtenhofer, Christoph}, 192 | journal={ICML}, 193 | year={2023} 194 | } 195 | ``` 196 | 197 | ### License 198 | The code for this work is licensed under the [Apache License, Version 2.0](https://opensource.org/licenses/Apache-2.0), while the model weights are licensed under the [Creative Commons Attribution-NonCommercial 4.0 International License](https://creativecommons.org/licenses/by-nc/4.0/). 199 | 200 | See [LICENSE](LICENSE) for more details on the code license, and [LICENSE.models](LICENSE.models) for more details on the model weight license. 201 | 202 | ### Contributing 203 | See [contributing](CONTRIBUTING.md) and the [code of conduct](CODE_OF_CONDUCT.md). 204 | 205 | [arxiv-link]: https://arxiv.org/abs/2306.00989/ 206 | [icml-link]: https://icml.cc/Conferences/2023 207 | -------------------------------------------------------------------------------- /examples/benchmark.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": {}, 7 | "source": [ 8 | "# **Example**: Hiera Benchmarking\n", 9 | "\n", 10 | "Sample code for how to benchmark Hiera models for different modalities.\n", 11 | "You might have to fiddle with the batch size to get the highest numbers for your environment.\n", 12 | "\n", 13 | "**Note**: Requires the `hiera` package to be installed." 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 1, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "import hiera\n", 23 | "from hiera.benchmarking import benchmark" 24 | ] 25 | }, 26 | { 27 | "attachments": {}, 28 | "cell_type": "markdown", 29 | "metadata": {}, 30 | "source": [ 31 | "## Images\n", 32 | "Benchmarking a Hiera model on 224x224 images. Results are in im/s.\n", 33 | "\n", 34 | "**Note**: I'm using a Quadro GP100 here, your results should be better." 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 2, 40 | "metadata": {}, 41 | "outputs": [ 42 | { 43 | "name": "stderr", 44 | "output_type": "stream", 45 | "text": [ 46 | "Benchmarking: 100%|██████████| 40/40 [00:10<00:00, 3.94it/s]\n" 47 | ] 48 | }, 49 | { 50 | "name": "stdout", 51 | "output_type": "stream", 52 | "text": [ 53 | "Throughput: 264.28 im/s\n" 54 | ] 55 | }, 56 | { 57 | "data": { 58 | "text/plain": [ 59 | "264.280846570216" 60 | ] 61 | }, 62 | "execution_count": 2, 63 | "metadata": {}, 64 | "output_type": "execute_result" 65 | } 66 | ], 67 | "source": [ 68 | "# Create a Hiera-B model for images\n", 69 | "model = hiera.hiera_base_224()\n", 70 | "\n", 71 | "# Run an fp16 benchmark\n", 72 | "benchmark(model, device=0, input_size=(3, 224, 224), batch_size=64, runs=40, use_fp16=True, verbose=True)" 73 | ] 74 | }, 75 | { 76 | "attachments": {}, 77 | "cell_type": "markdown", 78 | "metadata": {}, 79 | "source": [ 80 | "## Video\n", 81 | "Benchmarking a Hiera model on 16 frames of 224x224 images. Results are in clips/s.\n", 82 | "\n", 83 | "**Note**: I'm using a Quadro GP100 here, your results should be better." 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 3, 89 | "metadata": {}, 90 | "outputs": [ 91 | { 92 | "name": "stderr", 93 | "output_type": "stream", 94 | "text": [ 95 | "Benchmarking: 100%|██████████| 40/40 [00:12<00:00, 3.17it/s]\n" 96 | ] 97 | }, 98 | { 99 | "name": "stdout", 100 | "output_type": "stream", 101 | "text": [ 102 | "Throughput: 24.77 im/s\n" 103 | ] 104 | }, 105 | { 106 | "data": { 107 | "text/plain": [ 108 | "24.76710762205785" 109 | ] 110 | }, 111 | "execution_count": 3, 112 | "metadata": {}, 113 | "output_type": "execute_result" 114 | } 115 | ], 116 | "source": [ 117 | "# Create a Hiera-B model for video\n", 118 | "model = hiera.hiera_base_16x224()\n", 119 | "\n", 120 | "# Run an fp16 benchmark\n", 121 | "benchmark(model, device=0, input_size=(3, 16, 224, 224), batch_size=8, runs=40, use_fp16=True, verbose=True)" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": null, 127 | "metadata": {}, 128 | "outputs": [], 129 | "source": [] 130 | } 131 | ], 132 | "metadata": { 133 | "kernelspec": { 134 | "display_name": "Python 3", 135 | "language": "python", 136 | "name": "python3" 137 | }, 138 | "language_info": { 139 | "codemirror_mode": { 140 | "name": "ipython", 141 | "version": 3 142 | }, 143 | "file_extension": ".py", 144 | "mimetype": "text/x-python", 145 | "name": "python", 146 | "nbconvert_exporter": "python", 147 | "pygments_lexer": "ipython3", 148 | "version": "3.10.11" 149 | }, 150 | "orig_nbformat": 4 151 | }, 152 | "nbformat": 4, 153 | "nbformat_minor": 2 154 | } 155 | -------------------------------------------------------------------------------- /examples/img/dog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/hiera/b12b842542ee5c757fcfec8c41f6b56fcbe89b65/examples/img/dog.jpg -------------------------------------------------------------------------------- /examples/img/hiera_arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/hiera/b12b842542ee5c757fcfec8c41f6b56fcbe89b65/examples/img/hiera_arch.png -------------------------------------------------------------------------------- /examples/img/inference_speed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/hiera/b12b842542ee5c757fcfec8c41f6b56fcbe89b65/examples/img/inference_speed.png -------------------------------------------------------------------------------- /examples/vid/dog.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/hiera/b12b842542ee5c757fcfec8c41f6b56fcbe89b65/examples/vid/dog.mp4 -------------------------------------------------------------------------------- /examples/vid/goat.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/hiera/b12b842542ee5c757fcfec8c41f6b56fcbe89b65/examples/vid/goat.mp4 -------------------------------------------------------------------------------- /hiera/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | 8 | from .hiera import ( 9 | hiera_tiny_224, 10 | hiera_small_224, 11 | hiera_base_224, 12 | hiera_base_plus_224, 13 | hiera_large_224, 14 | hiera_huge_224, 15 | 16 | hiera_base_16x224, 17 | hiera_base_plus_16x224, 18 | hiera_large_16x224, 19 | hiera_huge_16x224, 20 | 21 | Hiera, 22 | HieraBlock, 23 | MaskUnitAttention, 24 | Head, 25 | PatchEmbed, 26 | ) 27 | 28 | 29 | from .hiera_mae import ( 30 | mae_hiera_tiny_224, 31 | mae_hiera_small_224, 32 | mae_hiera_base_224, 33 | mae_hiera_base_plus_224, 34 | mae_hiera_large_224, 35 | mae_hiera_huge_224, 36 | 37 | mae_hiera_base_16x224, 38 | mae_hiera_base_plus_16x224, 39 | mae_hiera_large_16x224, 40 | mae_hiera_huge_16x224, 41 | 42 | MaskedAutoencoderHiera, 43 | ) -------------------------------------------------------------------------------- /hiera/benchmarking.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | 8 | import time 9 | from typing import List, Tuple, Union 10 | 11 | import torch 12 | from tqdm import tqdm 13 | 14 | # From https://github.com/facebookresearch/ToMe/ 15 | def benchmark( 16 | model: torch.nn.Module, 17 | device: torch.device = 0, 18 | input_size: Tuple[int] = (3, 224, 224), 19 | batch_size: int = 64, 20 | runs: int = 40, 21 | throw_out: float = 0.25, 22 | use_fp16: bool = False, 23 | verbose: bool = False, 24 | ) -> float: 25 | """ 26 | Benchmark the given model with random inputs at the given batch size. 27 | 28 | Args: 29 | - model: the module to benchmark 30 | - device: the device to use for benchmarking 31 | - input_size: the input size to pass to the model e.g., (ch, h, w) or (ch, t, h, w) 32 | - batch_size: the batch size to use for evaluation 33 | - runs: the number of total runs to do 34 | - throw_out: the percentage of runs to throw out at the start of testing 35 | - use_fp16: whether or not to benchmark with float16 and autocast 36 | - verbose: whether or not to use tqdm to print progress / print throughput at end 37 | 38 | Returns: 39 | - the throughput measured in images / second 40 | """ 41 | if not isinstance(device, torch.device): 42 | device = torch.device(device) 43 | is_cuda = torch.device(device).type == "cuda" 44 | 45 | model = model.eval().to(device) 46 | input = torch.rand(batch_size, *input_size, device=device) 47 | if use_fp16: 48 | input = input.half() 49 | 50 | warm_up = int(runs * throw_out) 51 | total = 0 52 | start = time.time() 53 | 54 | with torch.autocast(device.type, enabled=use_fp16): 55 | with torch.no_grad(): 56 | for i in tqdm(range(runs), disable=not verbose, desc="Benchmarking"): 57 | if i == warm_up: 58 | if is_cuda: 59 | torch.cuda.synchronize() 60 | total = 0 61 | start = time.time() 62 | 63 | model(input) 64 | total += batch_size 65 | 66 | if is_cuda: 67 | torch.cuda.synchronize() 68 | 69 | end = time.time() 70 | elapsed = end - start 71 | 72 | throughput = total / elapsed 73 | 74 | if verbose: 75 | print(f"Throughput: {throughput:.2f} im/s") 76 | 77 | return throughput 78 | -------------------------------------------------------------------------------- /hiera/hfhub.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # https://github.com/facebookresearch/hiera/pull/25 9 | # -------------------------------------------------------- 10 | 11 | import importlib.util 12 | import importlib.metadata 13 | from packaging import version 14 | 15 | import inspect 16 | 17 | def is_huggingface_hub_available(): 18 | available: bool = importlib.util.find_spec("huggingface_hub") is not None 19 | 20 | if not available: 21 | return False 22 | else: 23 | hfversion = importlib.metadata.version("huggingface_hub") 24 | return version.parse(hfversion) >= version.parse("0.21.0") 25 | 26 | 27 | if is_huggingface_hub_available(): 28 | from huggingface_hub import PyTorchModelHubMixin 29 | else: 30 | # Empty class in case modelmixins dont exist 31 | class PyTorchModelHubMixin: 32 | error_str: str = 'This feature requires "huggingface-hub >= 0.21.0" to be installed.' 33 | 34 | @classmethod 35 | def from_pretrained(cls, *args, **kwdargs): 36 | raise RuntimeError(cls.error_str) 37 | 38 | @classmethod 39 | def save_pretrained(cls, *args, **kwdargs): 40 | raise RuntimeError(cls.error_str) 41 | 42 | @classmethod 43 | def push_to_hub(cls, *args, **kwdargs): 44 | raise RuntimeError(cls.error_str) 45 | 46 | 47 | 48 | # Saves the input args to the function as self.config, also allows 49 | # loading a config instead of kwdargs. 50 | def has_config(func): 51 | signature = inspect.signature(func) 52 | 53 | def wrapper(self, *args, **kwdargs): 54 | if "config" in kwdargs: 55 | config = kwdargs["config"] 56 | del kwdargs["config"] 57 | kwdargs.update(**config) 58 | 59 | self.config = { 60 | k: v.default if (i-1) >= len(args) else args[i-1] 61 | for i, (k, v) in enumerate(signature.parameters.items()) 62 | if v.default is not inspect.Parameter.empty 63 | } 64 | self.config.update(**kwdargs) 65 | 66 | func(self, **kwdargs) 67 | return wrapper 68 | -------------------------------------------------------------------------------- /hiera/hiera.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # 8 | # Hiera: A Hierarchical Vision Transformer without the Bells-and-Whistles 9 | # 10 | # Chaitanya Ryali, Yuan-Ting Hu, Daniel Bolya, Chen Wei, Haoqi Fan, 11 | # Po-Yao Huang, Vaibhav Aggarwal, Arkabandhu Chowdhury, Omid Poursaeed, 12 | # Judy Hoffman, Jitendra Malik, Yanghao Li, Christoph Feichtenhofer. 13 | # 14 | # Paper: https://arxiv.org/abs/2306.00989/ 15 | # 16 | # References: 17 | # slowfast: https://github.com/facebookresearch/SlowFast 18 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 19 | # -------------------------------------------------------- 20 | 21 | import math 22 | from functools import partial 23 | from typing import List, Tuple, Callable, Optional, Union 24 | 25 | import torch 26 | import torch.nn as nn 27 | import torch.nn.functional as F 28 | 29 | from timm.models.layers import DropPath, Mlp 30 | 31 | from .hiera_utils import pretrained_model, conv_nd, do_pool, do_masked_conv, Unroll, Reroll 32 | from .hfhub import has_config, PyTorchModelHubMixin 33 | 34 | 35 | class MaskUnitAttention(nn.Module): 36 | """ 37 | Computes either Mask Unit or Global Attention. Also is able to perform q pooling. 38 | 39 | Note: this assumes the tokens have already been flattened and unrolled into mask units. 40 | See `Unroll` for more details. 41 | """ 42 | 43 | def __init__( 44 | self, 45 | dim: int, 46 | dim_out: int, 47 | heads: int, 48 | q_stride: int = 1, 49 | window_size: int = 0, 50 | use_mask_unit_attn: bool = False, 51 | ): 52 | """ 53 | Args: 54 | - dim, dim_out: The input and output feature dimensions. 55 | - heads: The number of attention heads. 56 | - q_stride: If greater than 1, pool q with this stride. The stride should be flattened (e.g., 2x2 = 4). 57 | - window_size: The current (flattened) size of a mask unit *after* pooling (if any). 58 | - use_mask_unit_attn: Use Mask Unit or Global Attention. 59 | """ 60 | super().__init__() 61 | 62 | self.dim = dim 63 | self.dim_out = dim_out 64 | self.heads = heads 65 | self.q_stride = q_stride 66 | 67 | self.head_dim = dim_out // heads 68 | self.scale = (self.head_dim) ** -0.5 69 | 70 | self.qkv = nn.Linear(dim, 3 * dim_out) 71 | self.proj = nn.Linear(dim_out, dim_out) 72 | 73 | self.window_size = window_size 74 | self.use_mask_unit_attn = use_mask_unit_attn 75 | 76 | def forward(self, x: torch.Tensor) -> torch.Tensor: 77 | """ Input should be of shape [batch, tokens, channels]. """ 78 | B, N, _ = x.shape 79 | num_windows = ( 80 | (N // (self.q_stride * self.window_size)) if self.use_mask_unit_attn else 1 81 | ) 82 | 83 | qkv = ( 84 | self.qkv(x) 85 | .reshape(B, -1, num_windows, 3, self.heads, self.head_dim) 86 | .permute(3, 0, 4, 2, 1, 5) 87 | ) 88 | q, k, v = qkv[0], qkv[1], qkv[2] 89 | 90 | if self.q_stride > 1: 91 | # Refer to Unroll to see how this performs a maxpool-Nd 92 | q = ( 93 | q.view(B, self.heads, num_windows, self.q_stride, -1, self.head_dim) 94 | .max(dim=3) 95 | .values 96 | ) 97 | 98 | if hasattr(F, "scaled_dot_product_attention"): 99 | # Note: the original paper did *not* use SDPA, it's a free boost! 100 | x = F.scaled_dot_product_attention(q, k, v) 101 | else: 102 | attn = (q * self.scale) @ k.transpose(-1, -2) 103 | attn = attn.softmax(dim=-1) 104 | x = (attn @ v) 105 | 106 | x = x.transpose(1, 3).reshape(B, -1, self.dim_out) 107 | x = self.proj(x) 108 | return x 109 | 110 | 111 | class HieraBlock(nn.Module): 112 | def __init__( 113 | self, 114 | dim: int, 115 | dim_out: int, 116 | heads: int, 117 | mlp_ratio: float = 4.0, 118 | drop_path: float = 0.0, 119 | norm_layer: nn.Module = nn.LayerNorm, 120 | act_layer: nn.Module = nn.GELU, 121 | q_stride: int = 1, 122 | window_size: int = 0, 123 | use_mask_unit_attn: bool = False, 124 | ): 125 | super().__init__() 126 | 127 | self.dim = dim 128 | self.dim_out = dim_out 129 | 130 | self.norm1 = norm_layer(dim) 131 | self.attn = MaskUnitAttention( 132 | dim, dim_out, heads, q_stride, window_size, use_mask_unit_attn 133 | ) 134 | 135 | self.norm2 = norm_layer(dim_out) 136 | self.mlp = Mlp(dim_out, int(dim_out * mlp_ratio), act_layer=act_layer) 137 | 138 | self.drop_path = DropPath(drop_path) if drop_path > 0 else nn.Identity() 139 | if dim != dim_out: 140 | self.proj = nn.Linear(dim, dim_out) 141 | 142 | def forward(self, x: torch.Tensor) -> torch.Tensor: 143 | # Attention + Q Pooling 144 | x_norm = self.norm1(x) 145 | if self.dim != self.dim_out: 146 | x = do_pool(self.proj(x_norm), stride=self.attn.q_stride) 147 | x = x + self.drop_path(self.attn(x_norm)) 148 | 149 | # MLP 150 | x = x + self.drop_path(self.mlp(self.norm2(x))) 151 | return x 152 | 153 | 154 | class Head(nn.Module): 155 | def __init__( 156 | self, 157 | dim: int, 158 | num_classes: int, 159 | dropout_rate: float = 0.0, 160 | act_func: Callable[[torch.Tensor], torch.Tensor] = lambda x: x.softmax(dim=-1), 161 | ): 162 | super().__init__() 163 | self.dropout = nn.Dropout(dropout_rate) if dropout_rate > 0 else nn.Identity() 164 | self.projection = nn.Linear(dim, num_classes) 165 | # act_fun for eval and testing only 166 | self.act_func = act_func 167 | 168 | def forward(self, x: torch.Tensor) -> torch.Tensor: 169 | x = self.dropout(x) 170 | x = self.projection(x) 171 | if not self.training: 172 | x = self.act_func(x) 173 | return x 174 | 175 | 176 | class PatchEmbed(nn.Module): 177 | """Patch embed that supports any number of spatial dimensions (1d, 2d, 3d).""" 178 | 179 | def __init__( 180 | self, 181 | dim_in: int, 182 | dim_out: int, 183 | kernel: Tuple[int, ...], 184 | stride: Tuple[int, ...], 185 | padding: Tuple[int, ...], 186 | ): 187 | super().__init__() 188 | 189 | # Support any number of spatial dimensions 190 | self.spatial_dims = len(kernel) 191 | self.proj = conv_nd(self.spatial_dims)( 192 | dim_in, 193 | dim_out, 194 | kernel_size=kernel, 195 | stride=stride, 196 | padding=padding, 197 | ) 198 | 199 | def forward( 200 | self, x: torch.Tensor, mask: Optional[torch.Tensor] = None 201 | ) -> torch.Tensor: 202 | x = do_masked_conv(x, self.proj, mask) 203 | x = x.reshape(x.shape[0], x.shape[1], -1).transpose(2, 1) 204 | return x 205 | 206 | 207 | class Hiera(nn.Module, PyTorchModelHubMixin): 208 | @has_config 209 | def __init__( 210 | self, 211 | input_size: Tuple[int, ...] = (224, 224), 212 | in_chans: int = 3, 213 | embed_dim: int = 96, # initial embed dim 214 | num_heads: int = 1, # initial number of heads 215 | num_classes: int = 1000, 216 | stages: Tuple[int, ...] = (2, 3, 16, 3), 217 | q_pool: int = 3, # number of q_pool stages 218 | q_stride: Tuple[int, ...] = (2, 2), 219 | mask_unit_size: Tuple[int, ...] = (8, 8), # must divide q_stride ** (#stages-1) 220 | # mask_unit_attn: which stages use mask unit attention? 221 | mask_unit_attn: Tuple[bool, ...] = (True, True, False, False), 222 | dim_mul: float = 2.0, 223 | head_mul: float = 2.0, 224 | patch_kernel: Tuple[int, ...] = (7, 7), 225 | patch_stride: Tuple[int, ...] = (4, 4), 226 | patch_padding: Tuple[int, ...] = (3, 3), 227 | mlp_ratio: float = 4.0, 228 | drop_path_rate: float = 0.0, 229 | norm_layer: Union[str, nn.Module] = "LayerNorm", 230 | head_dropout: float = 0.0, 231 | head_init_scale: float = 0.001, 232 | sep_pos_embed: bool = False, 233 | ): 234 | super().__init__() 235 | 236 | # Do it this way to ensure that the init args are all PoD (for config usage) 237 | if isinstance(norm_layer, str): 238 | norm_layer = partial(getattr(nn, norm_layer), eps=1e-6) 239 | 240 | depth = sum(stages) 241 | self.patch_stride = patch_stride 242 | self.tokens_spatial_shape = [i // s for i, s in zip(input_size, patch_stride)] 243 | num_tokens = math.prod(self.tokens_spatial_shape) 244 | flat_mu_size = math.prod(mask_unit_size) 245 | flat_q_stride = math.prod(q_stride) 246 | 247 | assert q_pool < len(stages) 248 | self.q_pool, self.q_stride = q_pool, q_stride 249 | self.mu_size, self.mask_unit_size = flat_mu_size, mask_unit_size 250 | self.mask_spatial_shape = [ 251 | i // s for i, s in zip(self.tokens_spatial_shape, self.mask_unit_size) 252 | ] 253 | self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)] 254 | 255 | self.patch_embed = PatchEmbed( 256 | in_chans, embed_dim, patch_kernel, patch_stride, patch_padding 257 | ) 258 | 259 | self.sep_pos_embed = sep_pos_embed 260 | if sep_pos_embed: 261 | self.pos_embed_spatial = nn.Parameter( 262 | torch.zeros( 263 | 1, 264 | self.tokens_spatial_shape[1] * self.tokens_spatial_shape[2], 265 | embed_dim, 266 | ) 267 | ) 268 | self.pos_embed_temporal = nn.Parameter( 269 | torch.zeros(1, self.tokens_spatial_shape[0], embed_dim) 270 | ) 271 | else: 272 | self.pos_embed = nn.Parameter(torch.zeros(1, num_tokens, embed_dim)) 273 | 274 | # Setup roll and reroll modules 275 | self.unroll = Unroll( 276 | input_size, patch_stride, [q_stride] * len(self.stage_ends[:-1]) 277 | ) 278 | self.reroll = Reroll( 279 | input_size, 280 | patch_stride, 281 | [q_stride] * len(self.stage_ends[:-1]), 282 | self.stage_ends, 283 | q_pool, 284 | ) 285 | # q_pool locations 286 | q_pool_blocks = [x + 1 for x in self.stage_ends[:q_pool]] 287 | # stochastic depth decay rule 288 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] 289 | 290 | # Transformer blocks 291 | cur_stage = 0 292 | self.blocks = nn.ModuleList() 293 | 294 | for i in range(depth): 295 | dim_out = embed_dim 296 | # Mask unit or global attention. 297 | # Lag by 1 block, so that global attention, 298 | # applied post pooling on lower resolution 299 | use_mask_unit_attn = mask_unit_attn[cur_stage] 300 | 301 | if i - 1 in self.stage_ends: 302 | dim_out = int(embed_dim * dim_mul) 303 | num_heads = int(num_heads * head_mul) 304 | cur_stage += 1 305 | if i in q_pool_blocks: 306 | flat_mu_size //= flat_q_stride 307 | 308 | block = HieraBlock( 309 | dim=embed_dim, 310 | dim_out=dim_out, 311 | heads=num_heads, 312 | mlp_ratio=mlp_ratio, 313 | drop_path=dpr[i], 314 | norm_layer=norm_layer, 315 | q_stride=(flat_q_stride if i in q_pool_blocks else 1), 316 | window_size=flat_mu_size, 317 | use_mask_unit_attn=use_mask_unit_attn, 318 | ) 319 | 320 | embed_dim = dim_out 321 | self.blocks.append(block) 322 | 323 | self.norm = norm_layer(embed_dim) 324 | self.head = Head(embed_dim, num_classes, dropout_rate=head_dropout) 325 | 326 | # Initialize everything 327 | if sep_pos_embed: 328 | nn.init.trunc_normal_(self.pos_embed_spatial, std=0.02) 329 | nn.init.trunc_normal_(self.pos_embed_temporal, std=0.02) 330 | else: 331 | nn.init.trunc_normal_(self.pos_embed, std=0.02) 332 | self.apply(partial(self._init_weights)) 333 | self.head.projection.weight.data.mul_(head_init_scale) 334 | self.head.projection.bias.data.mul_(head_init_scale) 335 | 336 | def _init_weights(self, m, init_bias=0.02): 337 | if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)): 338 | nn.init.trunc_normal_(m.weight, std=0.02) 339 | if isinstance(m, nn.Linear) and m.bias is not None: 340 | nn.init.constant_(m.bias, init_bias) 341 | elif isinstance(m, nn.LayerNorm): 342 | nn.init.constant_(m.bias, init_bias) 343 | nn.init.constant_(m.weight, 1.0) 344 | 345 | @torch.jit.ignore 346 | def no_weight_decay(self): 347 | if self.sep_pos_embed: 348 | return ["pos_embed_spatial", "pos_embed_temporal"] 349 | else: 350 | return ["pos_embed"] 351 | 352 | def get_random_mask(self, x: torch.Tensor, mask_ratio: float) -> torch.Tensor: 353 | """ 354 | Generates a random mask, mask_ratio fraction are dropped. 355 | 1 is *keep*, 0 is *remove*. Useful for MAE, FLIP, etc. 356 | """ 357 | B = x.shape[0] 358 | # Tokens selected for masking at mask unit level 359 | num_windows = math.prod(self.mask_spatial_shape) # num_mask_units 360 | len_keep = int(num_windows * (1 - mask_ratio)) 361 | noise = torch.rand(B, num_windows, device=x.device) 362 | 363 | # Sort noise for each sample 364 | ids_shuffle = torch.argsort( 365 | noise, dim=1 366 | ) # ascend: small is keep, large is remove 367 | ids_restore = torch.argsort(ids_shuffle, dim=1) 368 | 369 | # Generate the binary mask: 1 is *keep*, 0 is *remove* 370 | # Note this is opposite to original MAE 371 | mask = torch.zeros([B, num_windows], device=x.device) 372 | mask[:, :len_keep] = 1 373 | # Unshuffle to get the binary mask 374 | mask = torch.gather(mask, dim=1, index=ids_restore) 375 | 376 | return mask.bool() 377 | 378 | def get_pos_embed(self) -> torch.Tensor: 379 | if self.sep_pos_embed: 380 | return self.pos_embed_spatial.repeat( 381 | 1, self.tokens_spatial_shape[0], 1 382 | ) + torch.repeat_interleave( 383 | self.pos_embed_temporal, 384 | self.tokens_spatial_shape[1] * self.tokens_spatial_shape[2], 385 | dim=1, 386 | ) 387 | else: 388 | return self.pos_embed 389 | 390 | def forward( 391 | self, 392 | x: torch.Tensor, 393 | mask: torch.Tensor = None, 394 | return_intermediates: bool = False, 395 | ) -> torch.Tensor: 396 | """ 397 | mask should be a boolean tensor of shape [B, #MUt*#MUy*#MUx] where #MU are the number of mask units in that dim. 398 | Note: 1 in mask is *keep*, 0 is *remove*; mask.sum(dim=-1) should be the same across the batch. 399 | """ 400 | # Slowfast training passes in a list 401 | if isinstance(x, list): 402 | x = x[0] 403 | intermediates = [] 404 | 405 | x = self.patch_embed( 406 | x, 407 | mask=mask.view( 408 | x.shape[0], 1, *self.mask_spatial_shape 409 | ) # B, C, *mask_spatial_shape 410 | if mask is not None 411 | else None, 412 | ) 413 | x = x + self.get_pos_embed() 414 | x = self.unroll(x) 415 | 416 | # Discard masked tokens 417 | if mask is not None: 418 | x = x[mask[..., None].tile(1, self.mu_size, x.shape[2])].view( 419 | x.shape[0], -1, x.shape[-1] 420 | ) 421 | 422 | for i, blk in enumerate(self.blocks): 423 | x = blk(x) 424 | 425 | if return_intermediates and i in self.stage_ends: 426 | intermediates.append(self.reroll(x, i, mask=mask)) 427 | 428 | if mask is None: 429 | x = x.mean(dim=1) 430 | x = self.norm(x) 431 | x = self.head(x) 432 | 433 | # x may not always be in spatial order here. 434 | # e.g. if q_pool = 2, mask_unit_size = (8, 8), and 435 | # q_stride = (2, 2), not all unrolls were consumed, 436 | # intermediates[-1] is x in spatial order 437 | if return_intermediates: 438 | return x, intermediates 439 | 440 | return x 441 | 442 | 443 | # Image models 444 | 445 | @pretrained_model({ 446 | "mae_in1k_ft_in1k": "https://dl.fbaipublicfiles.com/hiera/hiera_tiny_224.pth", 447 | "mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_tiny_224.pth", 448 | }, default="mae_in1k_ft_in1k") 449 | def hiera_tiny_224(**kwdargs): 450 | return Hiera(embed_dim=96, num_heads=1, stages=(1, 2, 7, 2), **kwdargs) 451 | 452 | 453 | @pretrained_model({ 454 | "mae_in1k_ft_in1k": "https://dl.fbaipublicfiles.com/hiera/hiera_small_224.pth", 455 | "mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_small_224.pth", 456 | }, default="mae_in1k_ft_in1k") 457 | def hiera_small_224(**kwdargs): 458 | return Hiera(embed_dim=96, num_heads=1, stages=(1, 2, 11, 2), **kwdargs) 459 | 460 | 461 | @pretrained_model({ 462 | "mae_in1k_ft_in1k": "https://dl.fbaipublicfiles.com/hiera/hiera_base_224.pth", 463 | "mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_224.pth", 464 | }, default="mae_in1k_ft_in1k") 465 | def hiera_base_224(**kwdargs): 466 | return Hiera(embed_dim=96, num_heads=1, stages=(2, 3, 16, 3), **kwdargs) 467 | 468 | 469 | @pretrained_model({ 470 | "mae_in1k_ft_in1k": "https://dl.fbaipublicfiles.com/hiera/hiera_base_plus_224.pth", 471 | "mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_plus_224.pth", 472 | }, default="mae_in1k_ft_in1k") 473 | def hiera_base_plus_224(**kwdargs): 474 | return Hiera(embed_dim=112, num_heads=2, stages=(2, 3, 16, 3), **kwdargs) 475 | 476 | 477 | @pretrained_model({ 478 | "mae_in1k_ft_in1k": "https://dl.fbaipublicfiles.com/hiera/hiera_large_224.pth", 479 | "mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_large_224.pth", 480 | }, default="mae_in1k_ft_in1k") 481 | def hiera_large_224(**kwdargs): 482 | return Hiera(embed_dim=144, num_heads=2, stages=(2, 6, 36, 4), **kwdargs) 483 | 484 | 485 | @pretrained_model({ 486 | "mae_in1k_ft_in1k": "https://dl.fbaipublicfiles.com/hiera/hiera_huge_224.pth", 487 | "mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_huge_224.pth", 488 | }, default="mae_in1k_ft_in1k") 489 | def hiera_huge_224(**kwdargs): 490 | return Hiera(embed_dim=256, num_heads=4, stages=(2, 6, 36, 4), **kwdargs) 491 | 492 | 493 | # Video models 494 | 495 | @pretrained_model({ 496 | "mae_k400_ft_k400": "https://dl.fbaipublicfiles.com/hiera/hiera_base_16x224.pth", 497 | "mae_k400": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_16x224.pth", 498 | }, default="mae_k400_ft_k400") 499 | def hiera_base_16x224(num_classes: int = 400, **kwdargs): 500 | return Hiera( 501 | num_classes=num_classes, # K400 has 400 classes 502 | input_size=(16, 224, 224), 503 | q_stride=(1, 2, 2), 504 | mask_unit_size=(1, 8, 8), 505 | patch_kernel=(3, 7, 7), 506 | patch_stride=(2, 4, 4), 507 | patch_padding=(1, 3, 3), 508 | sep_pos_embed=True, 509 | **kwdargs 510 | ) 511 | 512 | 513 | @pretrained_model({ 514 | "mae_k400_ft_k400": "https://dl.fbaipublicfiles.com/hiera/hiera_base_plus_16x224.pth", 515 | "mae_k400": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_plus_16x224.pth", 516 | }, default="mae_k400_ft_k400") 517 | def hiera_base_plus_16x224(**kwdargs): 518 | return hiera_base_16x224( 519 | embed_dim=112, num_heads=2, stages=(2, 3, 16, 3), **kwdargs 520 | ) 521 | 522 | 523 | @pretrained_model({ 524 | "mae_k400_ft_k400": "https://dl.fbaipublicfiles.com/hiera/hiera_large_16x224.pth", 525 | "mae_k400": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_large_16x224.pth", 526 | }, default="mae_k400_ft_k400") 527 | def hiera_large_16x224(**kwdargs): 528 | return hiera_base_16x224( 529 | embed_dim=144, num_heads=2, stages=(2, 6, 36, 4), **kwdargs 530 | ) 531 | 532 | 533 | @pretrained_model({ 534 | "mae_k400_ft_k400": "https://dl.fbaipublicfiles.com/hiera/hiera_huge_16x224.pth", 535 | "mae_k400": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_huge_16x224.pth", 536 | }, default="mae_k400_ft_k400") 537 | def hiera_huge_16x224(**kwdargs): 538 | return hiera_base_16x224( 539 | embed_dim=256, num_heads=4, stages=(2, 6, 36, 4), **kwdargs 540 | ) 541 | -------------------------------------------------------------------------------- /hiera/hiera_mae.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # mae: https://github.com/facebookresearch/mae 9 | # slowfast: https://github.com/facebookresearch/SlowFast 10 | # -------------------------------------------------------- 11 | 12 | 13 | from functools import partial 14 | from typing import Tuple, Optional 15 | 16 | import math 17 | import torch 18 | import torch.nn as nn 19 | 20 | from .hiera import Hiera, HieraBlock 21 | from .hiera_utils import pretrained_model, undo_windowing, conv_nd 22 | 23 | 24 | def apply_fusion_head(head: nn.Module, x: torch.Tensor) -> torch.Tensor: 25 | if isinstance(head, nn.Identity): 26 | return x 27 | 28 | B, num_mask_units = x.shape[0:2] 29 | # Apply head, e.g [B, #MUs, My, Mx, C] -> head([B * #MUs, C, My, Mx]) 30 | permute = [0] + [len(x.shape) - 2] + list(range(1, len(x.shape) - 2)) 31 | x = head(x.reshape(B * num_mask_units, *x.shape[2:]).permute(permute)) 32 | 33 | # Restore original layout, e.g. [B * #MUs, C', My', Mx'] -> [B, #MUs, My', Mx', C'] 34 | permute = [0] + list(range(2, len(x.shape))) + [1] 35 | x = x.permute(permute).reshape(B, num_mask_units, *x.shape[2:], x.shape[1]) 36 | return x 37 | 38 | 39 | class MaskedAutoencoderHiera(Hiera): 40 | """Masked Autoencoder with Hiera backbone""" 41 | 42 | def __init__( 43 | self, 44 | in_chans: int = 3, 45 | patch_stride: Tuple[int, ...] = (4, 4), 46 | mlp_ratio: float = 4.0, 47 | decoder_embed_dim: int = 512, 48 | decoder_depth: int = 8, 49 | decoder_num_heads: int = 16, 50 | norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6), 51 | **kwdargs, 52 | ): 53 | super().__init__( 54 | in_chans=in_chans, 55 | patch_stride=patch_stride, 56 | mlp_ratio=mlp_ratio, 57 | norm_layer=norm_layer, 58 | **kwdargs, 59 | ) 60 | 61 | del self.norm, self.head 62 | encoder_dim_out = self.blocks[-1].dim_out 63 | self.encoder_norm = norm_layer(encoder_dim_out) 64 | self.mask_unit_spatial_shape_final = [ 65 | i // s ** (self.q_pool) for i, s in zip(self.mask_unit_size, self.q_stride) 66 | ] 67 | self.tokens_spatial_shape_final = [ 68 | i // s ** (self.q_pool) 69 | for i, s in zip(self.tokens_spatial_shape, self.q_stride) 70 | ] 71 | # -------------------------------------------------------------------------- 72 | # Multi-scale fusion heads 73 | curr_mu_size = self.mask_unit_size 74 | self.multi_scale_fusion_heads = nn.ModuleList() 75 | 76 | for i in self.stage_ends[: self.q_pool]: # resolution constant after q_pool 77 | kernel = [ 78 | i // s for i, s in zip(curr_mu_size, self.mask_unit_spatial_shape_final) 79 | ] 80 | curr_mu_size = [i // s for i, s in zip(curr_mu_size, self.q_stride)] 81 | self.multi_scale_fusion_heads.append( 82 | conv_nd(len(self.q_stride))( 83 | self.blocks[i].dim_out, 84 | encoder_dim_out, 85 | kernel_size=kernel, 86 | stride=kernel, 87 | ) 88 | ) 89 | self.multi_scale_fusion_heads.append(nn.Identity()) # final stage, no transform 90 | 91 | # -------------------------------------------------------------------------- 92 | # MAE decoder specifics 93 | self.decoder_embed = nn.Linear(encoder_dim_out, decoder_embed_dim) 94 | 95 | self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) 96 | 97 | self.decoder_pos_embed = nn.Parameter( 98 | torch.zeros( 99 | 1, math.prod(self.tokens_spatial_shape_final), decoder_embed_dim 100 | ) 101 | ) 102 | 103 | self.decoder_blocks = nn.ModuleList( 104 | [ 105 | HieraBlock( 106 | dim=decoder_embed_dim, 107 | dim_out=decoder_embed_dim, 108 | heads=decoder_num_heads, 109 | norm_layer=norm_layer, 110 | mlp_ratio=mlp_ratio, 111 | ) 112 | for i in range(decoder_depth) 113 | ] 114 | ) 115 | self.decoder_norm = norm_layer(decoder_embed_dim) 116 | 117 | self.pred_stride = patch_stride[-1] * ( 118 | self.q_stride[-1] ** self.q_pool 119 | ) # patch stride of prediction 120 | 121 | self.decoder_pred = nn.Linear( 122 | decoder_embed_dim, 123 | (self.pred_stride ** min(2, len(self.q_stride))) * in_chans, 124 | ) # predictor 125 | # -------------------------------------------------------------------------- 126 | 127 | self.initialize_weights() 128 | 129 | def initialize_weights(self): 130 | nn.init.trunc_normal_(self.mask_token, std=0.02) 131 | nn.init.trunc_normal_(self.decoder_pos_embed, std=0.02) 132 | self.apply(self._mae_init_weights) 133 | 134 | # initialize patch_embed like nn.Linear (instead of nn.Conv2d) 135 | w = self.patch_embed.proj.weight.data 136 | nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 137 | 138 | def _mae_init_weights(self, m: nn.Module): 139 | if isinstance(m, nn.Linear): 140 | nn.init.xavier_uniform_(m.weight) 141 | if m.bias is not None: 142 | nn.init.constant_(m.bias, 0) 143 | elif isinstance(m, nn.LayerNorm): 144 | nn.init.constant_(m.bias, 0) 145 | nn.init.constant_(m.weight, 1.0) 146 | 147 | def get_pixel_label_2d( 148 | self, input_img: torch.Tensor, mask: torch.Tensor, norm: bool = True 149 | ) -> torch.Tensor: 150 | # mask (boolean tensor): True must correspond to *masked* 151 | input_img = input_img.permute(0, 2, 3, 1) 152 | 153 | size = self.pred_stride 154 | label = input_img.unfold(1, size, size).unfold(2, size, size) 155 | label = label.flatten(1, 2).flatten(2) 156 | label = label[mask] 157 | if norm: 158 | mean = label.mean(dim=-1, keepdim=True) 159 | var = label.var(dim=-1, keepdim=True) 160 | label = (label - mean) / (var + 1.0e-6) ** 0.5 161 | 162 | return label 163 | 164 | def get_pixel_label_3d( 165 | self, input_vid: torch.Tensor, mask: torch.Tensor, norm: bool = True 166 | ) -> torch.Tensor: 167 | # mask (boolean tensor): True must correspond to *masked* 168 | 169 | # We use time strided loss, only take the first frame from each token 170 | input_vid = input_vid[:, :, ::self.patch_stride[0], :, :] 171 | 172 | size = self.pred_stride 173 | label = input_vid.unfold(3, size, size).unfold(4, size, size) 174 | label = label.permute(0, 2, 3, 4, 5, 6, 1) # Different from 2d, mistake during training lol 175 | label = label.flatten(1, 3).flatten(2) 176 | label = label[mask] 177 | 178 | if norm: 179 | mean = label.mean(dim=-1, keepdim=True) 180 | var = label.var(dim=-1, keepdim=True) 181 | label = (label - mean) / (var + 1.0e-6) ** 0.5 182 | 183 | return label 184 | 185 | 186 | def forward_encoder( 187 | self, x: torch.Tensor, mask_ratio: float, mask: Optional[torch.Tensor] = None 188 | ) -> Tuple[torch.Tensor, torch.Tensor]: 189 | 190 | if mask is None: 191 | mask = self.get_random_mask(x, mask_ratio) # [B, #MUs_all] 192 | 193 | # Get multi-scale representations from encoder 194 | _, intermediates = super().forward(x, mask, return_intermediates=True) 195 | # Resolution unchanged after q_pool stages, so skip those features 196 | intermediates = intermediates[: self.q_pool] + intermediates[-1:] 197 | 198 | # Multi-scale fusion 199 | x = 0.0 200 | for head, interm_x in zip(self.multi_scale_fusion_heads, intermediates): 201 | x += apply_fusion_head(head, interm_x) 202 | 203 | x = self.encoder_norm(x) 204 | 205 | return x, mask 206 | 207 | def forward_decoder( 208 | self, x: torch.Tensor, mask: torch.Tensor 209 | ) -> Tuple[torch.Tensor, torch.Tensor]: 210 | # Embed tokens 211 | x = self.decoder_embed(x) 212 | 213 | # Combine visible and mask tokens 214 | 215 | # x: [B, #MUs, *mask_unit_spatial_shape_final, encoder_dim_out] 216 | # mask: [B, #MUs_all] 217 | x_dec = torch.zeros(*mask.shape, *x.shape[2:], device=x.device, dtype=x.dtype) 218 | mask_tokens = self.mask_token.view( 219 | (1,) * (len(mask.shape) + len(x.shape[2:-1])) + (-1,) 220 | ) 221 | mask = mask.reshape(mask.shape + (1,) * len(x.shape[2:])) 222 | mask = mask.expand((-1,) * 2 + x.shape[2:]).bool() 223 | x_dec[mask] = x.flatten() 224 | x_dec = ~mask * mask_tokens + mask * x_dec 225 | 226 | # Get back spatial order 227 | x = undo_windowing( 228 | x_dec, 229 | self.tokens_spatial_shape_final, 230 | self.mask_unit_spatial_shape_final, 231 | ) 232 | mask = undo_windowing( 233 | mask[..., 0:1], 234 | self.tokens_spatial_shape_final, 235 | self.mask_unit_spatial_shape_final, 236 | ) 237 | 238 | # Flatten 239 | x = x.reshape(x.shape[0], -1, x.shape[-1]) 240 | mask = mask.view(x.shape[0], -1) 241 | 242 | # Add pos embed 243 | x = x + self.decoder_pos_embed 244 | 245 | # Apply decoder blocks 246 | for blk in self.decoder_blocks: 247 | x = blk(x) 248 | x = self.decoder_norm(x) 249 | 250 | # Predictor projection 251 | x = self.decoder_pred(x) 252 | 253 | return x, mask 254 | 255 | def forward_loss( 256 | self, x: torch.Tensor, pred: torch.Tensor, mask: torch.Tensor 257 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 258 | """ 259 | Note: in mask, 0 is *visible*, 1 is *masked* 260 | 261 | x: e.g. [B, 3, H, W] 262 | pred: [B * num_pred_tokens, num_pixels_in_pred_patch * in_chans] 263 | label: [B * num_pred_tokens, num_pixels_in_pred_patch * in_chans] 264 | """ 265 | if len(self.q_stride) == 2: 266 | label = self.get_pixel_label_2d(x, mask) 267 | elif len(self.q_stride) == 3: 268 | label = self.get_pixel_label_3d(x, mask) 269 | else: 270 | raise NotImplementedError 271 | 272 | pred = pred[mask] 273 | loss = (pred - label) ** 2 274 | 275 | return loss.mean(), pred, label 276 | 277 | def forward( 278 | self, 279 | x: torch.Tensor, 280 | mask_ratio: float = 0.6, 281 | mask: Optional[torch.Tensor] = None, 282 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 283 | 284 | latent, mask = self.forward_encoder(x, mask_ratio, mask=mask) 285 | pred, pred_mask = self.forward_decoder( 286 | latent, mask 287 | ) # pred_mask is mask at resolution of *prediction* 288 | 289 | # Toggle mask, to generate labels for *masked* tokens 290 | return *self.forward_loss(x, pred, ~pred_mask), mask 291 | 292 | 293 | 294 | 295 | # Image Models 296 | 297 | @pretrained_model({ 298 | "mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_tiny_224.pth", 299 | }, default="mae_in1k") 300 | def mae_hiera_tiny_224(**kwargs): 301 | return MaskedAutoencoderHiera( 302 | embed_dim=96, num_heads=1, stages=(1, 2, 7, 2), q_pool=2, **kwargs, 303 | ) 304 | 305 | 306 | @pretrained_model({ 307 | "mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_small_224.pth", 308 | }, default="mae_in1k") 309 | def mae_hiera_small_224(**kwargs): 310 | return MaskedAutoencoderHiera( 311 | embed_dim=96, num_heads=1, stages=(1, 2, 11, 2), q_pool=2, **kwargs, 312 | ) 313 | 314 | 315 | @pretrained_model({ 316 | "mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_224.pth", 317 | }, default="mae_in1k") 318 | def mae_hiera_base_224(**kwargs): 319 | return MaskedAutoencoderHiera( 320 | embed_dim=96, num_heads=1, stages=(2, 3, 16, 3), q_pool=2, **kwargs, 321 | ) 322 | 323 | 324 | @pretrained_model({ 325 | "mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_plus_224.pth", 326 | }, default="mae_in1k") 327 | def mae_hiera_base_plus_224(**kwargs): 328 | return MaskedAutoencoderHiera( 329 | embed_dim=112, num_heads=2, stages=(2, 3, 16, 3), q_pool=2, **kwargs, 330 | ) 331 | 332 | 333 | @pretrained_model({ 334 | "mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_large_224.pth", 335 | }, default="mae_in1k") 336 | def mae_hiera_large_224(**kwargs): 337 | return MaskedAutoencoderHiera( 338 | embed_dim=144, num_heads=2, stages=(2, 6, 36, 4), q_pool=2, **kwargs, 339 | ) 340 | 341 | 342 | @pretrained_model({ 343 | "mae_in1k": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_huge_224.pth", 344 | }, default="mae_in1k") 345 | def mae_hiera_huge_224(**kwargs): 346 | return MaskedAutoencoderHiera( 347 | embed_dim=256, num_heads=4, stages=(2, 6, 36, 4), q_pool=2, **kwargs, 348 | ) 349 | 350 | 351 | 352 | # Video Models 353 | 354 | @pretrained_model({ 355 | "mae_k400": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_16x224.pth", 356 | }, default="mae_k400") 357 | def mae_hiera_base_16x224(num_classes: int = 400, **kwdargs): 358 | return MaskedAutoencoderHiera( 359 | num_classes=num_classes, # K400 has 400 classes 360 | input_size=(16, 224, 224), 361 | q_stride=(1, 2, 2), 362 | mask_unit_size=(1, 8, 8), 363 | patch_kernel=(3, 7, 7), 364 | patch_stride=(2, 4, 4), 365 | patch_padding=(1, 3, 3), 366 | sep_pos_embed=True, 367 | q_pool=2, 368 | **kwdargs 369 | ) 370 | 371 | 372 | @pretrained_model({ 373 | "mae_k400": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_plus_16x224.pth", 374 | }, default="mae_k400") 375 | @pretrained_model(None) 376 | def mae_hiera_base_plus_16x224(**kwdargs): 377 | return mae_hiera_base_16x224( 378 | embed_dim=112, num_heads=2, stages=(2, 3, 16, 3), **kwdargs 379 | ) 380 | 381 | 382 | @pretrained_model({ 383 | "mae_k400": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_large_16x224.pth", 384 | }, default="mae_k400") 385 | @pretrained_model(None) 386 | def mae_hiera_large_16x224(**kwdargs): 387 | return mae_hiera_base_16x224( 388 | embed_dim=144, num_heads=2, stages=(2, 6, 36, 4), **kwdargs 389 | ) 390 | 391 | 392 | @pretrained_model({ 393 | "mae_k400": "https://dl.fbaipublicfiles.com/hiera/mae_hiera_huge_16x224.pth", 394 | }, default="mae_k400") 395 | def mae_hiera_huge_16x224(**kwdargs): 396 | return mae_hiera_base_16x224( 397 | embed_dim=256, num_heads=4, stages=(2, 6, 36, 4), **kwdargs 398 | ) 399 | -------------------------------------------------------------------------------- /hiera/hiera_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # 8 | # Hiera: A Hierarchical Vision Transformer without the Bells-and-Whistles 9 | # 10 | # Chaitanya Ryali, Yuan-Ting Hu, Daniel Bolya, Chen Wei, Haoqi Fan, 11 | # Po-Yao Huang, Vaibhav Aggarwal, Arkabandhu Chowdhury, Omid Poursaeed, 12 | # Judy Hoffman, Jitendra Malik, Yanghao Li, Christoph Feichtenhofer. 13 | # 14 | # Paper: https://arxiv.org/abs/2306.00989/ 15 | # 16 | # References: 17 | # slowfast: https://github.com/facebookresearch/SlowFast 18 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 19 | # -------------------------------------------------------- 20 | 21 | import math 22 | from typing import List, Tuple, Optional, Type, Callable, Dict 23 | 24 | import torch 25 | import torch.nn as nn 26 | import torch.nn.functional as F 27 | 28 | 29 | def pretrained_model(checkpoints: Dict[str, str], default: str = None) -> Callable: 30 | """ Loads a Hiera model from a pretrained source (if pretrained=True). Use "checkpoint" to specify the checkpoint. """ 31 | 32 | def inner(model_func: Callable) -> Callable: 33 | def model_def(pretrained: bool = False, checkpoint: str = default, strict: bool = True, **kwdargs) -> nn.Module: 34 | if pretrained: 35 | if checkpoints is None: 36 | raise RuntimeError("This model currently doesn't have pretrained weights available.") 37 | elif checkpoint is None: 38 | raise RuntimeError("No checkpoint specified.") 39 | elif checkpoint not in checkpoints: 40 | raise RuntimeError(f"Invalid checkpoint specified ({checkpoint}). Options are: {list(checkpoints.keys())}.") 41 | 42 | state_dict = torch.hub.load_state_dict_from_url(checkpoints[checkpoint], map_location="cpu") 43 | 44 | if "head.projection.weight" in state_dict["model_state"]: 45 | # Set the number of classes equal to the state_dict only if the user doesn't want to overwrite it 46 | if "num_classes" not in kwdargs: 47 | kwdargs["num_classes"] = state_dict["model_state"]["head.projection.weight"].shape[0] 48 | # If the user specified a different number of classes, remove the projection weights or else we'll error out 49 | elif kwdargs["num_classes"] != state_dict["model_state"]["head.projection.weight"].shape[0]: 50 | del state_dict["model_state"]["head.projection.weight"] 51 | del state_dict["model_state"]["head.projection.bias"] 52 | 53 | model = model_func(**kwdargs) 54 | if pretrained: 55 | # Disable being strict when trying to load a encoder-decoder model into an encoder-only model 56 | if "decoder_pos_embed" in state_dict["model_state"] and not hasattr(model, "decoder_pos_embed"): 57 | strict = False 58 | 59 | model.load_state_dict(state_dict["model_state"], strict=strict) 60 | 61 | return model 62 | 63 | # Keep some metadata so we can do things that require looping through all available models 64 | model_def.checkpoints = checkpoints 65 | model_def.default = default 66 | 67 | return model_def 68 | 69 | return inner 70 | 71 | 72 | 73 | def conv_nd(n: int) -> Type[nn.Module]: 74 | """ 75 | Returns a conv with nd (e.g., Conv2d for n=2). Work up to n=3. 76 | If you wanted a 4d Hiera, you could probably just implement this for n=4. (no promises) 77 | """ 78 | return [nn.Identity, nn.Conv1d, nn.Conv2d, nn.Conv3d][n] 79 | 80 | 81 | def do_pool(x: torch.Tensor, stride: int) -> torch.Tensor: 82 | # Refer to `Unroll` to see how this performs a maxpool-Nd 83 | return x.view(x.shape[0], stride, -1, x.shape[-1]).max(dim=1).values 84 | 85 | 86 | def get_resized_mask(target_size: torch.Size, mask: torch.Tensor) -> torch.Tensor: 87 | # target_size: [(T), (H), W] 88 | # (spatial) mask: [B, C, (t), (h), w] 89 | if mask is None: 90 | return mask 91 | 92 | assert len(mask.shape[2:]) == len(target_size) 93 | if mask.shape[2:] != target_size: 94 | return F.interpolate(mask.float(), size=target_size) 95 | return mask 96 | 97 | 98 | def do_masked_conv( 99 | x: torch.Tensor, conv: nn.Module, mask: Optional[torch.Tensor] = None 100 | ) -> torch.Tensor: 101 | """Zero-out the masked regions of the input before conv. 102 | Prevents leakage of masked regions when using overlapping kernels. 103 | """ 104 | if conv is None: 105 | return x 106 | if mask is None: 107 | return conv(x) 108 | 109 | mask = get_resized_mask(target_size=x.shape[2:], mask=mask) 110 | return conv(x * mask.bool()) 111 | 112 | 113 | def undo_windowing( 114 | x: torch.Tensor, shape: List[int], mu_shape: List[int] 115 | ) -> torch.Tensor: 116 | """ 117 | Restore spatial organization by undoing windowed organization of mask units. 118 | 119 | Args: 120 | x: organized by mask units windows, e.g. in 2d [B, #MUy*#MUx, MUy, MUx, C] 121 | shape: current spatial shape, if it were not organized into mask unit 122 | windows, e.g. in 2d [B, #MUy*MUy, #MUx*MUx, C]. 123 | mu_shape: current mask unit shape, e.g. in 2d [MUy, MUx] 124 | Returns: 125 | x: e.g. in 2d, [B, #MUy*MUy, #MUx*MUx, C] 126 | """ 127 | D = len(shape) 128 | B, C = x.shape[0], x.shape[-1] 129 | # [B, #MUy*#MUx, MUy, MUx, C] -> [B, #MUy, #MUx, MUy, MUx, C] 130 | num_MUs = [s // mu for s, mu in zip(shape, mu_shape)] 131 | x = x.view(B, *num_MUs, *mu_shape, C) 132 | 133 | # [B, #MUy, #MUx, MUy, MUx, C] -> [B, #MUy*MUy, #MUx*MUx, C] 134 | permute = ( 135 | [0] 136 | + sum( 137 | [list(p) for p in zip(range(1, 1 + D), range(1 + D, 1 + 2 * D))], 138 | [], 139 | ) 140 | + [len(x.shape) - 1] 141 | ) 142 | x = x.permute(permute).reshape(B, *shape, C) 143 | 144 | return x 145 | 146 | 147 | 148 | class Unroll(nn.Module): 149 | """ 150 | Reorders the tokens such that patches are contiguous in memory. 151 | E.g., given [B, (H, W), C] and stride of (Sy, Sx), this will re-order the tokens as 152 | [B, (Sy, Sx, H // Sy, W // Sx), C] 153 | 154 | This allows operations like Max2d to be computed as x.view(B, Sx*Sy, -1, C).max(dim=1). 155 | Not only is this faster, but it also makes it easy to support inputs of arbitrary 156 | dimensions in addition to patch-wise sparsity. 157 | 158 | Performing this operation multiple times in sequence puts entire windows as contiguous 159 | in memory. For instance, if you applied the stride (2, 2) 3 times, entire windows of 160 | size 8x8 would be contiguous in memory, allowing operations like mask unit attention 161 | computed easily and efficiently, while also allowing max to be applied sequentially. 162 | 163 | Note: This means that intermediate values of the model are not in HxW order, so they 164 | need to be re-rolled if you want to use the intermediate values as a HxW feature map. 165 | The last block of the network is fine though, since by then the strides are all consumed. 166 | """ 167 | 168 | def __init__( 169 | self, 170 | input_size: Tuple[int, ...], 171 | patch_stride: Tuple[int, ...], 172 | unroll_schedule: List[Tuple[int, ...]], 173 | ): 174 | super().__init__() 175 | self.size = [i // s for i, s in zip(input_size, patch_stride)] 176 | self.schedule = unroll_schedule 177 | 178 | def forward(self, x: torch.Tensor) -> torch.Tensor: 179 | """ 180 | Input: Flattened patch embeddings [B, N, C] 181 | Output: Patch embeddings [B, N, C] permuted such that [B, 4, N//4, C].max(1) etc. performs MaxPoolNd 182 | """ 183 | B, _, C = x.shape 184 | 185 | cur_size = self.size 186 | x = x.view(*([B] + cur_size + [C])) 187 | 188 | for strides in self.schedule: 189 | # Move patches with the given strides to the batch dimension 190 | 191 | # Create a view of the tensor with the patch stride as separate dims 192 | # For example in 2d: [B, H // Sy, Sy, W // Sx, Sx, C] 193 | cur_size = [i // s for i, s in zip(cur_size, strides)] 194 | new_shape = [B] + sum([[i, s] for i, s in zip(cur_size, strides)], []) + [C] 195 | x = x.view(new_shape) 196 | 197 | # Move the patch stride into the batch dimension 198 | # For example in 2d: [B, Sy, Sx, H // Sy, W // Sx, C] 199 | L = len(new_shape) 200 | permute = ( 201 | [0] + list(range(2, L - 1, 2)) + list(range(1, L - 1, 2)) + [L - 1] 202 | ) 203 | x = x.permute(permute) 204 | 205 | # Now finally flatten the relevant dims into the batch dimension 206 | x = x.flatten(0, len(strides)) 207 | B *= math.prod(strides) 208 | 209 | x = x.reshape(-1, math.prod(self.size), C) 210 | return x 211 | 212 | 213 | class Reroll(nn.Module): 214 | """ 215 | Undos the "unroll" operation so that you can use intermediate features. 216 | """ 217 | 218 | def __init__( 219 | self, 220 | input_size: Tuple[int, ...], 221 | patch_stride: Tuple[int, ...], 222 | unroll_schedule: List[Tuple[int, ...]], 223 | stage_ends: List[int], 224 | q_pool: int, 225 | ): 226 | super().__init__() 227 | self.size = [i // s for i, s in zip(input_size, patch_stride)] 228 | 229 | # The first stage has to reverse everything 230 | # The next stage has to reverse all but the first unroll, etc. 231 | self.schedule = {} 232 | size = self.size 233 | for i in range(stage_ends[-1] + 1): 234 | self.schedule[i] = unroll_schedule, size 235 | # schedule unchanged if no pooling at a stage end 236 | if i in stage_ends[:q_pool]: 237 | if len(unroll_schedule) > 0: 238 | size = [n // s for n, s in zip(size, unroll_schedule[0])] 239 | unroll_schedule = unroll_schedule[1:] 240 | 241 | def forward( 242 | self, x: torch.Tensor, block_idx: int, mask: torch.Tensor = None 243 | ) -> torch.Tensor: 244 | """ 245 | Roll the given tensor back up to spatial order assuming it's from the given block. 246 | 247 | If no mask is provided: 248 | - Returns [B, H, W, C] for 2d, [B, T, H, W, C] for 3d, etc. 249 | If a mask is provided: 250 | - Returns [B, #MUs, MUy, MUx, C] for 2d, etc. 251 | """ 252 | schedule, size = self.schedule[block_idx] 253 | B, N, C = x.shape 254 | 255 | D = len(size) 256 | cur_mu_shape = [1] * D 257 | 258 | for strides in schedule: 259 | # Extract the current patch from N 260 | x = x.view(B, *strides, N // math.prod(strides), *cur_mu_shape, C) 261 | 262 | # Move that patch into the current MU 263 | # Example in 2d: [B, Sy, Sx, N//(Sy*Sx), MUy, MUx, C] -> [B, N//(Sy*Sx), Sy, MUy, Sx, MUx, C] 264 | L = len(x.shape) 265 | permute = ( 266 | [0, 1 + D] 267 | + sum( 268 | [list(p) for p in zip(range(1, 1 + D), range(1 + D + 1, L - 1))], 269 | [], 270 | ) 271 | + [L - 1] 272 | ) 273 | x = x.permute(permute) 274 | 275 | # Reshape to [B, N//(Sy*Sx), *MU, C] 276 | for i in range(D): 277 | cur_mu_shape[i] *= strides[i] 278 | x = x.reshape(B, -1, *cur_mu_shape, C) 279 | N = x.shape[1] 280 | 281 | # Current shape (e.g., 2d: [B, #MUy*#MUx, MUy, MUx, C]) 282 | x = x.view(B, N, *cur_mu_shape, C) 283 | 284 | # If masked, return [B, #MUs, MUy, MUx, C] 285 | if mask is not None: 286 | return x 287 | 288 | # If not masked, we can return [B, H, W, C] 289 | x = undo_windowing(x, size, cur_mu_shape) 290 | 291 | return x 292 | -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | 8 | dependencies = ["torch", "timm"] 9 | 10 | from hiera import ( 11 | hiera_tiny_224, 12 | hiera_small_224, 13 | hiera_base_224, 14 | hiera_base_plus_224, 15 | hiera_large_224, 16 | hiera_huge_224, 17 | 18 | hiera_base_16x224, 19 | hiera_base_plus_16x224, 20 | hiera_large_16x224, 21 | hiera_huge_16x224, 22 | 23 | mae_hiera_tiny_224, 24 | mae_hiera_small_224, 25 | mae_hiera_base_224, 26 | mae_hiera_base_plus_224, 27 | mae_hiera_large_224, 28 | mae_hiera_huge_224, 29 | ) -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | 8 | from setuptools import find_packages, setup 9 | 10 | setup( 11 | name="hiera-transformer", 12 | version="0.1.4", 13 | author="Chaitanya Ryali, Daniel Bolya", 14 | url="https://github.com/facebookresearch/hiera", 15 | description="A fast, powerful, and simple hierarchical vision transformer", 16 | install_requires=["torch>=1.8.1", "timm>=0.4.12", "tqdm", "packaging"], 17 | packages=find_packages(exclude=("examples", "build")), 18 | license = 'Apache 2.0', 19 | long_description=open("README.md", "r", encoding="utf-8").read(), 20 | long_description_content_type="text/markdown", 21 | python_requires=">=3.8.0", 22 | classifiers=[ 23 | "Intended Audience :: Developers", 24 | "Intended Audience :: Education", 25 | "Intended Audience :: Science/Research", 26 | "License :: OSI Approved :: Apache Software License", 27 | "Operating System :: OS Independent", 28 | "Programming Language :: Python :: 3", 29 | "Programming Language :: Python :: 3 :: Only", 30 | "Programming Language :: Python :: 3.8", 31 | "Programming Language :: Python :: 3.9", 32 | "Programming Language :: Python :: 3.10", 33 | "Programming Language :: Python :: 3.11", 34 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 35 | ], 36 | 37 | ) 38 | --------------------------------------------------------------------------------