├── .gitignore ├── LICENSE ├── README.md ├── assets ├── example_images │ ├── arkit_depth.png │ └── image.jpg └── teaser.gif ├── promptda ├── main.py ├── model │ ├── blocks.py │ ├── config.py │ └── dpt.py ├── promptda.py ├── scripts │ ├── generate_video.py │ ├── infer_stray_scan.py │ └── sanity_check.py └── utils │ ├── depth_utils.py │ ├── io_wrapper.py │ ├── logger.py │ └── parallel_utils.py ├── requirements.txt ├── setup.py └── torchhub ├── README.md └── facebookresearch_dinov2_main ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── MODEL_CARD.md ├── README.md ├── conda.yaml ├── dinov2 ├── __init__.py ├── configs │ ├── __init__.py │ ├── eval │ │ ├── vitb14_pretrain.yaml │ │ ├── vitg14_pretrain.yaml │ │ ├── vitl14_pretrain.yaml │ │ └── vits14_pretrain.yaml │ ├── ssl_default_config.yaml │ └── train │ │ ├── vitg14.yaml │ │ ├── vitl14.yaml │ │ └── vitl16_short.yaml ├── distributed │ └── __init__.py ├── eval │ ├── __init__.py │ ├── knn.py │ ├── linear.py │ ├── log_regression.py │ ├── metrics.py │ ├── setup.py │ └── utils.py ├── fsdp │ └── __init__.py ├── layers │ ├── __init__.py │ ├── attention.py │ ├── block.py │ ├── dino_head.py │ ├── drop_path.py │ ├── layer_scale.py │ ├── mlp.py │ ├── patch_embed.py │ └── swiglu_ffn.py ├── logging │ ├── __init__.py │ └── helpers.py ├── loss │ ├── __init__.py │ ├── dino_clstoken_loss.py │ ├── ibot_patch_loss.py │ └── koleo_loss.py ├── models │ ├── __init__.py │ └── vision_transformer.py ├── run │ ├── __init__.py │ ├── eval │ │ ├── knn.py │ │ ├── linear.py │ │ └── log_regression.py │ ├── submit.py │ └── train │ │ └── train.py ├── train │ ├── __init__.py │ ├── ssl_meta_arch.py │ └── train.py └── utils │ ├── __init__.py │ ├── cluster.py │ ├── config.py │ ├── dtype.py │ ├── param_groups.py │ └── utils.py ├── hubconf.py ├── pyproject.toml ├── scripts └── lint.sh ├── setup.cfg ├── setup.py ├── utils.py └── vision_transformer.py /.gitignore: -------------------------------------------------------------------------------- 1 | # delete files larger than 10MiB 2 | **basicModel_neutral_lbs_10_207_0_v1.0.0.pkl 3 | 4 | .vscode 5 | .hydra 6 | *.txt 7 | # All file or folders start with tmp will be ignored 8 | tmp* 9 | 10 | # Byte-compiled / optimized / DLL files 11 | __pycache__/ 12 | *.py[cod] 13 | *$py.class 14 | 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | results/ 22 | checkpoints/ 23 | data/ 24 | develop-eggs/ 25 | dist/ 26 | downloads/ 27 | eggs/ 28 | .eggs/ 29 | lib64/ 30 | parts/ 31 | sdist/ 32 | var/ 33 | wheels/ 34 | share/python-wheels/ 35 | *.egg-info/ 36 | .installed.cfg 37 | *.egg 38 | MANIFEST 39 | 40 | # PyInstaller 41 | # Usually these files are written by a python script from a template 42 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 43 | *.manifest 44 | *.spec 45 | 46 | # Installer logs 47 | pip-log.txt 48 | pip-delete-this-directory.txt 49 | 50 | # Unit test / coverage reports 51 | htmlcov/ 52 | .tox/ 53 | .nox/ 54 | .coverage 55 | .coverage.* 56 | .cache 57 | nosetests.xml 58 | coverage.xml 59 | *.cover 60 | *.py,cover 61 | .hypothesis/ 62 | .pytest_cache/ 63 | cover/ 64 | 65 | # Translations 66 | *.mo 67 | *.pot 68 | 69 | # Django stuff: 70 | *.log 71 | local_settings.py 72 | db.sqlite3 73 | db.sqlite3-journal 74 | 75 | # Flask stuff: 76 | instance/ 77 | .webassets-cache 78 | 79 | # Scrapy stuff: 80 | .scrapy 81 | 82 | # Sphinx documentation 83 | docs/_build/ 84 | 85 | # PyBuilder 86 | .pybuilder/ 87 | target/ 88 | 89 | # Jupyter Notebook 90 | .ipynb_checkpoints 91 | 92 | # IPython 93 | profile_default/ 94 | ipython_config.py 95 | 96 | # pyenv 97 | # For a library or package, you might want to ignore these files since the code is 98 | # intended to run in multiple environments; otherwise, check them in: 99 | # .python-version 100 | 101 | # pipenv 102 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 103 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 104 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 105 | # install all needed dependencies. 106 | #Pipfile.lock 107 | 108 | # poetry 109 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 110 | # This is especially recommended for binary packages to ensure reproducibility, and is more 111 | # commonly ignored for libraries. 112 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 113 | #poetry.lock 114 | 115 | # pdm 116 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 117 | #pdm.lock 118 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 119 | # in version control. 120 | # https://pdm.fming.dev/#use-with-ide 121 | .pdm.toml 122 | 123 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 124 | __pypackages__/ 125 | 126 | # Celery stuff 127 | celerybeat-schedule 128 | celerybeat.pid 129 | 130 | # SageMath parsed files 131 | *.sage.py 132 | 133 | # Environments 134 | .env 135 | .venv 136 | env/ 137 | venv/ 138 | ENV/ 139 | env.bak/ 140 | venv.bak/ 141 | 142 | # Spyder project settings 143 | .spyderproject 144 | .spyproject 145 | 146 | # Rope project settings 147 | .ropeproject 148 | 149 | # mkdocs documentation 150 | /site 151 | 152 | # mypy 153 | .mypy_cache/ 154 | .dmypy.json 155 | dmypy.json 156 | 157 | # Pyre type checker 158 | .pyre/ 159 | 160 | # pytype static type analyzer 161 | .pytype/ 162 | 163 | # Cython debug symbols 164 | cython_debug/ 165 | 166 | # 167 | .DS_Store/ 168 | 169 | # PyCharm 170 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 171 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 172 | # and can be added to the global gitignore or merged into this file. For a more nuclear 173 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 174 | #.idea/ 175 | 176 | # torchsparse 177 | torchsparse 178 | 179 | # tensorboard 180 | tensorboard 181 | 182 | # glove 183 | glove 184 | *.jpg 185 | *.png 186 | *.mp4 187 | 188 | *.ipynb 189 | *.ply 190 | 3rdparty 191 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Prompting Depth Anything for 4K Resolution Accurate Metric Depth Estimation 2 | ### [Project Page](https://promptda.github.io/) | [Paper](https://promptda.github.io/assets/main_paper_with_supp.pdf) | [Hugging Face Demo](https://huggingface.co/spaces/depth-anything/PromptDA) | [Interactive Results](https://promptda.github.io/interactive.html) | [Data](https://promptda.github.io/) 3 | 4 | > Prompting Depth Anything for 4K Resolution Accurate Metric Depth Estimation 5 | > [Haotong Lin](https://haotongl.github.io/), 6 | [Sida Peng](https://pengsida.net/), 7 | [Jingxiao Chen](https://scholar.google.com/citations?user=-zs1V28AAAAJ), 8 | [Songyou Peng](https://pengsongyou.github.io/), 9 | [Jiaming Sun](https://jiamingsun.me/), 10 | [Minghuan Liu](https://minghuanliu.com/), 11 | [Hujun Bao](http://www.cad.zju.edu.cn/home/bao/), 12 | [Jiashi Feng](https://scholar.google.com/citations?user=Q8iay0gAAAAJ), 13 | [Xiaowei Zhou](https://www.xzhou.me/), 14 | [Bingyi Kang](https://bingykang.github.io/) 15 | > CVPR 2025 16 | 17 | ![teaser](assets/teaser.gif) 18 | 19 | 20 | ## 🛠️ Installation 21 | 22 |
Setting up the environment 23 | 24 | ```bash 25 | git clone https://github.com/DepthAnything/PromptDA.git 26 | cd PromptDA 27 | pip install -r requirements.txt 28 | pip install -e . 29 | sudo apt install ffmpeg # for video generation 30 | ``` 31 |
32 |
Pre-trained Models 33 | 34 | | Model | Params | Checkpoint | 35 | |:-|-:|:-:| 36 | | Prompt-Depth-Anything-Large | 340M | [Download](https://huggingface.co/depth-anything/prompt-depth-anything-vitl/resolve/main/model.ckpt) | 37 | | Prompt-Depth-Anything-Small | 25.1M | [Download](https://huggingface.co/depth-anything/prompt-depth-anything-vits/resolve/main/model.ckpt) | 38 | | Prompt-Depth-Anything-Small-Transparent | 25.1M | [Download](https://huggingface.co/depth-anything/prompt-depth-anything-vits-transparent/resolve/main/model.ckpt) | 39 | 40 | Only Prompt-Depth-Anything-Large is used to benchmark in our paper. Prompt-Depth-Anything-Small-Transparent is further fine-tuned 10K steps with [hammer dataset](https://github.com/Junggy/HAMMER-dataset) with our iPhone lidar simulation method to improve the performance on transparent objects. 41 | 42 |
43 | 44 | 45 | ## 🚀 Usage 46 |
Example usage 47 | 48 | ```python 49 | from promptda.promptda import PromptDA 50 | from promptda.utils.io_wrapper import load_image, load_depth, save_depth 51 | 52 | DEVICE = 'cuda' 53 | image_path = "assets/example_images/image.jpg" 54 | prompt_depth_path = "assets/example_images/arkit_depth.png" 55 | image = load_image(image_path).to(DEVICE) 56 | prompt_depth = load_depth(prompt_depth_path).to(DEVICE) # 192x256, ARKit LiDAR depth in meters 57 | 58 | model = PromptDA.from_pretrained("depth-anything/prompt-depth-anything-vitl").to(DEVICE).eval() 59 | depth = model.predict(image, prompt_depth) # HxW, depth in meters 60 | 61 | save_depth(depth, prompt_depth=prompt_depth, image=image) 62 | ``` 63 |
64 | 65 | 66 | ## 📸 Running on your own capture 67 | 68 | You can use [Stray Scanner App](https://apps.apple.com/us/app/stray-scanner/id1557051662) to capture your own data, which requires iPhone 12 Pro or later Pro models, iPad 2020 Pro or later Pro models. We setup a [Hugging Face Space](https://huggingface.co/spaces/depth-anything/PromptDA) for you to quickly test our model. If you want to obtain video results, please follow the following steps. 69 | 70 |
Testing steps 71 | 72 | 1. Capture a scene with the Stray Scanner App. (The charging port is preferred to face downward or to the right.) 73 | 2. Use the iPhone Files App to compress it into a zip file and transfer it to your computer. Here is an [example screen recording](https://haotongl.github.io/promptda/assets/ScreenRecording_12-16-2024.mp4). 74 | 3. Run the following commands to infer our model and generate the video results. 75 | ```bash 76 | export PATH_TO_ZIP_FILE=data/8b98276b0a.zip # Replace with your own zip file path 77 | export PATH_TO_SAVE_FOLDER=data/8b98276b0a_results # Replace with your own save folder path 78 | python3 -m promptda.scripts.infer_stray_scan --input_path ${PATH_TO_ZIP_FILE} --output_path ${PATH_TO_SAVE_FOLDER} 79 | python3 -m promptda.scripts.generate_video process_stray_scan --input_path ${PATH_TO_ZIP_FILE} --result_path ${PATH_TO_SAVE_FOLDER} 80 | ffmpeg -framerate 60 -i ${PATH_TO_SAVE_FOLDER}/%06d_smooth.jpg -c:v libx264 -pix_fmt yuv420p ${PATH_TO_SAVE_FOLDER}.mp4 81 | ``` 82 |
83 | 84 | 85 | ## 👏 Acknowledgements 86 | We thank the generous support from Prof. [Weinan Zhang](https://wnzhang.net/) for robot experiments, including the space, objects and the Unitree H1 robot. We also thank [Zhengbang Zhu](https://scholar.google.com/citations?user=ozatRA0AAAAJ), Jiahang Cao, Xinyao Li, Wentao Dong for their help in setting up the robot platform and collecting robot data. 87 | 88 | ## 📚 Citation 89 | If you find this code useful for your research, please use the following BibTeX entry 90 | ``` 91 | @inproceedings{lin2024promptda, 92 | title={Prompting Depth Anything for 4K Resolution Accurate Metric Depth Estimation}, 93 | author={Lin, Haotong and Peng, Sida and Chen, Jingxiao and Peng, Songyou and Sun, Jiaming and Liu, Minghuan and Bao, Hujun and Feng, Jiashi and Zhou, Xiaowei and Kang, Bingyi}, 94 | journal={arXiv}, 95 | year={2024} 96 | } 97 | ``` 98 | -------------------------------------------------------------------------------- /assets/example_images/arkit_depth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DepthAnything/PromptDA/d2b3560c66dfe6ab401dce18a330c00655827b83/assets/example_images/arkit_depth.png -------------------------------------------------------------------------------- /assets/example_images/image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DepthAnything/PromptDA/d2b3560c66dfe6ab401dce18a330c00655827b83/assets/example_images/image.jpg -------------------------------------------------------------------------------- /assets/teaser.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DepthAnything/PromptDA/d2b3560c66dfe6ab401dce18a330c00655827b83/assets/teaser.gif -------------------------------------------------------------------------------- /promptda/main.py: -------------------------------------------------------------------------------- 1 | from promptda.utils.logger import Log 2 | 3 | 4 | def main(): 5 | pass 6 | 7 | 8 | if __name__ == "__main__": 9 | main() 10 | -------------------------------------------------------------------------------- /promptda/model/blocks.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | from promptda.utils.logger import Log 5 | import os 6 | import numpy as np 7 | 8 | 9 | def _make_fusion_block(features, use_bn, size=None): 10 | return FeatureFusionDepthBlock( 11 | features, 12 | nn.ReLU(False), 13 | deconv=False, 14 | bn=use_bn, 15 | expand=False, 16 | align_corners=True, 17 | size=size, 18 | ) 19 | 20 | 21 | def _make_scratch(in_shape, out_shape, groups=1, expand=False): 22 | scratch = nn.Module() 23 | 24 | out_shape1 = out_shape 25 | out_shape2 = out_shape 26 | out_shape3 = out_shape 27 | if len(in_shape) >= 4: 28 | out_shape4 = out_shape 29 | 30 | if expand: 31 | out_shape1 = out_shape 32 | out_shape2 = out_shape*2 33 | out_shape3 = out_shape*4 34 | if len(in_shape) >= 4: 35 | out_shape4 = out_shape*8 36 | 37 | scratch.layer1_rn = nn.Conv2d( 38 | in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 39 | ) 40 | scratch.layer2_rn = nn.Conv2d( 41 | in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 42 | ) 43 | scratch.layer3_rn = nn.Conv2d( 44 | in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 45 | ) 46 | if len(in_shape) >= 4: 47 | scratch.layer4_rn = nn.Conv2d( 48 | in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 49 | ) 50 | 51 | return scratch 52 | 53 | 54 | class ResidualConvUnit(nn.Module): 55 | """Residual convolution module. 56 | """ 57 | 58 | def __init__(self, features, activation, bn): 59 | """Init. 60 | 61 | Args: 62 | features (int): number of features 63 | """ 64 | super().__init__() 65 | 66 | self.bn = bn 67 | 68 | self.groups = 1 69 | 70 | self.conv1 = nn.Conv2d( 71 | features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups 72 | ) 73 | 74 | self.conv2 = nn.Conv2d( 75 | features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups 76 | ) 77 | 78 | if self.bn == True: 79 | self.bn1 = nn.BatchNorm2d(features) 80 | self.bn2 = nn.BatchNorm2d(features) 81 | 82 | self.activation = activation 83 | 84 | self.skip_add = nn.quantized.FloatFunctional() 85 | 86 | def forward(self, x): 87 | """Forward pass. 88 | 89 | Args: 90 | x (tensor): input 91 | 92 | Returns: 93 | tensor: output 94 | """ 95 | 96 | out = self.activation(x) 97 | out = self.conv1(out) 98 | if self.bn == True: 99 | out = self.bn1(out) 100 | 101 | out = self.activation(out) 102 | out = self.conv2(out) 103 | if self.bn == True: 104 | out = self.bn2(out) 105 | 106 | if self.groups > 1: 107 | out = self.conv_merge(out) 108 | 109 | return self.skip_add.add(out, x) 110 | 111 | 112 | class FeatureFusionBlock(nn.Module): 113 | """Feature fusion block. 114 | """ 115 | 116 | def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, size=None): 117 | """Init. 118 | 119 | Args: 120 | features (int): number of features 121 | """ 122 | super(FeatureFusionBlock, self).__init__() 123 | 124 | self.deconv = deconv 125 | self.align_corners = align_corners 126 | 127 | self.groups = 1 128 | 129 | self.expand = expand 130 | out_features = features 131 | if self.expand == True: 132 | out_features = features//2 133 | 134 | self.out_conv = nn.Conv2d( 135 | features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) 136 | 137 | self.resConfUnit1 = ResidualConvUnit(features, activation, bn) 138 | self.resConfUnit2 = ResidualConvUnit(features, activation, bn) 139 | 140 | self.skip_add = nn.quantized.FloatFunctional() 141 | 142 | self.size = size 143 | 144 | def forward(self, *xs, size=None): 145 | """Forward pass. 146 | 147 | Returns: 148 | tensor: output 149 | """ 150 | output = xs[0] 151 | 152 | if len(xs) == 2: 153 | res = self.resConfUnit1(xs[1]) 154 | output = self.skip_add.add(output, res) 155 | 156 | output = self.resConfUnit2(output) 157 | 158 | if (size is None) and (self.size is None): 159 | modifier = {"scale_factor": 2} 160 | elif size is None: 161 | modifier = {"size": self.size} 162 | else: 163 | modifier = {"size": size} 164 | 165 | output = nn.functional.interpolate( 166 | output, **modifier, mode="bilinear", align_corners=self.align_corners 167 | ) 168 | 169 | output = self.out_conv(output) 170 | 171 | return output 172 | 173 | 174 | class FeatureFusionControlBlock(FeatureFusionBlock): 175 | """Feature fusion block. 176 | """ 177 | 178 | def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, size=None): 179 | """Init. 180 | 181 | Args: 182 | features (int): number of features 183 | """ 184 | super.__init__(features, activation, deconv, 185 | bn, expand, align_corners, size) 186 | self.copy_block = FeatureFusionBlock( 187 | features, activation, deconv, bn, expand, align_corners, size) 188 | 189 | def forward(self, *xs, size=None): 190 | """Forward pass. 191 | 192 | Returns: 193 | tensor: output 194 | """ 195 | output = xs[0] 196 | 197 | if len(xs) == 2: 198 | res = self.resConfUnit1(xs[1]) 199 | output = self.skip_add.add(output, res) 200 | 201 | output = self.resConfUnit2(output) 202 | 203 | if (size is None) and (self.size is None): 204 | modifier = {"scale_factor": 2} 205 | elif size is None: 206 | modifier = {"size": self.size} 207 | else: 208 | modifier = {"size": size} 209 | 210 | output = nn.functional.interpolate( 211 | output, **modifier, mode="bilinear", align_corners=self.align_corners 212 | ) 213 | 214 | output = self.out_conv(output) 215 | 216 | return output 217 | 218 | 219 | def zero_module(module): 220 | """ 221 | Zero out the parameters of a module and return it. 222 | """ 223 | for p in module.parameters(): 224 | p.detach().zero_() 225 | return module 226 | 227 | 228 | class FeatureFusionDepthBlock(nn.Module): 229 | """Feature fusion block. 230 | """ 231 | 232 | def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, size=None): 233 | """Init. 234 | 235 | Args: 236 | features (int): number of features 237 | """ 238 | super(FeatureFusionDepthBlock, self).__init__() 239 | 240 | self.deconv = deconv 241 | self.align_corners = align_corners 242 | 243 | self.groups = 1 244 | 245 | self.expand = expand 246 | out_features = features 247 | if self.expand == True: 248 | out_features = features//2 249 | 250 | self.out_conv = nn.Conv2d( 251 | features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) 252 | 253 | self.resConfUnit1 = ResidualConvUnit(features, activation, bn) 254 | self.resConfUnit2 = ResidualConvUnit(features, activation, bn) 255 | self.resConfUnit_depth = nn.Sequential( 256 | nn.Conv2d(1, features, kernel_size=3, stride=1, 257 | padding=1, bias=True, groups=1), 258 | activation, 259 | nn.Conv2d(features, features, kernel_size=3, 260 | stride=1, padding=1, bias=True, groups=1), 261 | activation, 262 | zero_module( 263 | nn.Conv2d(features, features, kernel_size=3, 264 | stride=1, padding=1, bias=True, groups=1) 265 | ) 266 | ) 267 | self.skip_add = nn.quantized.FloatFunctional() 268 | self.size = size 269 | 270 | def forward(self, *xs, prompt_depth=None, size=None): 271 | """Forward pass. 272 | 273 | Returns: 274 | tensor: output 275 | """ 276 | output = xs[0] 277 | 278 | if len(xs) == 2: 279 | res = self.resConfUnit1(xs[1]) 280 | output = self.skip_add.add(output, res) 281 | 282 | output = self.resConfUnit2(output) 283 | 284 | if prompt_depth is not None: 285 | prompt_depth = F.interpolate( 286 | prompt_depth, output.shape[2:], mode='bilinear', align_corners=False) 287 | res = self.resConfUnit_depth(prompt_depth) 288 | output = self.skip_add.add(output, res) 289 | 290 | if (size is None) and (self.size is None): 291 | modifier = {"scale_factor": 2} 292 | elif size is None: 293 | modifier = {"size": self.size} 294 | else: 295 | modifier = {"size": size} 296 | 297 | output = nn.functional.interpolate( 298 | output, **modifier, mode="bilinear", align_corners=self.align_corners 299 | ) 300 | 301 | output = self.out_conv(output) 302 | 303 | return output 304 | -------------------------------------------------------------------------------- /promptda/model/config.py: -------------------------------------------------------------------------------- 1 | model_configs = { 2 | 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384], 'layer_idxs': [2, 5, 8, 11]}, 3 | 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768], 'layer_idxs': [2, 5, 8, 11]}, 4 | 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024], 'layer_idxs': [4, 11, 17, 23]}, 5 | 'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536], 'layer_idxs': [9, 19, 29, 39]} 6 | } 7 | -------------------------------------------------------------------------------- /promptda/model/dpt.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, Depth Anything V2 2 | # https://github.com/DepthAnything/Depth-Anything-V2/blob/main/depth_anything_v2/dpt.py 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from promptda.model.blocks import _make_scratch, _make_fusion_block 7 | 8 | 9 | class DPTHead(nn.Module): 10 | def __init__(self, 11 | nclass, 12 | in_channels, 13 | features=256, 14 | out_channels=[256, 512, 1024, 1024], 15 | use_bn=False, 16 | use_clstoken=False, 17 | output_act='sigmoid'): 18 | super(DPTHead, self).__init__() 19 | 20 | self.nclass = nclass 21 | self.use_clstoken = use_clstoken 22 | 23 | self.projects = nn.ModuleList([ 24 | nn.Conv2d( 25 | in_channels=in_channels, 26 | out_channels=out_channel, 27 | kernel_size=1, 28 | stride=1, 29 | padding=0, 30 | ) for out_channel in out_channels 31 | ]) 32 | 33 | self.resize_layers = nn.ModuleList([ 34 | nn.ConvTranspose2d( 35 | in_channels=out_channels[0], 36 | out_channels=out_channels[0], 37 | kernel_size=4, 38 | stride=4, 39 | padding=0), 40 | nn.ConvTranspose2d( 41 | in_channels=out_channels[1], 42 | out_channels=out_channels[1], 43 | kernel_size=2, 44 | stride=2, 45 | padding=0), 46 | nn.Identity(), 47 | nn.Conv2d( 48 | in_channels=out_channels[3], 49 | out_channels=out_channels[3], 50 | kernel_size=3, 51 | stride=2, 52 | padding=1) 53 | ]) 54 | 55 | if use_clstoken: 56 | self.readout_projects = nn.ModuleList() 57 | for _ in range(len(self.projects)): 58 | self.readout_projects.append( 59 | nn.Sequential( 60 | nn.Linear(2 * in_channels, in_channels), 61 | nn.GELU())) 62 | 63 | self.scratch = _make_scratch( 64 | out_channels, 65 | features, 66 | groups=1, 67 | expand=False, 68 | ) 69 | 70 | self.scratch.stem_transpose = None 71 | 72 | self.scratch.refinenet1 = _make_fusion_block( 73 | features, use_bn) 74 | self.scratch.refinenet2 = _make_fusion_block( 75 | features, use_bn) 76 | self.scratch.refinenet3 = _make_fusion_block( 77 | features, use_bn) 78 | self.scratch.refinenet4 = _make_fusion_block( 79 | features, use_bn) 80 | 81 | head_features_1 = features 82 | head_features_2 = 32 83 | 84 | act_func = nn.Sigmoid() if output_act == 'sigmoid' else nn.Identity() 85 | 86 | if nclass > 1: 87 | self.scratch.output_conv = nn.Sequential( 88 | nn.Conv2d(head_features_1, head_features_1, 89 | kernel_size=3, stride=1, padding=1), 90 | nn.ReLU(True), 91 | nn.Conv2d(head_features_1, nclass, 92 | kernel_size=1, stride=1, padding=0), 93 | ) 94 | else: 95 | self.scratch.output_conv1 = nn.Conv2d( 96 | head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1) 97 | 98 | self.scratch.output_conv2 = nn.Sequential( 99 | nn.Conv2d(head_features_1 // 2, head_features_2, 100 | kernel_size=3, stride=1, padding=1), 101 | nn.ReLU(True), 102 | nn.Conv2d(head_features_2, 1, kernel_size=1, 103 | stride=1, padding=0), 104 | act_func, 105 | ) 106 | 107 | def forward(self, out_features, patch_h, patch_w, prompt_depth=None): 108 | out = [] 109 | for i, x in enumerate(out_features): 110 | if self.use_clstoken: 111 | x, cls_token = x[0], x[1] 112 | readout = cls_token.unsqueeze(1).expand_as(x) 113 | x = self.readout_projects[i](torch.cat((x, readout), -1)) 114 | else: 115 | x = x[0] 116 | 117 | x = x.permute(0, 2, 1).reshape( 118 | (x.shape[0], x.shape[-1], patch_h, patch_w)) 119 | 120 | x = self.projects[i](x) 121 | x = self.resize_layers[i](x) 122 | 123 | out.append(x) 124 | 125 | layer_1, layer_2, layer_3, layer_4 = out 126 | 127 | layer_1_rn = self.scratch.layer1_rn(layer_1) 128 | layer_2_rn = self.scratch.layer2_rn(layer_2) 129 | layer_3_rn = self.scratch.layer3_rn(layer_3) 130 | layer_4_rn = self.scratch.layer4_rn(layer_4) 131 | 132 | path_4 = self.scratch.refinenet4( 133 | layer_4_rn, size=layer_3_rn.shape[2:], prompt_depth=prompt_depth) 134 | path_3 = self.scratch.refinenet3( 135 | path_4, layer_3_rn, size=layer_2_rn.shape[2:], prompt_depth=prompt_depth) 136 | path_2 = self.scratch.refinenet2( 137 | path_3, layer_2_rn, size=layer_1_rn.shape[2:], prompt_depth=prompt_depth) 138 | path_1 = self.scratch.refinenet1( 139 | path_2, layer_1_rn, prompt_depth=prompt_depth) 140 | out = self.scratch.output_conv1(path_1) 141 | out_feat = F.interpolate( 142 | out, (int(patch_h * 14), int(patch_w * 14)), 143 | mode="bilinear", align_corners=True) 144 | out = self.scratch.output_conv2(out_feat) 145 | return out 146 | -------------------------------------------------------------------------------- /promptda/promptda.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from promptda.model.dpt import DPTHead 4 | from promptda.model.config import model_configs 5 | from promptda.utils.logger import Log 6 | import os 7 | from pathlib import Path 8 | from huggingface_hub import hf_hub_download 9 | 10 | 11 | class PromptDA(nn.Module): 12 | patch_size = 14 # patch size of the pretrained dinov2 model 13 | use_bn = False 14 | use_clstoken = False 15 | output_act = 'sigmoid' 16 | 17 | def __init__(self, 18 | encoder='vitl', 19 | ckpt_path='checkpoints/promptda_vitl.ckpt'): 20 | super().__init__() 21 | model_config = model_configs[encoder] 22 | 23 | self.encoder = encoder 24 | self.model_config = model_config 25 | module_path = Path(__file__) # From anywhere else: module_path = Path(inspect.getfile(PromptDA)) 26 | package_base_dir = str(Path(*module_path.parts[:-2])) # extract path to PromptDA module, then resolve to repo base dir for dinov2 load 27 | self.pretrained = torch.hub.load( 28 | f'{package_base_dir}/torchhub/facebookresearch_dinov2_main', 29 | 'dinov2_{:}14'.format(encoder), 30 | source='local', 31 | pretrained=False) 32 | dim = self.pretrained.blocks[0].attn.qkv.in_features 33 | self.depth_head = DPTHead(nclass=1, 34 | in_channels=dim, 35 | features=model_config['features'], 36 | out_channels=model_config['out_channels'], 37 | use_bn=self.use_bn, 38 | use_clstoken=self.use_clstoken, 39 | output_act=self.output_act) 40 | 41 | # mean and std of the pretrained dinov2 model 42 | self.register_buffer('_mean', torch.tensor( 43 | [0.485, 0.456, 0.406]).view(1, 3, 1, 1)) 44 | self.register_buffer('_std', torch.tensor( 45 | [0.229, 0.224, 0.225]).view(1, 3, 1, 1)) 46 | 47 | self.load_checkpoint(ckpt_path) 48 | 49 | @classmethod 50 | def from_pretrained(cls, pretrained_model_name_or_path = None, model_kwargs = None, **hf_kwargs): 51 | """ 52 | Load a model from a checkpoint file. 53 | ### Parameters: 54 | - `pretrained_model_name_or_path`: path to the checkpoint file or repo id. 55 | - `model_kwargs`: additional keyword arguments to override the parameters in the checkpoint. 56 | - `hf_kwargs`: additional keyword arguments to pass to the `hf_hub_download` function. Ignored if `pretrained_model_name_or_path` is a local path. 57 | ### Returns: 58 | - A new instance of `MoGe` with the parameters loaded from the checkpoint. 59 | """ 60 | ckpt_path = None 61 | if Path(pretrained_model_name_or_path).exists(): 62 | ckpt_path = pretrained_model_name_or_path 63 | else: 64 | cached_checkpoint_path = hf_hub_download( 65 | repo_id=pretrained_model_name_or_path, 66 | repo_type="model", 67 | filename="model.ckpt", 68 | **hf_kwargs 69 | ) 70 | ckpt_path = cached_checkpoint_path 71 | # model_config = checkpoint['model_config'] 72 | # if model_kwargs is not None: 73 | # model_config.update(model_kwargs) 74 | if model_kwargs is None: 75 | model_kwargs = {} 76 | model_kwargs.update({'ckpt_path': ckpt_path}) 77 | model = cls(**model_kwargs) 78 | return model 79 | 80 | def load_checkpoint(self, ckpt_path): 81 | if os.path.exists(ckpt_path): 82 | Log.info(f'Loading checkpoint from {ckpt_path}') 83 | checkpoint = torch.load(ckpt_path, map_location='cpu') 84 | self.load_state_dict( 85 | {k[9:]: v for k, v in checkpoint['state_dict'].items()}) 86 | else: 87 | Log.warn(f'Checkpoint {ckpt_path} not found') 88 | 89 | def forward(self, x, prompt_depth=None): 90 | assert prompt_depth is not None, 'prompt_depth is required' 91 | prompt_depth, min_val, max_val = self.normalize(prompt_depth) 92 | h, w = x.shape[-2:] 93 | features = self.pretrained.get_intermediate_layers( 94 | (x - self._mean) / self._std, self.model_config['layer_idxs'], 95 | return_class_token=True) 96 | patch_h, patch_w = h // self.patch_size, w // self.patch_size 97 | depth = self.depth_head(features, patch_h, patch_w, prompt_depth) 98 | depth = self.denormalize(depth, min_val, max_val) 99 | return depth 100 | 101 | @torch.no_grad() 102 | def predict(self, 103 | image: torch.Tensor, 104 | prompt_depth: torch.Tensor): 105 | return self.forward(image, prompt_depth) 106 | 107 | def normalize(self, 108 | prompt_depth: torch.Tensor): 109 | B, C, H, W = prompt_depth.shape 110 | min_val = torch.quantile( 111 | prompt_depth.reshape(B, -1), 0., dim=1, keepdim=True)[:, :, None, None] 112 | max_val = torch.quantile( 113 | prompt_depth.reshape(B, -1), 1., dim=1, keepdim=True)[:, :, None, None] 114 | prompt_depth = (prompt_depth - min_val) / (max_val - min_val) 115 | return prompt_depth, min_val, max_val 116 | 117 | def denormalize(self, 118 | depth: torch.Tensor, 119 | min_val: torch.Tensor, 120 | max_val: torch.Tensor): 121 | return depth * (max_val - min_val) + min_val 122 | -------------------------------------------------------------------------------- /promptda/scripts/generate_video.py: -------------------------------------------------------------------------------- 1 | import tyro 2 | from tqdm.auto import tqdm 3 | import numpy as np 4 | import glob 5 | import os 6 | import cv2 7 | import imageio 8 | from promptda.utils.depth_utils import smooth_min_max, visualize_depth 9 | from promptda.utils.io_wrapper import load_depth, load_image 10 | from promptda.utils.parallel_utils import async_call, parallel_execution 11 | 12 | 13 | def load_depths(depth_paths: list[str]) -> list[np.ndarray]: 14 | depths = parallel_execution(depth_paths, 15 | to_tensor=False, 16 | action=load_depth, 17 | num_processes=32, 18 | desc='Loading depths') 19 | return depths 20 | 21 | 22 | def load_imgs(rgb_paths: list[str], max_size: int) -> list[np.ndarray]: 23 | rgbs = parallel_execution(rgb_paths, 24 | to_tensor=False, 25 | max_size=max_size, 26 | action=load_image, 27 | num_processes=32, 28 | desc='Loading RGB images') 29 | return rgbs 30 | 31 | 32 | def load_result_depths(result_path: str) -> list[np.ndarray]: 33 | depth_paths = sorted(glob.glob(os.path.join(result_path, '*.png'))) 34 | depths = load_depths(depth_paths) 35 | return depths 36 | 37 | 38 | def load_prompt_depths(input_path: str) -> list[np.ndarray]: 39 | prompt_depth_paths = sorted(glob.glob(os.path.join(input_path, 'depth/*.png'))) 40 | prompt_depths = load_depths(prompt_depth_paths) 41 | return prompt_depths 42 | 43 | 44 | def load_rgbs(input_path: str, max_size: int) -> list[np.ndarray]: 45 | rgb_paths = sorted(glob.glob(os.path.join(input_path, 'rgb/*.jpg'))) 46 | rgbs = load_imgs(rgb_paths, max_size) 47 | return rgbs 48 | 49 | @async_call 50 | def generate_frame( 51 | depth: np.ndarray, 52 | min_val: float, 53 | max_val: float, 54 | frame_idx: int, 55 | result_path: str, 56 | prompt_depth: np.ndarray = None, 57 | rgb: np.ndarray = None, 58 | ) -> None: 59 | output_img = visualize_depth(depth, min_val, max_val) 60 | if prompt_depth is not None: 61 | prompt_depth_img = visualize_depth(prompt_depth, min_val, max_val) 62 | if prompt_depth_img.shape[:2] != depth.shape[:2]: 63 | prompt_depth_img = cv2.resize(prompt_depth_img, (depth.shape[1], depth.shape[0])) 64 | output_img = np.concatenate([output_img, prompt_depth_img], axis=1) 65 | if rgb is not None: 66 | if rgb.shape[:2] != depth.shape[:2]: 67 | rgb = cv2.resize(rgb, (depth.shape[1], depth.shape[0])) 68 | if rgb.dtype == np.float32 or rgb.dtype == np.float64: 69 | rgb = (rgb * 255).astype(np.uint8) 70 | output_img = np.concatenate([rgb, output_img], axis=1) 71 | imageio.imwrite(os.path.join(result_path, f'{frame_idx:06d}_smooth.jpg'), output_img) 72 | 73 | 74 | def process_stray_scan(input_path: str = 'data/8b98276b0a', 75 | result_path: str = 'data/8b98276b0a_results', 76 | include_prompt: bool = True, 77 | include_rgb: bool = True, 78 | percentile: float = 2, 79 | smooth_interval: int = 60) -> None: 80 | result_depths = load_result_depths(result_path) 81 | min_vals = [np.percentile(depth, percentile) for depth in result_depths] 82 | max_vals = [np.percentile(depth, 100 - percentile) for depth in result_depths] 83 | min_vals_smooth, max_vals_smooth = smooth_min_max(min_vals, max_vals, smooth_interval) 84 | 85 | if include_prompt: 86 | prompt_depths = load_prompt_depths(input_path) 87 | if include_rgb: 88 | rgbs = load_rgbs(input_path, max(result_depths[0].shape)) 89 | 90 | min_len = min(len(result_depths), len(prompt_depths), len(rgbs)) 91 | result_depths = result_depths[:min_len] 92 | prompt_depths = prompt_depths[:min_len] 93 | rgbs = rgbs[:min_len] 94 | 95 | for frame_idx in tqdm(range(len(result_depths)), desc='Generating frames'): 96 | generate_frame(result_depths[frame_idx], 97 | min_vals_smooth[frame_idx], 98 | max_vals_smooth[frame_idx], 99 | frame_idx, 100 | result_path, 101 | prompt_depths[frame_idx] if include_prompt else None, 102 | rgbs[frame_idx] if include_rgb else None) 103 | 104 | def main() -> None: 105 | pass 106 | 107 | 108 | if __name__ == "__main__": 109 | tyro.extras.subcommand_cli_from_dict( 110 | { 111 | 'process_stray_scan': process_stray_scan, 112 | 'main': main, 113 | } 114 | ) 115 | 116 | -------------------------------------------------------------------------------- /promptda/scripts/infer_stray_scan.py: -------------------------------------------------------------------------------- 1 | import tyro 2 | import os 3 | import glob 4 | from tqdm.auto import tqdm 5 | 6 | from promptda.utils.io_wrapper import load_image, load_depth, save_depth 7 | from promptda.utils.parallel_utils import parallel_execution 8 | from promptda.promptda import PromptDA 9 | 10 | def load_data(input_path: str, max_size: int): 11 | root_dir = os.path.dirname(input_path) 12 | scene_name = input_path.split('/')[-1].split('.')[0] 13 | input_dir = os.path.join(root_dir, scene_name) 14 | if not os.path.exists(input_dir): 15 | cmd = f'unzip -o {input_path} -d {root_dir}' 16 | os.system(cmd) 17 | 18 | if not os.path.exists(os.path.join(input_dir, 'rgb')): 19 | os.makedirs(os.path.join(input_dir, 'rgb'), exist_ok=True) 20 | cmd = f'ffmpeg -i {input_dir}/rgb.mp4 -start_number 0 -q:v 2 {input_dir}/rgb/%06d.jpg' 21 | os.system(cmd) 22 | 23 | rgb_files = sorted(glob.glob(os.path.join(input_dir, 'rgb', '*.jpg'))) 24 | prompt_depth_files = sorted(glob.glob(os.path.join(input_dir, 'depth', '*.png'))) 25 | 26 | if len(rgb_files) != len(prompt_depth_files): 27 | min_len = min(len(rgb_files), len(prompt_depth_files)) 28 | rgb_files = rgb_files[:min_len] 29 | prompt_depth_files = prompt_depth_files[:min_len] 30 | 31 | rgbs = parallel_execution(rgb_files, 32 | to_tensor=True, # to_tensor 33 | max_size=max_size, 34 | action=load_image, 35 | num_processes=32, 36 | print_progress=True, 37 | desc='Loading RGB images') 38 | 39 | prompt_depths = parallel_execution(prompt_depth_files, 40 | to_tensor=True, # to_tensor 41 | action=load_depth, 42 | num_processes=32, 43 | print_progress=True, 44 | desc='Loading Prompt Depth') 45 | return rgbs, prompt_depths 46 | 47 | 48 | def main(input_path: str = 'data/8b98276b0a.zip', 49 | output_path: str = 'data/8b98276b0a_results', 50 | max_size: int = 1008, 51 | ): 52 | os.makedirs(output_path, exist_ok=True) 53 | rgbs, prompt_depths = load_data(input_path, max_size) 54 | results = [] 55 | DEVICE = 'cuda' 56 | model = PromptDA.from_pretrained("depth-anything/prompt-depth-anything-vitl").to(DEVICE).eval() 57 | for frame_idx, (rgb, prompt_depth) in tqdm(enumerate(zip(rgbs, prompt_depths)), desc='Inferring', total=len(rgbs)): 58 | rgb, prompt_depth = rgb.to(DEVICE), prompt_depth.to(DEVICE) 59 | depth = model.predict(rgb, prompt_depth) 60 | save_depth(depth.detach().cpu(), 61 | output_path=os.path.join(output_path, f'{frame_idx:06d}.png'), 62 | save_vis=True) 63 | 64 | if __name__ == "__main__": 65 | tyro.cli(main) -------------------------------------------------------------------------------- /promptda/scripts/sanity_check.py: -------------------------------------------------------------------------------- 1 | from promptda.promptda import PromptDA 2 | from promptda.utils.io_wrapper import load_image, load_depth, save_depth 3 | 4 | DEVICE = 'cuda' 5 | image_path = "assets/example_images/image.jpg" 6 | prompt_depth_path = "assets/example_images/arkit_depth.png" 7 | image = load_image(image_path).to(DEVICE) 8 | prompt_depth = load_depth(prompt_depth_path).to(DEVICE) # 192x256, ARKit LiDAR depth in meters 9 | 10 | model = PromptDA.from_pretrained("depth-anything/prompt-depth-anything-vitl").to(DEVICE).eval() 11 | depth = model.predict(image, prompt_depth) # HxW, depth in meters 12 | 13 | save_depth(depth, prompt_depth=prompt_depth, image=image) -------------------------------------------------------------------------------- /promptda/utils/depth_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib 3 | import open3d as o3d 4 | from scipy.interpolate import CubicSpline 5 | 6 | def visualize_depth(depth: np.ndarray, 7 | depth_min=None, 8 | depth_max=None, 9 | percentile=2, 10 | ret_minmax=False, 11 | cmap='Spectral'): 12 | if depth_min is None: depth_min = np.percentile(depth, percentile) 13 | if depth_max is None: depth_max = np.percentile(depth, 100 - percentile) 14 | if depth_min == depth_max: 15 | depth_min = depth_min - 1e-6 16 | depth_max = depth_max + 1e-6 17 | cm = matplotlib.colormaps[cmap] 18 | depth = ((depth - depth_min) / (depth_max - depth_min)).clip(0, 1) 19 | img_colored_np = cm(depth[None], bytes=False)[:, :, :, 0:3] # value from 0 to 1 20 | img_colored_np = (img_colored_np[0] * 255.0).astype(np.uint8) 21 | if ret_minmax: 22 | return img_colored_np, depth_min, depth_max 23 | else: 24 | return img_colored_np 25 | 26 | 27 | def unproject_depth(depth, 28 | ixt, 29 | depth_min=0.01, 30 | depth_max=None, 31 | color=None, 32 | ext=None, 33 | conf=None, 34 | ret_pcd=False, 35 | clip_box=None): 36 | height, width = depth.shape 37 | x = np.arange(0, width) 38 | y = np.arange(0, height) 39 | xx, yy = np.meshgrid(x, y) 40 | xx = xx.reshape(-1) 41 | yy = yy.reshape(-1) 42 | zz = depth.reshape(-1) 43 | mask = np.ones_like(xx, dtype=np.bool_) 44 | if depth_min is not None: 45 | mask &= zz >= depth_min 46 | if depth_max is not None: 47 | mask &= zz <= depth_max 48 | if conf is not None: 49 | mask &= conf.reshape(-1) == 2 50 | xx = xx[mask] 51 | yy = yy[mask] 52 | zz = zz[mask] 53 | pcd = np.stack([xx, yy, np.ones_like(xx)], axis=1) 54 | pcd = pcd * zz[:, None] 55 | pcd = np.dot(pcd, np.linalg.inv(ixt).T) 56 | if ext is not None: 57 | pcd = np.concatenate([pcd, np.ones((pcd.shape[0], 1))], axis=1) 58 | pcd = np.dot(pcd, np.linalg.inv(ext).T) 59 | new_mask = np.ones_like(pcd[:, 0]).astype(np.bool_) 60 | if clip_box is not None: 61 | assert len(clip_box) == 6 62 | for i, val in enumerate(clip_box): 63 | if val is None: 64 | continue 65 | if i == 0: new_mask &= (pcd[:, 0] <= val) 66 | elif i == 1: new_mask &= (pcd[:, 1] <= val) 67 | elif i == 2: new_mask &= (pcd[:, 2] <= val) 68 | elif i == 3: new_mask &= (pcd[:, 0] >= val) 69 | elif i == 4: new_mask &= (pcd[:, 1] >= val) 70 | elif i == 5: new_mask &= (pcd[:, 2] >= val) 71 | if color is not None: 72 | if color.dtype == np.uint8: 73 | color = color.astype(np.float32) / 255. 74 | if ret_pcd: 75 | points = pcd 76 | pcd = o3d.geometry.PointCloud() 77 | pcd.points = o3d.utility.Vector3dVector(points[:, :3][new_mask]) 78 | pcd.colors = o3d.utility.Vector3dVector(color.reshape(-1, 3)[mask][new_mask]) 79 | else: 80 | return pcd[:, :3][new_mask], color.reshape(-1, 3)[mask][new_mask] 81 | else: 82 | if ret_pcd: 83 | points = pcd 84 | pcd = o3d.geometry.PointCloud() 85 | pcd.points = o3d.utility.Vector3dVector(pcd[:, :3][new_mask]) 86 | else: 87 | return pcd[:, :3][new_mask] 88 | return pcd 89 | 90 | def smooth_min_max(min_vals, 91 | max_vals, 92 | interval: int = 60): 93 | ''' 94 | Slerp interpolate and smooth min and max values 95 | Args: 96 | min_vals: list[float] 97 | max_vals: list[float] 98 | Returns: 99 | min_vals_smooth: list[float] 100 | max_vals_smooth: list[float] 101 | ''' 102 | 103 | key_frames = list(range(0, len(min_vals), interval)) 104 | if key_frames[-1] != len(min_vals) - 1: 105 | key_frames.append(len(min_vals) - 1) 106 | 107 | key_frame_indices = np.array(key_frames) 108 | min_key_vals = np.array([min_vals[i] for i in key_frames]) 109 | max_key_vals = np.array([max_vals[i] for i in key_frames]) 110 | 111 | # Use CubicSpline for smooth interpolation 112 | min_spline = CubicSpline( 113 | key_frame_indices, min_key_vals, bc_type='natural') 114 | max_spline = CubicSpline( 115 | key_frame_indices, max_key_vals, bc_type='natural') 116 | 117 | x_full = np.arange(len(min_vals)) 118 | min_vals_smooth = min_spline(x_full) 119 | max_vals_smooth = max_spline(x_full) 120 | # plt.plot(min_vals, label='min_vals') 121 | # plt.plot(min_vals_smooth, label='min_vals_smooth') 122 | # plt.legend() 123 | # plt.savefig('min_vals.png') 124 | return min_vals_smooth, max_vals_smooth 125 | 126 | if __name__ == '__main__': 127 | depth = np.random.rand(100, 100) 128 | visualize_depth(depth) -------------------------------------------------------------------------------- /promptda/utils/io_wrapper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import imageio 3 | import torch 4 | import os 5 | import matplotlib.pyplot as plt 6 | import cv2 7 | 8 | from promptda.utils.logger import Log 9 | from promptda.utils.depth_utils import visualize_depth 10 | from promptda.utils.parallel_utils import async_call 11 | 12 | # DEVICE = 'cuda' if torch.cuda.is_available( 13 | # ) else 'mps' if torch.backends.mps.is_available() else 'cpu' 14 | 15 | 16 | def to_tensor_func(arr): 17 | if arr.ndim == 2: 18 | arr = arr[:, :, np.newaxis] 19 | return torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0) 20 | 21 | 22 | def to_numpy_func(tensor): 23 | arr = tensor.squeeze(0).permute(1, 2, 0).cpu().numpy() 24 | if arr.shape[2] == 1: 25 | arr = arr[:, :, 0] 26 | return arr 27 | 28 | 29 | def ensure_multiple_of(x, multiple_of=14): 30 | return int(x // multiple_of * multiple_of) 31 | 32 | 33 | def load_image(image_path, to_tensor=True, max_size=1008, multiple_of=14): 34 | ''' 35 | Load image from path and convert to tensor 36 | max_size // 14 = 0 37 | ''' 38 | image = np.asarray(imageio.imread(image_path)).astype(np.float32) 39 | image = image / 255. 40 | 41 | max_size = max_size // multiple_of * multiple_of 42 | if max(image.shape) > max_size: 43 | h, w = image.shape[:2] 44 | scale = max_size / max(h, w) 45 | tar_h = ensure_multiple_of(h * scale) 46 | tar_w = ensure_multiple_of(w * scale) 47 | image = cv2.resize(image, (tar_w, tar_h), interpolation=cv2.INTER_AREA) 48 | if to_tensor: 49 | return to_tensor_func(image) 50 | return image 51 | 52 | 53 | def load_depth(depth_path, to_tensor=True): 54 | ''' 55 | Load depth from path and convert to tensor 56 | ''' 57 | if depth_path.endswith('.png'): 58 | depth = np.asarray(imageio.imread(depth_path)).astype(np.float32) 59 | depth = depth / 1000. 60 | elif depth_path.endswith('.npz'): 61 | depth = np.load(depth_path)['depth'] 62 | else: 63 | raise ValueError(f"Unsupported depth format: {depth_path}") 64 | if to_tensor: 65 | return to_tensor_func(depth) 66 | return depth 67 | 68 | 69 | @async_call 70 | def save_depth(depth, 71 | prompt_depth=None, 72 | image=None, 73 | output_path='results/example_depth.png', 74 | save_vis=True): 75 | ''' 76 | Save depth to path 77 | ''' 78 | os.makedirs(os.path.dirname(output_path), exist_ok=True) 79 | depth = to_numpy_func(depth) 80 | uint16_depth = (depth * 1000.).astype(np.uint16) 81 | imageio.imwrite(output_path, uint16_depth) 82 | Log.info(f'Saved depth to {output_path}', tag='save_depth') 83 | 84 | if not save_vis: 85 | return 86 | output_path_ = output_path 87 | output_path = output_path_.replace('.png', '_depth.jpg') 88 | depth_vis, depth_min, depth_max = visualize_depth(depth, ret_minmax=True) 89 | imageio.imwrite(output_path, depth_vis) 90 | 91 | 92 | if prompt_depth is not None: 93 | prompt_depth = to_numpy_func(prompt_depth) 94 | output_path = output_path_.replace('.png', '_prompt_depth.jpg') 95 | prompt_depth_vis = visualize_depth(prompt_depth, 96 | depth_min=depth_min, 97 | depth_max=depth_max) 98 | imageio.imwrite(output_path, prompt_depth_vis) 99 | 100 | if image is not None: 101 | output_path = output_path_.replace('.png', '_image.jpg') 102 | image = to_numpy_func(image) 103 | imageio.imwrite(output_path, (image * 255).astype(np.uint8)) 104 | -------------------------------------------------------------------------------- /promptda/utils/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | class Log: 5 | log_on = True # fast switch 6 | used_tags = dict() # To keep track of used tags 7 | _is_main_cached = None # Cache to store the main process check result 8 | 9 | @staticmethod 10 | def is_main_process(): 11 | if Log._is_main_cached is not None: 12 | return Log._is_main_cached 13 | try: 14 | from pytorch_lightning.utilities import rank_zero_only 15 | if rank_zero_only.rank == 0: 16 | Log._is_main_cached = True 17 | else: 18 | Log._is_main_cached = False 19 | except: 20 | Log._is_main_cached = True 21 | return Log._is_main_cached 22 | 23 | @staticmethod 24 | def _should_log(tag): 25 | """ 26 | Determine if the log information should be recorded. 27 | Conditions: log function is enabled, current process is the main process, and the tag has not been used. 28 | """ 29 | if not Log.log_on: 30 | return False 31 | if not Log.is_main_process(): 32 | return False 33 | if tag is None: 34 | return True 35 | if '__' in tag: 36 | num = int(tag.split('__')[-1]) 37 | tag = tag.split('__')[0] # can output num same information 38 | else: 39 | num = 3 # default 3 40 | 41 | if tag not in Log.used_tags: 42 | Log.used_tags[tag] = num 43 | Log.used_tags[tag] -= 1 44 | if Log.used_tags[tag] >= 0: 45 | return True 46 | else: 47 | return False 48 | 49 | @staticmethod 50 | def info(*args, tag=None): 51 | """ 52 | Output INFO level log information. 53 | """ 54 | if Log._should_log(tag): 55 | print("\033[1;32m[INFO]\033[0;0m", *args) 56 | 57 | @staticmethod 58 | def warn(*args, tag=None): 59 | """ 60 | Output WARN level log information. 61 | """ 62 | if Log._should_log(tag): 63 | print("\033[1;35m[WARN]\033[0;0m", *args) 64 | 65 | @staticmethod 66 | def error(*args, tag=None): 67 | print("\033[1;31m[ERROR]\033[0;0m", *args) 68 | 69 | @staticmethod 70 | def debug(*args, tag=None): 71 | """ 72 | Output DEBUG level log information. 73 | """ 74 | if Log._should_log(tag) and 'HT_DEBUG' in os.environ and os.environ['HT_DEBUG'] == '1': 75 | print("\033[1;33m[DEBUG]\033[0;0m", *args) 76 | -------------------------------------------------------------------------------- /promptda/utils/parallel_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, List, Dict 2 | from multiprocessing.pool import ThreadPool 3 | from tqdm import tqdm 4 | from threading import Thread 5 | import asyncio 6 | from functools import wraps 7 | 8 | 9 | def async_call_func(func): 10 | @wraps(func) 11 | async def wrapper(*args, **kwargs): 12 | loop = asyncio.get_event_loop() 13 | # Use run_in_executor to run the blocking function in a separate thread 14 | return await loop.run_in_executor(None, func, *args, **kwargs) 15 | return wrapper 16 | 17 | 18 | def async_call(fn): 19 | def wrapper(*args, **kwargs): 20 | Thread(target=fn, args=args, kwargs=kwargs).start() 21 | return wrapper 22 | 23 | 24 | def parallel_execution(*args, action: Callable, num_processes=32, print_progress=False, sequential=False, async_return=False, desc=None, **kwargs): 25 | # Copy from EasyVolCap 26 | # Author: Zhen Xu https://github.com/dendenxu 27 | # NOTE: we expect first arg / or kwargs to be distributed 28 | # NOTE: print_progress arg is reserved 29 | def get_length(args: List, kwargs: Dict): 30 | for a in args: 31 | if isinstance(a, list): 32 | return len(a) 33 | for v in kwargs.values(): 34 | if isinstance(v, list): 35 | return len(v) 36 | raise NotImplementedError 37 | 38 | def get_action_args(length: int, args: List, kwargs: Dict, i: int): 39 | action_args = [(arg[i] if isinstance(arg, list) and len( 40 | arg) == length else arg) for arg in args] 41 | # TODO: Support all types of iterable 42 | action_kwargs = {key: (kwargs[key][i] if isinstance(kwargs[key], list) and len( 43 | kwargs[key]) == length else kwargs[key]) for key in kwargs} 44 | return action_args, action_kwargs 45 | 46 | if not sequential: 47 | # Create ThreadPool 48 | pool = ThreadPool(processes=num_processes) 49 | 50 | # Spawn threads 51 | results = [] 52 | asyncs = [] 53 | length = get_length(args, kwargs) 54 | for i in range(length): 55 | action_args, action_kwargs = get_action_args( 56 | length, args, kwargs, i) 57 | async_result = pool.apply_async(action, action_args, action_kwargs) 58 | asyncs.append(async_result) 59 | 60 | # Join threads and get return values 61 | if not async_return: 62 | for async_result in tqdm(asyncs, desc=desc, disable=not print_progress): 63 | # will sync the corresponding thread 64 | results.append(async_result.get()) 65 | pool.close() 66 | pool.join() 67 | return results 68 | else: 69 | return pool 70 | else: 71 | results = [] 72 | length = get_length(args, kwargs) 73 | for i in tqdm(range(length), desc=desc, disable=not print_progress): 74 | action_args, action_kwargs = get_action_args( 75 | length, args, kwargs, i) 76 | async_result = action(*action_args, **action_kwargs) 77 | results.append(async_result) 78 | return results 79 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # python version >= 3.9 <= 3.11 2 | # torch 3 | torch==2.0.1 4 | torchvision==0.15.2 5 | torchaudio==2.0.2 6 | xformers==0.0.22 7 | # 8 | lightning==2.1.3 9 | imageio>=2.33.1 10 | Pillow>=10.1.0 11 | imageio-ffmpeg 12 | einops 13 | tqdm 14 | ipdb 15 | # Diffusion 16 | transformers 17 | # ray 18 | termcolor 19 | numpy==1.26.4 20 | opencv-python==4.9.0.80 21 | scipy 22 | matplotlib 23 | h5py 24 | tyro==0.9.2 25 | # open3d 26 | 27 | # app.py 28 | gradio==4.44.1 29 | gradio-imageslider==0.0.20 30 | spaces 31 | trimesh 32 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name="promptda", 5 | version="1.0", 6 | packages=find_packages(where="src"), 7 | author="Haotong Lin", 8 | description=["Prompt Depth Anything"], 9 | url="https://github.com/DepthAnything/PromptDA", 10 | ) 11 | -------------------------------------------------------------------------------- /torchhub/README.md: -------------------------------------------------------------------------------- 1 | # Local PyTorch Hub 2 | 3 | This directory is for loading the DINOv2 encoder locally in case of no Internet connection. 4 | -------------------------------------------------------------------------------- /torchhub/facebookresearch_dinov2_main/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /torchhub/facebookresearch_dinov2_main/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to DINOv2 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Meta's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to DINOv2, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. 32 | -------------------------------------------------------------------------------- /torchhub/facebookresearch_dinov2_main/MODEL_CARD.md: -------------------------------------------------------------------------------- 1 | # Model Card for DINOv2-S/B/L/g 2 | 3 | These are Vision Transformer models trained following the method described in the paper: 4 | "DINOv2: Learning Robust Visual Features without Supervision" 5 | 6 | We provide 4 models: 1 ViT-g trained from scratch, and 3 ViT-S/B/L models distilled from the ViT-g. 7 | 8 | ## Model Details 9 | The model takes an image as input and returns a class token and patch tokens. 10 | 11 | The embedding dimension is: 12 | - 384 for ViT-S. 13 | - 768 for ViT-B. 14 | - 1024 for ViT-L. 15 | - 1536 for ViT-g. 16 | 17 | The models follow a Transformer architecture, with a patch size of 14. 18 | 19 | For a 224x224 image, this results in 1 class token + 256 patch tokens. 20 | 21 | The models can accept larger images provided the image shapes are multiples of the patch size (14). 22 | If this condition is not verified, the model will crop to the closest smaller multiple of the patch size. 23 | 24 | ### Model Description 25 | 26 | - **Developed by:** Meta AI 27 | - **Model type:** Vision Transformer 28 | - **License:** CC-BY-NC 29 | 30 | - **Repository:** https://github.com/facebookresearch/dinov2 31 | - **Paper:** https://arxiv.org/abs/2304.07193 32 | - **Demo:** https://dinov2.metademolab.com/ 33 | 34 | ## Uses 35 | 36 | The models are vision backbones providing multi-purpose features for downstream tasks. 37 | 38 | ### Direct Use 39 | 40 | The models can be used without fine-tuning, with downstream classifiers as simple as linear layers, to obtain competitive results: 41 | - on depth estimation, semantic segmentation, using linear layers. 42 | - on image classification, using k-NN classifiers on the class token. 43 | - on image classification, with logistic regression classifiers applied on the class token. 44 | - on image classification, with a linear layer applied on the class token and the average of the patch tokens. 45 | - on image retrieval using nearest neighbors. 46 | 47 | ### Downstream Use 48 | 49 | It is technically possible to perform fine-tuning on the models, for small gains (we measured +2% on ImageNet-1k classification). 50 | We recommend keeping this as a very last step and only when necessary, as the features already provide good performance out-of-the-box. 51 | 52 | ## Bias, Risks, and Limitations 53 | 54 | Despite improvements thanks to the training method not using annotations, we still observe significant biases in our models toward rich households from Western countries. 55 | 56 | ### Recommendations 57 | 58 | We expect fine-tuning will increase the biases in the features produced by the model as they will be tuned to the fine-tuning labels. 59 | 60 | ## How to Get Started with the Model 61 | 62 | Use the code below to get started with the model. 63 | 64 | ```python 65 | import torch 66 | dinov2_vits14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14') 67 | dinov2_vitb14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14') 68 | dinov2_vitl14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14') 69 | dinov2_vitg14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14') 70 | ``` 71 | 72 | ## Training Details 73 | 74 | ### Training Data 75 | 76 | - **Training data:** LVD-142M (see paper) 77 | - **Training regime:** fp16 using PyTorch-FSDP mixed-precision. 78 | 79 | ### Training Procedure 80 | 81 | - **Training objective:** 82 | - DINO self-distillation loss with multi-crop 83 | - iBOT masked-image modeling loss 84 | - KoLeo regularization on [CLS] tokens 85 | - **Architectures:** 86 | - ViT-S (21M params): Patch size 14, embedding dimension 384, 6 heads, MLP FFN 87 | - ViT-B (86M params): Patch size 14, embedding dimension 768, 12 heads, MLP FFN 88 | - ViT-L (0.3B params): Patch size 14, embedding dimension 1024, 16 heads, MLP FFN 89 | - ViT-g (1.1B params): Patch size 14, embedding dimension 1536, 24 heads, SwiGLU FFN 90 | - **Distillation:** 91 | - Distillation follows the standard DINOv2 pretraining procedure, except the teacher is a pretrained ViT-g, frozen. 92 | 93 | ## Evaluation 94 | 95 | We refer users to the associated paper for the evaluation protocols. 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 |
modelImageNet-1kNYU-Depth v2SUN-RGBDADE20kiNaturalist 2018Oxford-H
taskclassif. (acc)classif. (acc)classif. V2 (acc)depth (RMSE)depth (RMSE)segm. (mAP)classif. (acc)retrieval (mAP)
k-NNlinearlinearlinear
4 layers
NYU-D transfermultiscalelinearnearest neighbor
ViT-S/1479.0%81.1%70.8%0.4170.43147.269.5%43.2
ViT-B/1482.1%84.5%74.9%0.3620.40051.376.3%49.5
ViT-L/1483.5%86.3%77.6%0.3330.39653.179.8%54.0
ViT-g/1483.5%86.5%78.4%0.2980.36253.081.6%52.3
174 | 175 | ## Environmental Impact 176 | 177 | - **Hardware Type:** Nvidia A100 178 | - **Hours used:** 22,000 for ViT-g, 4,500 for ViT-S distillation, 5,300 for ViT-B distillation, 8,000 for ViT-L distillation 179 | - **Cloud Provider:** Private infra 180 | - **Compute Region:** USA 181 | - **Carbon Emitted:** 7t CO2eq 182 | 183 | #### Hardware 184 | 185 | Nvidia A100 GPUs 186 | 187 | #### Software 188 | 189 | PyTorch 2.0, 190 | xFormers 0.0.18 191 | 192 | **BibTeX** 193 | 194 | ``` 195 | @misc{oquab2023dinov2, 196 | title={DINOv2: Learning Robust Visual Features without Supervision}, 197 | author={Oquab, Maxime and Darcet, Timothée and Moutakanni, Theo and Vo, Huy and Szafraniec, Marc and Khalidov, Vasil and Fernandez, Pierre and Haziza, Daniel and Massa, Francisco and El-Nouby, Alaaeldin and Howes, Russell and Huang, Po-Yao and Xu, Hu and Sharma, Vasu and Li, Shang-Wen and Galuba, Wojciech and Rabbat, Mike and Assran, Mido and Ballas, Nicolas and Synnaeve, Gabriel and Misra, Ishan and Jegou, Herve and Mairal, Julien and Labatut, Patrick and Joulin, Armand and Bojanowski, Piotr}, 198 | journal={arXiv:2304.07193}, 199 | year={2023} 200 | } 201 | ``` 202 | -------------------------------------------------------------------------------- /torchhub/facebookresearch_dinov2_main/conda.yaml: -------------------------------------------------------------------------------- 1 | name: dinov2 2 | channels: 3 | - defaults 4 | - pytorch 5 | - nvidia 6 | - xformers 7 | - conda-forge 8 | dependencies: 9 | - python=3.9 10 | - pytorch::pytorch=2.0.0 11 | - pytorch::pytorch-cuda=11.7.0 12 | - pytorch::torchvision=0.15.0 13 | - omegaconf 14 | - torchmetrics=0.10.3 15 | - fvcore 16 | - iopath 17 | - xformers::xformers=0.0.18 18 | - pip 19 | - pip: 20 | - git+https://github.com/facebookincubator/submitit 21 | - --extra-index-url https://pypi.nvidia.com 22 | - cuml-cu11 23 | -------------------------------------------------------------------------------- /torchhub/facebookresearch_dinov2_main/dinov2/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | __version__ = "0.0.1" 8 | -------------------------------------------------------------------------------- /torchhub/facebookresearch_dinov2_main/dinov2/configs/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import pathlib 8 | 9 | from omegaconf import OmegaConf 10 | 11 | 12 | def load_config(config_name: str): 13 | config_filename = config_name + ".yaml" 14 | return OmegaConf.load(pathlib.Path(__file__).parent.resolve() / config_filename) 15 | 16 | 17 | dinov2_default_config = load_config("ssl_default_config") 18 | 19 | 20 | def load_and_merge_config(config_name: str): 21 | default_config = OmegaConf.create(dinov2_default_config) 22 | loaded_config = load_config(config_name) 23 | return OmegaConf.merge(default_config, loaded_config) 24 | -------------------------------------------------------------------------------- /torchhub/facebookresearch_dinov2_main/dinov2/configs/eval/vitb14_pretrain.yaml: -------------------------------------------------------------------------------- 1 | student: 2 | arch: vit_base 3 | patch_size: 14 4 | crops: 5 | global_crops_size: 518 # this is to set up the position embeddings properly 6 | local_crops_size: 98 -------------------------------------------------------------------------------- /torchhub/facebookresearch_dinov2_main/dinov2/configs/eval/vitg14_pretrain.yaml: -------------------------------------------------------------------------------- 1 | student: 2 | arch: vit_giant2 3 | patch_size: 14 4 | ffn_layer: swiglufused 5 | crops: 6 | global_crops_size: 518 # this is to set up the position embeddings properly 7 | local_crops_size: 98 -------------------------------------------------------------------------------- /torchhub/facebookresearch_dinov2_main/dinov2/configs/eval/vitl14_pretrain.yaml: -------------------------------------------------------------------------------- 1 | student: 2 | arch: vit_large 3 | patch_size: 14 4 | crops: 5 | global_crops_size: 518 # this is to set up the position embeddings properly 6 | local_crops_size: 98 -------------------------------------------------------------------------------- /torchhub/facebookresearch_dinov2_main/dinov2/configs/eval/vits14_pretrain.yaml: -------------------------------------------------------------------------------- 1 | student: 2 | arch: vit_small 3 | patch_size: 14 4 | crops: 5 | global_crops_size: 518 # this is to set up the position embeddings properly 6 | local_crops_size: 98 -------------------------------------------------------------------------------- /torchhub/facebookresearch_dinov2_main/dinov2/configs/ssl_default_config.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | WEIGHTS: '' 3 | compute_precision: 4 | grad_scaler: true 5 | teacher: 6 | backbone: 7 | sharding_strategy: SHARD_GRAD_OP 8 | mixed_precision: 9 | param_dtype: fp16 10 | reduce_dtype: fp16 11 | buffer_dtype: fp32 12 | dino_head: 13 | sharding_strategy: SHARD_GRAD_OP 14 | mixed_precision: 15 | param_dtype: fp16 16 | reduce_dtype: fp16 17 | buffer_dtype: fp32 18 | ibot_head: 19 | sharding_strategy: SHARD_GRAD_OP 20 | mixed_precision: 21 | param_dtype: fp16 22 | reduce_dtype: fp16 23 | buffer_dtype: fp32 24 | student: 25 | backbone: 26 | sharding_strategy: SHARD_GRAD_OP 27 | mixed_precision: 28 | param_dtype: fp16 29 | reduce_dtype: fp16 30 | buffer_dtype: fp32 31 | dino_head: 32 | sharding_strategy: SHARD_GRAD_OP 33 | mixed_precision: 34 | param_dtype: fp16 35 | reduce_dtype: fp32 36 | buffer_dtype: fp32 37 | ibot_head: 38 | sharding_strategy: SHARD_GRAD_OP 39 | mixed_precision: 40 | param_dtype: fp16 41 | reduce_dtype: fp32 42 | buffer_dtype: fp32 43 | dino: 44 | loss_weight: 1.0 45 | head_n_prototypes: 65536 46 | head_bottleneck_dim: 256 47 | head_nlayers: 3 48 | head_hidden_dim: 2048 49 | koleo_loss_weight: 0.1 50 | ibot: 51 | loss_weight: 1.0 52 | mask_sample_probability: 0.5 53 | mask_ratio_min_max: 54 | - 0.1 55 | - 0.5 56 | separate_head: false 57 | head_n_prototypes: 65536 58 | head_bottleneck_dim: 256 59 | head_nlayers: 3 60 | head_hidden_dim: 2048 61 | train: 62 | batch_size_per_gpu: 64 63 | dataset_path: ImageNet:split=TRAIN 64 | output_dir: . 65 | saveckp_freq: 20 66 | seed: 0 67 | num_workers: 10 68 | OFFICIAL_EPOCH_LENGTH: 1250 69 | cache_dataset: true 70 | centering: "centering" # or "sinkhorn_knopp" 71 | student: 72 | arch: vit_large 73 | patch_size: 16 74 | drop_path_rate: 0.3 75 | layerscale: 1.0e-05 76 | drop_path_uniform: true 77 | pretrained_weights: '' 78 | ffn_layer: "mlp" 79 | block_chunks: 0 80 | qkv_bias: true 81 | proj_bias: true 82 | ffn_bias: true 83 | teacher: 84 | momentum_teacher: 0.992 85 | final_momentum_teacher: 1 86 | warmup_teacher_temp: 0.04 87 | teacher_temp: 0.07 88 | warmup_teacher_temp_epochs: 30 89 | optim: 90 | epochs: 100 91 | weight_decay: 0.04 92 | weight_decay_end: 0.4 93 | base_lr: 0.004 # learning rate for a batch size of 1024 94 | lr: 0. # will be set after applying scaling rule 95 | warmup_epochs: 10 96 | min_lr: 1.0e-06 97 | clip_grad: 3.0 98 | freeze_last_layer_epochs: 1 99 | scaling_rule: sqrt_wrt_1024 100 | patch_embed_lr_mult: 0.2 101 | layerwise_decay: 0.9 102 | adamw_beta1: 0.9 103 | adamw_beta2: 0.999 104 | crops: 105 | global_crops_scale: 106 | - 0.32 107 | - 1.0 108 | local_crops_number: 8 109 | local_crops_scale: 110 | - 0.05 111 | - 0.32 112 | global_crops_size: 224 113 | local_crops_size: 96 114 | evaluation: 115 | eval_period_iterations: 12500 116 | -------------------------------------------------------------------------------- /torchhub/facebookresearch_dinov2_main/dinov2/configs/train/vitg14.yaml: -------------------------------------------------------------------------------- 1 | dino: 2 | head_n_prototypes: 131072 3 | head_bottleneck_dim: 384 4 | ibot: 5 | separate_head: true 6 | head_n_prototypes: 131072 7 | train: 8 | batch_size_per_gpu: 12 9 | dataset_path: ImageNet22k 10 | centering: sinkhorn_knopp 11 | student: 12 | arch: vit_giant2 13 | patch_size: 14 14 | drop_path_rate: 0.4 15 | ffn_layer: swiglufused 16 | block_chunks: 4 17 | teacher: 18 | momentum_teacher: 0.994 19 | optim: 20 | epochs: 500 21 | weight_decay_end: 0.2 22 | base_lr: 2.0e-04 # learning rate for a batch size of 1024 23 | warmup_epochs: 80 24 | layerwise_decay: 1.0 25 | crops: 26 | local_crops_size: 98 -------------------------------------------------------------------------------- /torchhub/facebookresearch_dinov2_main/dinov2/configs/train/vitl14.yaml: -------------------------------------------------------------------------------- 1 | dino: 2 | head_n_prototypes: 131072 3 | head_bottleneck_dim: 384 4 | ibot: 5 | separate_head: true 6 | head_n_prototypes: 131072 7 | train: 8 | batch_size_per_gpu: 32 9 | dataset_path: ImageNet22k 10 | centering: sinkhorn_knopp 11 | student: 12 | arch: vit_large 13 | patch_size: 14 14 | drop_path_rate: 0.4 15 | ffn_layer: swiglufused 16 | block_chunks: 4 17 | teacher: 18 | momentum_teacher: 0.994 19 | optim: 20 | epochs: 500 21 | weight_decay_end: 0.2 22 | base_lr: 2.0e-04 # learning rate for a batch size of 1024 23 | warmup_epochs: 80 24 | layerwise_decay: 1.0 25 | crops: 26 | local_crops_size: 98 -------------------------------------------------------------------------------- /torchhub/facebookresearch_dinov2_main/dinov2/configs/train/vitl16_short.yaml: -------------------------------------------------------------------------------- 1 | # this corresponds to the default config 2 | train: 3 | dataset_path: ImageNet:split=TRAIN 4 | batch_size_per_gpu: 64 5 | student: 6 | block_chunks: 4 7 | -------------------------------------------------------------------------------- /torchhub/facebookresearch_dinov2_main/dinov2/distributed/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import random 9 | import re 10 | import socket 11 | from typing import Dict, List 12 | 13 | import torch 14 | import torch.distributed as dist 15 | 16 | _LOCAL_RANK = -1 17 | _LOCAL_WORLD_SIZE = -1 18 | 19 | 20 | def is_enabled() -> bool: 21 | """ 22 | Returns: 23 | True if distributed training is enabled 24 | """ 25 | return dist.is_available() and dist.is_initialized() 26 | 27 | 28 | def get_global_size() -> int: 29 | """ 30 | Returns: 31 | The number of processes in the process group 32 | """ 33 | return dist.get_world_size() if is_enabled() else 1 34 | 35 | 36 | def get_global_rank() -> int: 37 | """ 38 | Returns: 39 | The rank of the current process within the global process group. 40 | """ 41 | return dist.get_rank() if is_enabled() else 0 42 | 43 | 44 | def get_local_rank() -> int: 45 | """ 46 | Returns: 47 | The rank of the current process within the local (per-machine) process group. 48 | """ 49 | if not is_enabled(): 50 | return 0 51 | assert 0 <= _LOCAL_RANK < _LOCAL_WORLD_SIZE 52 | return _LOCAL_RANK 53 | 54 | 55 | def get_local_size() -> int: 56 | """ 57 | Returns: 58 | The size of the per-machine process group, 59 | i.e. the number of processes per machine. 60 | """ 61 | if not is_enabled(): 62 | return 1 63 | assert 0 <= _LOCAL_RANK < _LOCAL_WORLD_SIZE 64 | return _LOCAL_WORLD_SIZE 65 | 66 | 67 | def is_main_process() -> bool: 68 | """ 69 | Returns: 70 | True if the current process is the main one. 71 | """ 72 | return get_global_rank() == 0 73 | 74 | 75 | def _restrict_print_to_main_process() -> None: 76 | """ 77 | This function disables printing when not in the main process 78 | """ 79 | import builtins as __builtin__ 80 | 81 | builtin_print = __builtin__.print 82 | 83 | def print(*args, **kwargs): 84 | force = kwargs.pop("force", False) 85 | if is_main_process() or force: 86 | builtin_print(*args, **kwargs) 87 | 88 | __builtin__.print = print 89 | 90 | 91 | def _get_master_port(seed: int = 0) -> int: 92 | MIN_MASTER_PORT, MAX_MASTER_PORT = (20_000, 60_000) 93 | 94 | master_port_str = os.environ.get("MASTER_PORT") 95 | if master_port_str is None: 96 | rng = random.Random(seed) 97 | return rng.randint(MIN_MASTER_PORT, MAX_MASTER_PORT) 98 | 99 | return int(master_port_str) 100 | 101 | 102 | def _get_available_port() -> int: 103 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 104 | # A "" host address means INADDR_ANY i.e. binding to all interfaces. 105 | # Note this is not compatible with IPv6. 106 | s.bind(("", 0)) 107 | port = s.getsockname()[1] 108 | return port 109 | 110 | 111 | _TORCH_DISTRIBUTED_ENV_VARS = ( 112 | "MASTER_ADDR", 113 | "MASTER_PORT", 114 | "RANK", 115 | "WORLD_SIZE", 116 | "LOCAL_RANK", 117 | "LOCAL_WORLD_SIZE", 118 | ) 119 | 120 | 121 | def _collect_env_vars() -> Dict[str, str]: 122 | return {env_var: os.environ[env_var] for env_var in _TORCH_DISTRIBUTED_ENV_VARS if env_var in os.environ} 123 | 124 | 125 | def _is_slurm_job_process() -> bool: 126 | return "SLURM_JOB_ID" in os.environ 127 | 128 | 129 | def _parse_slurm_node_list(s: str) -> List[str]: 130 | nodes = [] 131 | # Extract "hostname", "hostname[1-2,3,4-5]," substrings 132 | p = re.compile(r"(([^\[]+)(?:\[([^\]]+)\])?),?") 133 | for m in p.finditer(s): 134 | prefix, suffixes = s[m.start(2) : m.end(2)], s[m.start(3) : m.end(3)] 135 | for suffix in suffixes.split(","): 136 | span = suffix.split("-") 137 | if len(span) == 1: 138 | nodes.append(prefix + suffix) 139 | else: 140 | width = len(span[0]) 141 | start, end = int(span[0]), int(span[1]) + 1 142 | nodes.extend([prefix + f"{i:0{width}}" for i in range(start, end)]) 143 | return nodes 144 | 145 | 146 | def _check_env_variable(key: str, new_value: str): 147 | # Only check for difference with preset environment variables 148 | if key in os.environ and os.environ[key] != new_value: 149 | raise RuntimeError(f"Cannot export environment variables as {key} is already set") 150 | 151 | 152 | class _TorchDistributedEnvironment: 153 | def __init__(self): 154 | self.master_addr = "127.0.0.1" 155 | self.master_port = 0 156 | self.rank = -1 157 | self.world_size = -1 158 | self.local_rank = -1 159 | self.local_world_size = -1 160 | 161 | if _is_slurm_job_process(): 162 | return self._set_from_slurm_env() 163 | 164 | env_vars = _collect_env_vars() 165 | if not env_vars: 166 | # Environment is not set 167 | pass 168 | elif len(env_vars) == len(_TORCH_DISTRIBUTED_ENV_VARS): 169 | # Environment is fully set 170 | return self._set_from_preset_env() 171 | else: 172 | # Environment is partially set 173 | collected_env_vars = ", ".join(env_vars.keys()) 174 | raise RuntimeError(f"Partially set environment: {collected_env_vars}") 175 | 176 | if torch.cuda.device_count() > 0: 177 | return self._set_from_local() 178 | 179 | raise RuntimeError("Can't initialize PyTorch distributed environment") 180 | 181 | # Slurm job created with sbatch, submitit, etc... 182 | def _set_from_slurm_env(self): 183 | # logger.info("Initialization from Slurm environment") 184 | job_id = int(os.environ["SLURM_JOB_ID"]) 185 | node_count = int(os.environ["SLURM_JOB_NUM_NODES"]) 186 | nodes = _parse_slurm_node_list(os.environ["SLURM_JOB_NODELIST"]) 187 | assert len(nodes) == node_count 188 | 189 | self.master_addr = nodes[0] 190 | self.master_port = _get_master_port(seed=job_id) 191 | self.rank = int(os.environ["SLURM_PROCID"]) 192 | self.world_size = int(os.environ["SLURM_NTASKS"]) 193 | assert self.rank < self.world_size 194 | self.local_rank = int(os.environ["SLURM_LOCALID"]) 195 | self.local_world_size = self.world_size // node_count 196 | assert self.local_rank < self.local_world_size 197 | 198 | # Single node job with preset environment (i.e. torchrun) 199 | def _set_from_preset_env(self): 200 | # logger.info("Initialization from preset environment") 201 | self.master_addr = os.environ["MASTER_ADDR"] 202 | self.master_port = os.environ["MASTER_PORT"] 203 | self.rank = int(os.environ["RANK"]) 204 | self.world_size = int(os.environ["WORLD_SIZE"]) 205 | assert self.rank < self.world_size 206 | self.local_rank = int(os.environ["LOCAL_RANK"]) 207 | self.local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) 208 | assert self.local_rank < self.local_world_size 209 | 210 | # Single node and GPU job (i.e. local script run) 211 | def _set_from_local(self): 212 | # logger.info("Initialization from local") 213 | self.master_addr = "127.0.0.1" 214 | self.master_port = _get_available_port() 215 | self.rank = 0 216 | self.world_size = 1 217 | self.local_rank = 0 218 | self.local_world_size = 1 219 | 220 | def export(self, *, overwrite: bool) -> "_TorchDistributedEnvironment": 221 | # See the "Environment variable initialization" section from 222 | # https://pytorch.org/docs/stable/distributed.html for the complete list of 223 | # environment variables required for the env:// initialization method. 224 | env_vars = { 225 | "MASTER_ADDR": self.master_addr, 226 | "MASTER_PORT": str(self.master_port), 227 | "RANK": str(self.rank), 228 | "WORLD_SIZE": str(self.world_size), 229 | "LOCAL_RANK": str(self.local_rank), 230 | "LOCAL_WORLD_SIZE": str(self.local_world_size), 231 | } 232 | if not overwrite: 233 | for k, v in env_vars.items(): 234 | _check_env_variable(k, v) 235 | 236 | os.environ.update(env_vars) 237 | return self 238 | 239 | 240 | def enable(*, set_cuda_current_device: bool = True, overwrite: bool = False, allow_nccl_timeout: bool = False): 241 | """Enable distributed mode 242 | 243 | Args: 244 | set_cuda_current_device: If True, call torch.cuda.set_device() to set the 245 | current PyTorch CUDA device to the one matching the local rank. 246 | overwrite: If True, overwrites already set variables. Else fails. 247 | """ 248 | 249 | global _LOCAL_RANK, _LOCAL_WORLD_SIZE 250 | if _LOCAL_RANK >= 0 or _LOCAL_WORLD_SIZE >= 0: 251 | raise RuntimeError("Distributed mode has already been enabled") 252 | torch_env = _TorchDistributedEnvironment() 253 | torch_env.export(overwrite=overwrite) 254 | 255 | if set_cuda_current_device: 256 | torch.cuda.set_device(torch_env.local_rank) 257 | 258 | if allow_nccl_timeout: 259 | # This allows to use torch distributed timeout in a NCCL backend 260 | key, value = "NCCL_ASYNC_ERROR_HANDLING", "1" 261 | if not overwrite: 262 | _check_env_variable(key, value) 263 | os.environ[key] = value 264 | 265 | dist.init_process_group(backend="nccl") 266 | dist.barrier() 267 | 268 | # Finalize setup 269 | _LOCAL_RANK = torch_env.local_rank 270 | _LOCAL_WORLD_SIZE = torch_env.local_world_size 271 | _restrict_print_to_main_process() 272 | -------------------------------------------------------------------------------- /torchhub/facebookresearch_dinov2_main/dinov2/eval/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /torchhub/facebookresearch_dinov2_main/dinov2/eval/metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from enum import Enum 8 | import logging 9 | from typing import Any, Dict, Optional 10 | 11 | import torch 12 | from torch import Tensor 13 | from torchmetrics import Metric, MetricCollection 14 | from torchmetrics.classification import MulticlassAccuracy 15 | from torchmetrics.utilities.data import dim_zero_cat, select_topk 16 | 17 | 18 | logger = logging.getLogger("dinov2") 19 | 20 | 21 | class MetricType(Enum): 22 | MEAN_ACCURACY = "mean_accuracy" 23 | MEAN_PER_CLASS_ACCURACY = "mean_per_class_accuracy" 24 | PER_CLASS_ACCURACY = "per_class_accuracy" 25 | IMAGENET_REAL_ACCURACY = "imagenet_real_accuracy" 26 | 27 | @property 28 | def accuracy_averaging(self): 29 | return getattr(AccuracyAveraging, self.name, None) 30 | 31 | def __str__(self): 32 | return self.value 33 | 34 | 35 | class AccuracyAveraging(Enum): 36 | MEAN_ACCURACY = "micro" 37 | MEAN_PER_CLASS_ACCURACY = "macro" 38 | PER_CLASS_ACCURACY = "none" 39 | 40 | def __str__(self): 41 | return self.value 42 | 43 | 44 | def build_metric(metric_type: MetricType, *, num_classes: int, ks: Optional[tuple] = None): 45 | if metric_type.accuracy_averaging is not None: 46 | return build_topk_accuracy_metric( 47 | average_type=metric_type.accuracy_averaging, 48 | num_classes=num_classes, 49 | ks=(1, 5) if ks is None else ks, 50 | ) 51 | elif metric_type == MetricType.IMAGENET_REAL_ACCURACY: 52 | return build_topk_imagenet_real_accuracy_metric( 53 | num_classes=num_classes, 54 | ks=(1, 5) if ks is None else ks, 55 | ) 56 | 57 | raise ValueError(f"Unknown metric type {metric_type}") 58 | 59 | 60 | def build_topk_accuracy_metric(average_type: AccuracyAveraging, num_classes: int, ks: tuple = (1, 5)): 61 | metrics: Dict[str, Metric] = { 62 | f"top-{k}": MulticlassAccuracy(top_k=k, num_classes=int(num_classes), average=average_type.value) for k in ks 63 | } 64 | return MetricCollection(metrics) 65 | 66 | 67 | def build_topk_imagenet_real_accuracy_metric(num_classes: int, ks: tuple = (1, 5)): 68 | metrics: Dict[str, Metric] = {f"top-{k}": ImageNetReaLAccuracy(top_k=k, num_classes=int(num_classes)) for k in ks} 69 | return MetricCollection(metrics) 70 | 71 | 72 | class ImageNetReaLAccuracy(Metric): 73 | is_differentiable: bool = False 74 | higher_is_better: Optional[bool] = None 75 | full_state_update: bool = False 76 | 77 | def __init__( 78 | self, 79 | num_classes: int, 80 | top_k: int = 1, 81 | **kwargs: Any, 82 | ) -> None: 83 | super().__init__(**kwargs) 84 | self.num_classes = num_classes 85 | self.top_k = top_k 86 | self.add_state("tp", [], dist_reduce_fx="cat") 87 | 88 | def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore 89 | # preds [B, D] 90 | # target [B, A] 91 | # preds_oh [B, D] with 0 and 1 92 | # select top K highest probabilities, use one hot representation 93 | preds_oh = select_topk(preds, self.top_k) 94 | # target_oh [B, D + 1] with 0 and 1 95 | target_oh = torch.zeros((preds_oh.shape[0], preds_oh.shape[1] + 1), device=target.device, dtype=torch.int32) 96 | target = target.long() 97 | # for undefined targets (-1) use a fake value `num_classes` 98 | target[target == -1] = self.num_classes 99 | # fill targets, use one hot representation 100 | target_oh.scatter_(1, target, 1) 101 | # target_oh [B, D] (remove the fake target at index `num_classes`) 102 | target_oh = target_oh[:, :-1] 103 | # tp [B] with 0 and 1 104 | tp = (preds_oh * target_oh == 1).sum(dim=1) 105 | # at least one match between prediction and target 106 | tp.clip_(max=1) 107 | # ignore instances where no targets are defined 108 | mask = target_oh.sum(dim=1) > 0 109 | tp = tp[mask] 110 | self.tp.append(tp) # type: ignore 111 | 112 | def compute(self) -> Tensor: 113 | tp = dim_zero_cat(self.tp) # type: ignore 114 | return tp.float().mean() 115 | -------------------------------------------------------------------------------- /torchhub/facebookresearch_dinov2_main/dinov2/eval/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | from typing import Any, List, Optional, Tuple 9 | 10 | import torch 11 | import torch.backends.cudnn as cudnn 12 | 13 | from dinov2.models import build_model_from_cfg 14 | from dinov2.utils.config import setup 15 | import dinov2.utils.utils as dinov2_utils 16 | 17 | 18 | def get_args_parser( 19 | description: Optional[str] = None, 20 | parents: Optional[List[argparse.ArgumentParser]] = None, 21 | add_help: bool = True, 22 | ): 23 | parser = argparse.ArgumentParser( 24 | description=description, 25 | parents=parents or [], 26 | add_help=add_help, 27 | ) 28 | parser.add_argument( 29 | "--config-file", 30 | type=str, 31 | help="Model configuration file", 32 | ) 33 | parser.add_argument( 34 | "--pretrained-weights", 35 | type=str, 36 | help="Pretrained model weights", 37 | ) 38 | parser.add_argument( 39 | "--output-dir", 40 | default="", 41 | type=str, 42 | help="Output directory to write results and logs", 43 | ) 44 | parser.add_argument( 45 | "--opts", 46 | help="Extra configuration options", 47 | default=[], 48 | nargs="+", 49 | ) 50 | return parser 51 | 52 | 53 | def get_autocast_dtype(config): 54 | teacher_dtype_str = config.compute_precision.teacher.backbone.mixed_precision.param_dtype 55 | if teacher_dtype_str == "fp16": 56 | return torch.half 57 | elif teacher_dtype_str == "bf16": 58 | return torch.bfloat16 59 | else: 60 | return torch.float 61 | 62 | 63 | def build_model_for_eval(config, pretrained_weights): 64 | model, _ = build_model_from_cfg(config, only_teacher=True) 65 | dinov2_utils.load_pretrained_weights(model, pretrained_weights, "teacher") 66 | model.eval() 67 | model.cuda() 68 | return model 69 | 70 | 71 | def setup_and_build_model(args) -> Tuple[Any, torch.dtype]: 72 | cudnn.benchmark = True 73 | config = setup(args) 74 | model = build_model_for_eval(config, args.pretrained_weights) 75 | autocast_dtype = get_autocast_dtype(config) 76 | return model, autocast_dtype 77 | -------------------------------------------------------------------------------- /torchhub/facebookresearch_dinov2_main/dinov2/eval/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | from typing import Dict, Optional 9 | 10 | import torch 11 | from torch import nn 12 | from torchmetrics import MetricCollection 13 | 14 | from dinov2.data import DatasetWithEnumeratedTargets, SamplerType, make_data_loader 15 | import dinov2.distributed as distributed 16 | from dinov2.logging import MetricLogger 17 | 18 | 19 | logger = logging.getLogger("dinov2") 20 | 21 | 22 | class ModelWithNormalize(torch.nn.Module): 23 | def __init__(self, model): 24 | super().__init__() 25 | self.model = model 26 | 27 | def forward(self, samples): 28 | return nn.functional.normalize(self.model(samples), dim=1, p=2) 29 | 30 | 31 | class ModelWithIntermediateLayers(nn.Module): 32 | def __init__(self, feature_model, n_last_blocks, autocast_ctx): 33 | super().__init__() 34 | self.feature_model = feature_model 35 | self.feature_model.eval() 36 | self.n_last_blocks = n_last_blocks 37 | self.autocast_ctx = autocast_ctx 38 | 39 | def forward(self, images): 40 | with torch.inference_mode(): 41 | with self.autocast_ctx(): 42 | features = self.feature_model.get_intermediate_layers( 43 | images, self.n_last_blocks, return_class_token=True 44 | ) 45 | return features 46 | 47 | 48 | @torch.inference_mode() 49 | def evaluate( 50 | model: nn.Module, 51 | data_loader, 52 | postprocessors: Dict[str, nn.Module], 53 | metrics: Dict[str, MetricCollection], 54 | device: torch.device, 55 | criterion: Optional[nn.Module] = None, 56 | ): 57 | model.eval() 58 | if criterion is not None: 59 | criterion.eval() 60 | 61 | for metric in metrics.values(): 62 | metric = metric.to(device) 63 | 64 | metric_logger = MetricLogger(delimiter=" ") 65 | header = "Test:" 66 | 67 | for samples, targets, *_ in metric_logger.log_every(data_loader, 10, header): 68 | outputs = model(samples.to(device)) 69 | targets = targets.to(device) 70 | 71 | if criterion is not None: 72 | loss = criterion(outputs, targets) 73 | metric_logger.update(loss=loss.item()) 74 | 75 | for k, metric in metrics.items(): 76 | metric_inputs = postprocessors[k](outputs, targets) 77 | metric.update(**metric_inputs) 78 | 79 | metric_logger.synchronize_between_processes() 80 | logger.info(f"Averaged stats: {metric_logger}") 81 | 82 | stats = {k: metric.compute() for k, metric in metrics.items()} 83 | metric_logger_stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()} 84 | return metric_logger_stats, stats 85 | 86 | 87 | def all_gather_and_flatten(tensor_rank): 88 | tensor_all_ranks = torch.empty( 89 | distributed.get_global_size(), 90 | *tensor_rank.shape, 91 | dtype=tensor_rank.dtype, 92 | device=tensor_rank.device, 93 | ) 94 | tensor_list = list(tensor_all_ranks.unbind(0)) 95 | torch.distributed.all_gather(tensor_list, tensor_rank.contiguous()) 96 | return tensor_all_ranks.flatten(end_dim=1) 97 | 98 | 99 | def extract_features(model, dataset, batch_size, num_workers, gather_on_cpu=False): 100 | dataset_with_enumerated_targets = DatasetWithEnumeratedTargets(dataset) 101 | sample_count = len(dataset_with_enumerated_targets) 102 | data_loader = make_data_loader( 103 | dataset=dataset_with_enumerated_targets, 104 | batch_size=batch_size, 105 | num_workers=num_workers, 106 | sampler_type=SamplerType.DISTRIBUTED, 107 | drop_last=False, 108 | shuffle=False, 109 | ) 110 | return extract_features_with_dataloader(model, data_loader, sample_count, gather_on_cpu) 111 | 112 | 113 | @torch.inference_mode() 114 | def extract_features_with_dataloader(model, data_loader, sample_count, gather_on_cpu=False): 115 | gather_device = torch.device("cpu") if gather_on_cpu else torch.device("cuda") 116 | metric_logger = MetricLogger(delimiter=" ") 117 | features, all_labels = None, None 118 | for samples, (index, labels_rank) in metric_logger.log_every(data_loader, 10): 119 | samples = samples.cuda(non_blocking=True) 120 | labels_rank = labels_rank.cuda(non_blocking=True) 121 | index = index.cuda(non_blocking=True) 122 | features_rank = model(samples).float() 123 | 124 | # init storage feature matrix 125 | if features is None: 126 | features = torch.zeros(sample_count, features_rank.shape[-1], device=gather_device) 127 | labels_shape = list(labels_rank.shape) 128 | labels_shape[0] = sample_count 129 | all_labels = torch.full(labels_shape, fill_value=-1, device=gather_device) 130 | logger.info(f"Storing features into tensor of shape {features.shape}") 131 | 132 | # share indexes, features and labels between processes 133 | index_all = all_gather_and_flatten(index).to(gather_device) 134 | features_all_ranks = all_gather_and_flatten(features_rank).to(gather_device) 135 | labels_all_ranks = all_gather_and_flatten(labels_rank).to(gather_device) 136 | 137 | # update storage feature matrix 138 | if len(index_all) > 0: 139 | features.index_copy_(0, index_all, features_all_ranks) 140 | all_labels.index_copy_(0, index_all, labels_all_ranks) 141 | 142 | logger.info(f"Features shape: {tuple(features.shape)}") 143 | logger.info(f"Labels shape: {tuple(all_labels.shape)}") 144 | 145 | assert torch.all(all_labels > -1) 146 | 147 | return features, all_labels 148 | -------------------------------------------------------------------------------- /torchhub/facebookresearch_dinov2_main/dinov2/fsdp/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | from typing import Any 9 | 10 | import torch 11 | import dinov2.distributed as distributed 12 | from functools import partial 13 | from fvcore.common.checkpoint import Checkpointer 14 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 15 | from torch.distributed.fsdp import ShardingStrategy 16 | from torch.distributed.fsdp import MixedPrecision 17 | from torch.distributed.fsdp import StateDictType 18 | from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler 19 | from torch.distributed.fsdp.wrap import ModuleWrapPolicy 20 | from torch.distributed.fsdp._runtime_utils import _reshard 21 | 22 | 23 | def get_fsdp_wrapper(model_cfg, modules_to_wrap=set()): 24 | sharding_strategy_dict = { 25 | "NO_SHARD": ShardingStrategy.NO_SHARD, 26 | "SHARD_GRAD_OP": ShardingStrategy.SHARD_GRAD_OP, 27 | "FULL_SHARD": ShardingStrategy.FULL_SHARD, 28 | } 29 | 30 | dtype_dict = { 31 | "fp32": torch.float32, 32 | "fp16": torch.float16, 33 | "bf16": torch.bfloat16, 34 | } 35 | 36 | mixed_precision_config = MixedPrecision( 37 | param_dtype=dtype_dict[model_cfg.mixed_precision.param_dtype], 38 | reduce_dtype=dtype_dict[model_cfg.mixed_precision.reduce_dtype], 39 | buffer_dtype=dtype_dict[model_cfg.mixed_precision.buffer_dtype], 40 | ) 41 | 42 | sharding_strategy_config = sharding_strategy_dict[model_cfg.sharding_strategy] 43 | 44 | local_rank = distributed.get_local_rank() 45 | 46 | fsdp_wrapper = partial( 47 | FSDP, 48 | sharding_strategy=sharding_strategy_config, 49 | mixed_precision=mixed_precision_config, 50 | device_id=local_rank, 51 | sync_module_states=True, 52 | use_orig_params=True, 53 | auto_wrap_policy=ModuleWrapPolicy(modules_to_wrap), 54 | ) 55 | return fsdp_wrapper 56 | 57 | 58 | def is_fsdp(x): 59 | return isinstance(x, FSDP) 60 | 61 | 62 | def is_sharded_fsdp(x): 63 | return is_fsdp(x) and x.sharding_strategy is not ShardingStrategy.NO_SHARD 64 | 65 | 66 | def free_if_fsdp(x): 67 | if is_sharded_fsdp(x): 68 | handles = x._handles 69 | true_list = [True for h in handles] 70 | _reshard(x, handles, true_list) 71 | 72 | 73 | def get_fsdp_modules(x): 74 | return FSDP.fsdp_modules(x) 75 | 76 | 77 | def reshard_fsdp_model(x): 78 | for m in get_fsdp_modules(x): 79 | free_if_fsdp(m) 80 | 81 | 82 | def rankstr(): 83 | return f"rank_{distributed.get_global_rank()}" 84 | 85 | 86 | class FSDPCheckpointer(Checkpointer): 87 | def save(self, name: str, **kwargs: Any) -> None: 88 | """ 89 | Dump model and checkpointables to a file. 90 | 91 | Args: 92 | name (str): name of the file. 93 | kwargs (dict): extra arbitrary data to save. 94 | """ 95 | if not self.save_dir or not self.save_to_disk: 96 | return 97 | 98 | data = {} 99 | with FSDP.state_dict_type(self.model, StateDictType.LOCAL_STATE_DICT): 100 | data["model"] = self.model.state_dict() 101 | 102 | # data["model"] = self.model.state_dict() 103 | for key, obj in self.checkpointables.items(): 104 | data[key] = obj.state_dict() 105 | data.update(kwargs) 106 | 107 | basename = f"{name}.{rankstr()}.pth" 108 | save_file = os.path.join(self.save_dir, basename) 109 | assert os.path.basename(save_file) == basename, basename 110 | self.logger.info("Saving checkpoint to {}".format(save_file)) 111 | with self.path_manager.open(save_file, "wb") as f: 112 | torch.save(data, f) 113 | self.tag_last_checkpoint(basename) 114 | 115 | def load(self, *args, **kwargs): 116 | with FSDP.state_dict_type(self.model, StateDictType.LOCAL_STATE_DICT): 117 | return super().load(*args, **kwargs) 118 | 119 | def has_checkpoint(self) -> bool: 120 | """ 121 | Returns: 122 | bool: whether a checkpoint exists in the target directory. 123 | """ 124 | save_file = os.path.join(self.save_dir, f"last_checkpoint.{rankstr()}") 125 | return self.path_manager.exists(save_file) 126 | 127 | def get_checkpoint_file(self) -> str: 128 | """ 129 | Returns: 130 | str: The latest checkpoint file in target directory. 131 | """ 132 | save_file = os.path.join(self.save_dir, f"last_checkpoint.{rankstr()}") 133 | try: 134 | with self.path_manager.open(save_file, "r") as f: 135 | last_saved = f.read().strip() 136 | except IOError: 137 | # if file doesn't exist, maybe because it has just been 138 | # deleted by a separate process 139 | return "" 140 | # pyre-fixme[6]: For 2nd param expected `Union[PathLike[str], str]` but got 141 | # `Union[bytes, str]`. 142 | return os.path.join(self.save_dir, last_saved) 143 | 144 | def tag_last_checkpoint(self, last_filename_basename: str) -> None: 145 | """ 146 | Tag the last checkpoint. 147 | 148 | Args: 149 | last_filename_basename (str): the basename of the last filename. 150 | """ 151 | if distributed.is_enabled(): 152 | torch.distributed.barrier() 153 | save_file = os.path.join(self.save_dir, f"last_checkpoint.{rankstr()}") 154 | with self.path_manager.open(save_file, "w") as f: 155 | f.write(last_filename_basename) # pyre-ignore 156 | 157 | 158 | ShardedGradScaler = ShardedGradScaler 159 | -------------------------------------------------------------------------------- /torchhub/facebookresearch_dinov2_main/dinov2/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .dino_head import DINOHead 8 | from .mlp import Mlp 9 | from .patch_embed import PatchEmbed 10 | from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused 11 | from .block import NestedTensorBlock 12 | from .attention import MemEffAttention 13 | -------------------------------------------------------------------------------- /torchhub/facebookresearch_dinov2_main/dinov2/layers/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py 10 | 11 | import logging 12 | 13 | from torch import Tensor 14 | from torch import nn 15 | 16 | 17 | logger = logging.getLogger("dinov2") 18 | 19 | 20 | try: 21 | from xformers.ops import memory_efficient_attention, unbind, fmha 22 | 23 | XFORMERS_AVAILABLE = True 24 | except ImportError: 25 | logger.warning("xFormers not available") 26 | XFORMERS_AVAILABLE = False 27 | 28 | 29 | class Attention(nn.Module): 30 | def __init__( 31 | self, 32 | dim: int, 33 | num_heads: int = 8, 34 | qkv_bias: bool = False, 35 | proj_bias: bool = True, 36 | attn_drop: float = 0.0, 37 | proj_drop: float = 0.0, 38 | ) -> None: 39 | super().__init__() 40 | self.num_heads = num_heads 41 | head_dim = dim // num_heads 42 | self.scale = head_dim**-0.5 43 | 44 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 45 | self.attn_drop = nn.Dropout(attn_drop) 46 | self.proj = nn.Linear(dim, dim, bias=proj_bias) 47 | self.proj_drop = nn.Dropout(proj_drop) 48 | 49 | def forward(self, x: Tensor) -> Tensor: 50 | B, N, C = x.shape 51 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 52 | 53 | q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] 54 | attn = q @ k.transpose(-2, -1) 55 | 56 | attn = attn.softmax(dim=-1) 57 | attn = self.attn_drop(attn) 58 | 59 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 60 | x = self.proj(x) 61 | x = self.proj_drop(x) 62 | return x 63 | 64 | 65 | class MemEffAttention(Attention): 66 | def forward(self, x: Tensor, attn_bias=None) -> Tensor: 67 | if not XFORMERS_AVAILABLE: 68 | assert attn_bias is None, "xFormers is required for nested tensors usage" 69 | return super().forward(x) 70 | 71 | B, N, C = x.shape 72 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) 73 | 74 | q, k, v = unbind(qkv, 2) 75 | 76 | x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) 77 | x = x.reshape([B, N, C]) 78 | 79 | x = self.proj(x) 80 | x = self.proj_drop(x) 81 | return x 82 | -------------------------------------------------------------------------------- /torchhub/facebookresearch_dinov2_main/dinov2/layers/block.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py 10 | 11 | import logging 12 | from typing import Callable, List, Any, Tuple, Dict 13 | 14 | import torch 15 | from torch import nn, Tensor 16 | 17 | from .attention import Attention, MemEffAttention 18 | from .drop_path import DropPath 19 | from .layer_scale import LayerScale 20 | from .mlp import Mlp 21 | 22 | 23 | logger = logging.getLogger("dinov2") 24 | 25 | 26 | try: 27 | from xformers.ops import fmha 28 | from xformers.ops import scaled_index_add, index_select_cat 29 | 30 | XFORMERS_AVAILABLE = True 31 | except ImportError: 32 | logger.warning("xFormers not available") 33 | XFORMERS_AVAILABLE = False 34 | 35 | 36 | class Block(nn.Module): 37 | def __init__( 38 | self, 39 | dim: int, 40 | num_heads: int, 41 | mlp_ratio: float = 4.0, 42 | qkv_bias: bool = False, 43 | proj_bias: bool = True, 44 | ffn_bias: bool = True, 45 | drop: float = 0.0, 46 | attn_drop: float = 0.0, 47 | init_values=None, 48 | drop_path: float = 0.0, 49 | act_layer: Callable[..., nn.Module] = nn.GELU, 50 | norm_layer: Callable[..., nn.Module] = nn.LayerNorm, 51 | attn_class: Callable[..., nn.Module] = Attention, 52 | ffn_layer: Callable[..., nn.Module] = Mlp, 53 | ) -> None: 54 | super().__init__() 55 | # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") 56 | self.norm1 = norm_layer(dim) 57 | self.attn = attn_class( 58 | dim, 59 | num_heads=num_heads, 60 | qkv_bias=qkv_bias, 61 | proj_bias=proj_bias, 62 | attn_drop=attn_drop, 63 | proj_drop=drop, 64 | ) 65 | self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 66 | self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 67 | 68 | self.norm2 = norm_layer(dim) 69 | mlp_hidden_dim = int(dim * mlp_ratio) 70 | self.mlp = ffn_layer( 71 | in_features=dim, 72 | hidden_features=mlp_hidden_dim, 73 | act_layer=act_layer, 74 | drop=drop, 75 | bias=ffn_bias, 76 | ) 77 | self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 78 | self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 79 | 80 | self.sample_drop_ratio = drop_path 81 | 82 | def forward(self, x: Tensor) -> Tensor: 83 | def attn_residual_func(x: Tensor) -> Tensor: 84 | return self.ls1(self.attn(self.norm1(x))) 85 | 86 | def ffn_residual_func(x: Tensor) -> Tensor: 87 | return self.ls2(self.mlp(self.norm2(x))) 88 | 89 | if self.training and self.sample_drop_ratio > 0.1: 90 | # the overhead is compensated only for a drop path rate larger than 0.1 91 | x = drop_add_residual_stochastic_depth( 92 | x, 93 | residual_func=attn_residual_func, 94 | sample_drop_ratio=self.sample_drop_ratio, 95 | ) 96 | x = drop_add_residual_stochastic_depth( 97 | x, 98 | residual_func=ffn_residual_func, 99 | sample_drop_ratio=self.sample_drop_ratio, 100 | ) 101 | elif self.training and self.sample_drop_ratio > 0.0: 102 | x = x + self.drop_path1(attn_residual_func(x)) 103 | x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 104 | else: 105 | x = x + attn_residual_func(x) 106 | x = x + ffn_residual_func(x) 107 | return x 108 | 109 | 110 | def drop_add_residual_stochastic_depth( 111 | x: Tensor, 112 | residual_func: Callable[[Tensor], Tensor], 113 | sample_drop_ratio: float = 0.0, 114 | ) -> Tensor: 115 | # 1) extract subset using permutation 116 | b, n, d = x.shape 117 | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) 118 | brange = (torch.randperm(b, device=x.device))[:sample_subset_size] 119 | x_subset = x[brange] 120 | 121 | # 2) apply residual_func to get residual 122 | residual = residual_func(x_subset) 123 | 124 | x_flat = x.flatten(1) 125 | residual = residual.flatten(1) 126 | 127 | residual_scale_factor = b / sample_subset_size 128 | 129 | # 3) add the residual 130 | x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) 131 | return x_plus_residual.view_as(x) 132 | 133 | 134 | def get_branges_scales(x, sample_drop_ratio=0.0): 135 | b, n, d = x.shape 136 | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) 137 | brange = (torch.randperm(b, device=x.device))[:sample_subset_size] 138 | residual_scale_factor = b / sample_subset_size 139 | return brange, residual_scale_factor 140 | 141 | 142 | def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): 143 | if scaling_vector is None: 144 | x_flat = x.flatten(1) 145 | residual = residual.flatten(1) 146 | x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) 147 | else: 148 | x_plus_residual = scaled_index_add( 149 | x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor 150 | ) 151 | return x_plus_residual 152 | 153 | 154 | attn_bias_cache: Dict[Tuple, Any] = {} 155 | 156 | 157 | def get_attn_bias_and_cat(x_list, branges=None): 158 | """ 159 | this will perform the index select, cat the tensors, and provide the attn_bias from cache 160 | """ 161 | batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] 162 | all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) 163 | if all_shapes not in attn_bias_cache.keys(): 164 | seqlens = [] 165 | for b, x in zip(batch_sizes, x_list): 166 | for _ in range(b): 167 | seqlens.append(x.shape[1]) 168 | attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) 169 | attn_bias._batch_sizes = batch_sizes 170 | attn_bias_cache[all_shapes] = attn_bias 171 | 172 | if branges is not None: 173 | cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) 174 | else: 175 | tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) 176 | cat_tensors = torch.cat(tensors_bs1, dim=1) 177 | 178 | return attn_bias_cache[all_shapes], cat_tensors 179 | 180 | 181 | def drop_add_residual_stochastic_depth_list( 182 | x_list: List[Tensor], 183 | residual_func: Callable[[Tensor, Any], Tensor], 184 | sample_drop_ratio: float = 0.0, 185 | scaling_vector=None, 186 | ) -> Tensor: 187 | # 1) generate random set of indices for dropping samples in the batch 188 | branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] 189 | branges = [s[0] for s in branges_scales] 190 | residual_scale_factors = [s[1] for s in branges_scales] 191 | 192 | # 2) get attention bias and index+concat the tensors 193 | attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) 194 | 195 | # 3) apply residual_func to get residual, and split the result 196 | residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore 197 | 198 | outputs = [] 199 | for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): 200 | outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) 201 | return outputs 202 | 203 | 204 | class NestedTensorBlock(Block): 205 | def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: 206 | """ 207 | x_list contains a list of tensors to nest together and run 208 | """ 209 | assert isinstance(self.attn, MemEffAttention) 210 | 211 | if self.training and self.sample_drop_ratio > 0.0: 212 | 213 | def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 214 | return self.attn(self.norm1(x), attn_bias=attn_bias) 215 | 216 | def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 217 | return self.mlp(self.norm2(x)) 218 | 219 | x_list = drop_add_residual_stochastic_depth_list( 220 | x_list, 221 | residual_func=attn_residual_func, 222 | sample_drop_ratio=self.sample_drop_ratio, 223 | scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, 224 | ) 225 | x_list = drop_add_residual_stochastic_depth_list( 226 | x_list, 227 | residual_func=ffn_residual_func, 228 | sample_drop_ratio=self.sample_drop_ratio, 229 | scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, 230 | ) 231 | return x_list 232 | else: 233 | 234 | def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 235 | return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) 236 | 237 | def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 238 | return self.ls2(self.mlp(self.norm2(x))) 239 | 240 | attn_bias, x = get_attn_bias_and_cat(x_list) 241 | x = x + attn_residual_func(x, attn_bias=attn_bias) 242 | x = x + ffn_residual_func(x) 243 | return attn_bias.split(x) 244 | 245 | def forward(self, x_or_x_list): 246 | if isinstance(x_or_x_list, Tensor): 247 | return super().forward(x_or_x_list) 248 | elif isinstance(x_or_x_list, list): 249 | assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage" 250 | return self.forward_nested(x_or_x_list) 251 | else: 252 | raise AssertionError 253 | -------------------------------------------------------------------------------- /torchhub/facebookresearch_dinov2_main/dinov2/layers/dino_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn.init import trunc_normal_ 10 | from torch.nn.utils import weight_norm 11 | 12 | 13 | class DINOHead(nn.Module): 14 | def __init__( 15 | self, 16 | in_dim, 17 | out_dim, 18 | use_bn=False, 19 | nlayers=3, 20 | hidden_dim=2048, 21 | bottleneck_dim=256, 22 | mlp_bias=True, 23 | ): 24 | super().__init__() 25 | nlayers = max(nlayers, 1) 26 | self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias) 27 | self.apply(self._init_weights) 28 | self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) 29 | self.last_layer.weight_g.data.fill_(1) 30 | 31 | def _init_weights(self, m): 32 | if isinstance(m, nn.Linear): 33 | trunc_normal_(m.weight, std=0.02) 34 | if isinstance(m, nn.Linear) and m.bias is not None: 35 | nn.init.constant_(m.bias, 0) 36 | 37 | def forward(self, x): 38 | x = self.mlp(x) 39 | eps = 1e-6 if x.dtype == torch.float16 else 1e-12 40 | x = nn.functional.normalize(x, dim=-1, p=2, eps=eps) 41 | x = self.last_layer(x) 42 | return x 43 | 44 | 45 | def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True): 46 | if nlayers == 1: 47 | return nn.Linear(in_dim, bottleneck_dim, bias=bias) 48 | else: 49 | layers = [nn.Linear(in_dim, hidden_dim, bias=bias)] 50 | if use_bn: 51 | layers.append(nn.BatchNorm1d(hidden_dim)) 52 | layers.append(nn.GELU()) 53 | for _ in range(nlayers - 2): 54 | layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias)) 55 | if use_bn: 56 | layers.append(nn.BatchNorm1d(hidden_dim)) 57 | layers.append(nn.GELU()) 58 | layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias)) 59 | return nn.Sequential(*layers) 60 | -------------------------------------------------------------------------------- /torchhub/facebookresearch_dinov2_main/dinov2/layers/drop_path.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py 10 | 11 | 12 | from torch import nn 13 | 14 | 15 | def drop_path(x, drop_prob: float = 0.0, training: bool = False): 16 | if drop_prob == 0.0 or not training: 17 | return x 18 | keep_prob = 1 - drop_prob 19 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 20 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob) 21 | if keep_prob > 0.0: 22 | random_tensor.div_(keep_prob) 23 | output = x * random_tensor 24 | return output 25 | 26 | 27 | class DropPath(nn.Module): 28 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" 29 | 30 | def __init__(self, drop_prob=None): 31 | super(DropPath, self).__init__() 32 | self.drop_prob = drop_prob 33 | 34 | def forward(self, x): 35 | return drop_path(x, self.drop_prob, self.training) 36 | -------------------------------------------------------------------------------- /torchhub/facebookresearch_dinov2_main/dinov2/layers/layer_scale.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 8 | 9 | from typing import Union 10 | 11 | import torch 12 | from torch import Tensor 13 | from torch import nn 14 | 15 | 16 | class LayerScale(nn.Module): 17 | def __init__( 18 | self, 19 | dim: int, 20 | init_values: Union[float, Tensor] = 1e-5, 21 | inplace: bool = False, 22 | ) -> None: 23 | super().__init__() 24 | self.inplace = inplace 25 | self.gamma = nn.Parameter(init_values * torch.ones(dim)) 26 | 27 | def forward(self, x: Tensor) -> Tensor: 28 | return x.mul_(self.gamma) if self.inplace else x * self.gamma 29 | -------------------------------------------------------------------------------- /torchhub/facebookresearch_dinov2_main/dinov2/layers/mlp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py 10 | 11 | 12 | from typing import Callable, Optional 13 | 14 | from torch import Tensor, nn 15 | 16 | 17 | class Mlp(nn.Module): 18 | def __init__( 19 | self, 20 | in_features: int, 21 | hidden_features: Optional[int] = None, 22 | out_features: Optional[int] = None, 23 | act_layer: Callable[..., nn.Module] = nn.GELU, 24 | drop: float = 0.0, 25 | bias: bool = True, 26 | ) -> None: 27 | super().__init__() 28 | out_features = out_features or in_features 29 | hidden_features = hidden_features or in_features 30 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) 31 | self.act = act_layer() 32 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) 33 | self.drop = nn.Dropout(drop) 34 | 35 | def forward(self, x: Tensor) -> Tensor: 36 | x = self.fc1(x) 37 | x = self.act(x) 38 | x = self.drop(x) 39 | x = self.fc2(x) 40 | x = self.drop(x) 41 | return x 42 | -------------------------------------------------------------------------------- /torchhub/facebookresearch_dinov2_main/dinov2/layers/patch_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py 10 | 11 | from typing import Callable, Optional, Tuple, Union 12 | 13 | from torch import Tensor 14 | import torch.nn as nn 15 | 16 | 17 | def make_2tuple(x): 18 | if isinstance(x, tuple): 19 | assert len(x) == 2 20 | return x 21 | 22 | assert isinstance(x, int) 23 | return (x, x) 24 | 25 | 26 | class PatchEmbed(nn.Module): 27 | """ 28 | 2D image to patch embedding: (B,C,H,W) -> (B,N,D) 29 | 30 | Args: 31 | img_size: Image size. 32 | patch_size: Patch token size. 33 | in_chans: Number of input image channels. 34 | embed_dim: Number of linear projection output channels. 35 | norm_layer: Normalization layer. 36 | """ 37 | 38 | def __init__( 39 | self, 40 | img_size: Union[int, Tuple[int, int]] = 224, 41 | patch_size: Union[int, Tuple[int, int]] = 16, 42 | in_chans: int = 3, 43 | embed_dim: int = 768, 44 | norm_layer: Optional[Callable] = None, 45 | flatten_embedding: bool = True, 46 | ) -> None: 47 | super().__init__() 48 | 49 | image_HW = make_2tuple(img_size) 50 | patch_HW = make_2tuple(patch_size) 51 | patch_grid_size = ( 52 | image_HW[0] // patch_HW[0], 53 | image_HW[1] // patch_HW[1], 54 | ) 55 | 56 | self.img_size = image_HW 57 | self.patch_size = patch_HW 58 | self.patches_resolution = patch_grid_size 59 | self.num_patches = patch_grid_size[0] * patch_grid_size[1] 60 | 61 | self.in_chans = in_chans 62 | self.embed_dim = embed_dim 63 | 64 | self.flatten_embedding = flatten_embedding 65 | 66 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) 67 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 68 | 69 | def forward(self, x: Tensor) -> Tensor: 70 | _, _, H, W = x.shape 71 | patch_H, patch_W = self.patch_size 72 | 73 | assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" 74 | assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" 75 | 76 | x = self.proj(x) # B C H W 77 | H, W = x.size(2), x.size(3) 78 | x = x.flatten(2).transpose(1, 2) # B HW C 79 | x = self.norm(x) 80 | if not self.flatten_embedding: 81 | x = x.reshape(-1, H, W, self.embed_dim) # B H W C 82 | return x 83 | 84 | def flops(self) -> float: 85 | Ho, Wo = self.patches_resolution 86 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) 87 | if self.norm is not None: 88 | flops += Ho * Wo * self.embed_dim 89 | return flops 90 | -------------------------------------------------------------------------------- /torchhub/facebookresearch_dinov2_main/dinov2/layers/swiglu_ffn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Callable, Optional 8 | 9 | from torch import Tensor, nn 10 | import torch.nn.functional as F 11 | 12 | 13 | class SwiGLUFFN(nn.Module): 14 | def __init__( 15 | self, 16 | in_features: int, 17 | hidden_features: Optional[int] = None, 18 | out_features: Optional[int] = None, 19 | act_layer: Callable[..., nn.Module] = None, 20 | drop: float = 0.0, 21 | bias: bool = True, 22 | ) -> None: 23 | super().__init__() 24 | out_features = out_features or in_features 25 | hidden_features = hidden_features or in_features 26 | self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) 27 | self.w3 = nn.Linear(hidden_features, out_features, bias=bias) 28 | 29 | def forward(self, x: Tensor) -> Tensor: 30 | x12 = self.w12(x) 31 | x1, x2 = x12.chunk(2, dim=-1) 32 | hidden = F.silu(x1) * x2 33 | return self.w3(hidden) 34 | 35 | 36 | try: 37 | from xformers.ops import SwiGLU 38 | 39 | XFORMERS_AVAILABLE = True 40 | except ImportError: 41 | SwiGLU = SwiGLUFFN 42 | XFORMERS_AVAILABLE = False 43 | 44 | 45 | class SwiGLUFFNFused(SwiGLU): 46 | def __init__( 47 | self, 48 | in_features: int, 49 | hidden_features: Optional[int] = None, 50 | out_features: Optional[int] = None, 51 | act_layer: Callable[..., nn.Module] = None, 52 | drop: float = 0.0, 53 | bias: bool = True, 54 | ) -> None: 55 | out_features = out_features or in_features 56 | hidden_features = hidden_features or in_features 57 | hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 58 | super().__init__( 59 | in_features=in_features, 60 | hidden_features=hidden_features, 61 | out_features=out_features, 62 | bias=bias, 63 | ) 64 | -------------------------------------------------------------------------------- /torchhub/facebookresearch_dinov2_main/dinov2/logging/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import functools 8 | import logging 9 | import os 10 | import sys 11 | from typing import Optional 12 | 13 | import dinov2.distributed as distributed 14 | from .helpers import MetricLogger, SmoothedValue 15 | 16 | 17 | # So that calling _configure_logger multiple times won't add many handlers 18 | @functools.lru_cache() 19 | def _configure_logger( 20 | name: Optional[str] = None, 21 | *, 22 | level: int = logging.DEBUG, 23 | output: Optional[str] = None, 24 | ): 25 | """ 26 | Configure a logger. 27 | 28 | Adapted from Detectron2. 29 | 30 | Args: 31 | name: The name of the logger to configure. 32 | level: The logging level to use. 33 | output: A file name or a directory to save log. If None, will not save log file. 34 | If ends with ".txt" or ".log", assumed to be a file name. 35 | Otherwise, logs will be saved to `output/log.txt`. 36 | 37 | Returns: 38 | The configured logger. 39 | """ 40 | 41 | logger = logging.getLogger(name) 42 | logger.setLevel(level) 43 | logger.propagate = False 44 | 45 | # Loosely match Google glog format: 46 | # [IWEF]yyyymmdd hh:mm:ss.uuuuuu threadid file:line] msg 47 | # but use a shorter timestamp and include the logger name: 48 | # [IWEF]yyyymmdd hh:mm:ss logger threadid file:line] msg 49 | fmt_prefix = "%(levelname).1s%(asctime)s %(process)s %(name)s %(filename)s:%(lineno)s] " 50 | fmt_message = "%(message)s" 51 | fmt = fmt_prefix + fmt_message 52 | datefmt = "%Y%m%d %H:%M:%S" 53 | formatter = logging.Formatter(fmt=fmt, datefmt=datefmt) 54 | 55 | # stdout logging for main worker only 56 | if distributed.is_main_process(): 57 | handler = logging.StreamHandler(stream=sys.stdout) 58 | handler.setLevel(logging.DEBUG) 59 | handler.setFormatter(formatter) 60 | logger.addHandler(handler) 61 | 62 | # file logging for all workers 63 | if output: 64 | if os.path.splitext(output)[-1] in (".txt", ".log"): 65 | filename = output 66 | else: 67 | filename = os.path.join(output, "logs", "log.txt") 68 | 69 | if not distributed.is_main_process(): 70 | global_rank = distributed.get_global_rank() 71 | filename = filename + ".rank{}".format(global_rank) 72 | 73 | os.makedirs(os.path.dirname(filename), exist_ok=True) 74 | 75 | handler = logging.StreamHandler(open(filename, "a")) 76 | handler.setLevel(logging.DEBUG) 77 | handler.setFormatter(formatter) 78 | logger.addHandler(handler) 79 | 80 | return logger 81 | 82 | 83 | def setup_logging( 84 | output: Optional[str] = None, 85 | *, 86 | name: Optional[str] = None, 87 | level: int = logging.DEBUG, 88 | capture_warnings: bool = True, 89 | ) -> None: 90 | """ 91 | Setup logging. 92 | 93 | Args: 94 | output: A file name or a directory to save log files. If None, log 95 | files will not be saved. If output ends with ".txt" or ".log", it 96 | is assumed to be a file name. 97 | Otherwise, logs will be saved to `output/log.txt`. 98 | name: The name of the logger to configure, by default the root logger. 99 | level: The logging level to use. 100 | capture_warnings: Whether warnings should be captured as logs. 101 | """ 102 | logging.captureWarnings(capture_warnings) 103 | _configure_logger(name, level=level, output=output) 104 | -------------------------------------------------------------------------------- /torchhub/facebookresearch_dinov2_main/dinov2/logging/helpers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from collections import defaultdict, deque 8 | import datetime 9 | import json 10 | import logging 11 | import time 12 | 13 | import torch 14 | 15 | import dinov2.distributed as distributed 16 | 17 | 18 | logger = logging.getLogger("dinov2") 19 | 20 | 21 | class MetricLogger(object): 22 | def __init__(self, delimiter="\t", output_file=None): 23 | self.meters = defaultdict(SmoothedValue) 24 | self.delimiter = delimiter 25 | self.output_file = output_file 26 | 27 | def update(self, **kwargs): 28 | for k, v in kwargs.items(): 29 | if isinstance(v, torch.Tensor): 30 | v = v.item() 31 | assert isinstance(v, (float, int)) 32 | self.meters[k].update(v) 33 | 34 | def __getattr__(self, attr): 35 | if attr in self.meters: 36 | return self.meters[attr] 37 | if attr in self.__dict__: 38 | return self.__dict__[attr] 39 | raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr)) 40 | 41 | def __str__(self): 42 | loss_str = [] 43 | for name, meter in self.meters.items(): 44 | loss_str.append("{}: {}".format(name, str(meter))) 45 | return self.delimiter.join(loss_str) 46 | 47 | def synchronize_between_processes(self): 48 | for meter in self.meters.values(): 49 | meter.synchronize_between_processes() 50 | 51 | def add_meter(self, name, meter): 52 | self.meters[name] = meter 53 | 54 | def dump_in_output_file(self, iteration, iter_time, data_time): 55 | if self.output_file is None or not distributed.is_main_process(): 56 | return 57 | dict_to_dump = dict( 58 | iteration=iteration, 59 | iter_time=iter_time, 60 | data_time=data_time, 61 | ) 62 | dict_to_dump.update({k: v.median for k, v in self.meters.items()}) 63 | with open(self.output_file, "a") as f: 64 | f.write(json.dumps(dict_to_dump) + "\n") 65 | pass 66 | 67 | def log_every(self, iterable, print_freq, header=None, n_iterations=None, start_iteration=0): 68 | i = start_iteration 69 | if not header: 70 | header = "" 71 | start_time = time.time() 72 | end = time.time() 73 | iter_time = SmoothedValue(fmt="{avg:.6f}") 74 | data_time = SmoothedValue(fmt="{avg:.6f}") 75 | 76 | if n_iterations is None: 77 | n_iterations = len(iterable) 78 | 79 | space_fmt = ":" + str(len(str(n_iterations))) + "d" 80 | 81 | log_list = [ 82 | header, 83 | "[{0" + space_fmt + "}/{1}]", 84 | "eta: {eta}", 85 | "{meters}", 86 | "time: {time}", 87 | "data: {data}", 88 | ] 89 | if torch.cuda.is_available(): 90 | log_list += ["max mem: {memory:.0f}"] 91 | 92 | log_msg = self.delimiter.join(log_list) 93 | MB = 1024.0 * 1024.0 94 | for obj in iterable: 95 | data_time.update(time.time() - end) 96 | yield obj 97 | iter_time.update(time.time() - end) 98 | if i % print_freq == 0 or i == n_iterations - 1: 99 | self.dump_in_output_file(iteration=i, iter_time=iter_time.avg, data_time=data_time.avg) 100 | eta_seconds = iter_time.global_avg * (n_iterations - i) 101 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 102 | if torch.cuda.is_available(): 103 | logger.info( 104 | log_msg.format( 105 | i, 106 | n_iterations, 107 | eta=eta_string, 108 | meters=str(self), 109 | time=str(iter_time), 110 | data=str(data_time), 111 | memory=torch.cuda.max_memory_allocated() / MB, 112 | ) 113 | ) 114 | else: 115 | logger.info( 116 | log_msg.format( 117 | i, 118 | n_iterations, 119 | eta=eta_string, 120 | meters=str(self), 121 | time=str(iter_time), 122 | data=str(data_time), 123 | ) 124 | ) 125 | i += 1 126 | end = time.time() 127 | if i >= n_iterations: 128 | break 129 | total_time = time.time() - start_time 130 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 131 | logger.info("{} Total time: {} ({:.6f} s / it)".format(header, total_time_str, total_time / n_iterations)) 132 | 133 | 134 | class SmoothedValue: 135 | """Track a series of values and provide access to smoothed values over a 136 | window or the global series average. 137 | """ 138 | 139 | def __init__(self, window_size=20, fmt=None): 140 | if fmt is None: 141 | fmt = "{median:.4f} ({global_avg:.4f})" 142 | self.deque = deque(maxlen=window_size) 143 | self.total = 0.0 144 | self.count = 0 145 | self.fmt = fmt 146 | 147 | def update(self, value, num=1): 148 | self.deque.append(value) 149 | self.count += num 150 | self.total += value * num 151 | 152 | def synchronize_between_processes(self): 153 | """ 154 | Distributed synchronization of the metric 155 | Warning: does not synchronize the deque! 156 | """ 157 | if not distributed.is_enabled(): 158 | return 159 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") 160 | torch.distributed.barrier() 161 | torch.distributed.all_reduce(t) 162 | t = t.tolist() 163 | self.count = int(t[0]) 164 | self.total = t[1] 165 | 166 | @property 167 | def median(self): 168 | d = torch.tensor(list(self.deque)) 169 | return d.median().item() 170 | 171 | @property 172 | def avg(self): 173 | d = torch.tensor(list(self.deque), dtype=torch.float32) 174 | return d.mean().item() 175 | 176 | @property 177 | def global_avg(self): 178 | return self.total / self.count 179 | 180 | @property 181 | def max(self): 182 | return max(self.deque) 183 | 184 | @property 185 | def value(self): 186 | return self.deque[-1] 187 | 188 | def __str__(self): 189 | return self.fmt.format( 190 | median=self.median, 191 | avg=self.avg, 192 | global_avg=self.global_avg, 193 | max=self.max, 194 | value=self.value, 195 | ) 196 | -------------------------------------------------------------------------------- /torchhub/facebookresearch_dinov2_main/dinov2/loss/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .dino_clstoken_loss import DINOLoss 8 | from .ibot_patch_loss import iBOTPatchLoss 9 | from .koleo_loss import KoLeoLoss 10 | -------------------------------------------------------------------------------- /torchhub/facebookresearch_dinov2_main/dinov2/loss/dino_clstoken_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.distributed as dist 9 | import torch.nn.functional as F 10 | from torch import nn 11 | 12 | 13 | class DINOLoss(nn.Module): 14 | def __init__( 15 | self, 16 | out_dim, 17 | student_temp=0.1, 18 | center_momentum=0.9, 19 | ): 20 | super().__init__() 21 | self.student_temp = student_temp 22 | self.center_momentum = center_momentum 23 | self.register_buffer("center", torch.zeros(1, out_dim)) 24 | self.updated = True 25 | self.reduce_handle = None 26 | self.len_teacher_output = None 27 | self.async_batch_center = None 28 | 29 | @torch.no_grad() 30 | def softmax_center_teacher(self, teacher_output, teacher_temp): 31 | self.apply_center_update() 32 | # teacher centering and sharpening 33 | return F.softmax((teacher_output - self.center) / teacher_temp, dim=-1) 34 | 35 | @torch.no_grad() 36 | def sinkhorn_knopp_teacher(self, teacher_output, teacher_temp, n_iterations=3): 37 | teacher_output = teacher_output.float() 38 | world_size = dist.get_world_size() if dist.is_initialized() else 1 39 | Q = torch.exp(teacher_output / teacher_temp).t() # Q is K-by-B for consistency with notations from our paper 40 | B = Q.shape[1] * world_size # number of samples to assign 41 | K = Q.shape[0] # how many prototypes 42 | 43 | # make the matrix sums to 1 44 | sum_Q = torch.sum(Q) 45 | if dist.is_initialized(): 46 | dist.all_reduce(sum_Q) 47 | Q /= sum_Q 48 | 49 | for it in range(n_iterations): 50 | # normalize each row: total weight per prototype must be 1/K 51 | sum_of_rows = torch.sum(Q, dim=1, keepdim=True) 52 | if dist.is_initialized(): 53 | dist.all_reduce(sum_of_rows) 54 | Q /= sum_of_rows 55 | Q /= K 56 | 57 | # normalize each column: total weight per sample must be 1/B 58 | Q /= torch.sum(Q, dim=0, keepdim=True) 59 | Q /= B 60 | 61 | Q *= B # the columns must sum to 1 so that Q is an assignment 62 | return Q.t() 63 | 64 | def forward(self, student_output_list, teacher_out_softmaxed_centered_list): 65 | """ 66 | Cross-entropy between softmax outputs of the teacher and student networks. 67 | """ 68 | # TODO: Use cross_entropy_distribution here 69 | total_loss = 0 70 | for s in student_output_list: 71 | lsm = F.log_softmax(s / self.student_temp, dim=-1) 72 | for t in teacher_out_softmaxed_centered_list: 73 | loss = torch.sum(t * lsm, dim=-1) 74 | total_loss -= loss.mean() 75 | return total_loss 76 | 77 | @torch.no_grad() 78 | def update_center(self, teacher_output): 79 | self.reduce_center_update(teacher_output) 80 | 81 | @torch.no_grad() 82 | def reduce_center_update(self, teacher_output): 83 | self.updated = False 84 | self.len_teacher_output = len(teacher_output) 85 | self.async_batch_center = torch.sum(teacher_output, dim=0, keepdim=True) 86 | if dist.is_initialized(): 87 | self.reduce_handle = dist.all_reduce(self.async_batch_center, async_op=True) 88 | 89 | @torch.no_grad() 90 | def apply_center_update(self): 91 | if self.updated is False: 92 | world_size = dist.get_world_size() if dist.is_initialized() else 1 93 | 94 | if self.reduce_handle is not None: 95 | self.reduce_handle.wait() 96 | _t = self.async_batch_center / (self.len_teacher_output * world_size) 97 | 98 | self.center = self.center * self.center_momentum + _t * (1 - self.center_momentum) 99 | 100 | self.updated = True 101 | -------------------------------------------------------------------------------- /torchhub/facebookresearch_dinov2_main/dinov2/loss/ibot_patch_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.distributed as dist 9 | import torch.nn.functional as F 10 | from torch import nn 11 | 12 | import logging 13 | 14 | 15 | logger = logging.getLogger("dinov2") 16 | 17 | 18 | try: 19 | from xformers.ops import cross_entropy 20 | 21 | def lossfunc(t, s, temp): 22 | s = s.float() 23 | t = t.float() 24 | if s.ndim == 2: 25 | return -cross_entropy(s.unsqueeze(0), t.unsqueeze(0), temp, bw_inplace=True).squeeze(0) 26 | elif s.ndim == 3: 27 | return -cross_entropy(s, t, temp, bw_inplace=True) 28 | 29 | except ImportError: 30 | 31 | def lossfunc(t, s, temp): 32 | return torch.sum(t * F.log_softmax(s / temp, dim=-1), dim=-1) 33 | 34 | 35 | class iBOTPatchLoss(nn.Module): 36 | def __init__(self, patch_out_dim, student_temp=0.1, center_momentum=0.9): 37 | super().__init__() 38 | self.student_temp = student_temp 39 | self.center_momentum = center_momentum 40 | self.register_buffer("center", torch.zeros(1, 1, patch_out_dim)) 41 | self.updated = True 42 | self.reduce_handle = None 43 | self.len_teacher_patch_tokens = None 44 | self.async_batch_center = None 45 | 46 | @torch.no_grad() 47 | def softmax_center_teacher(self, teacher_patch_tokens, teacher_temp): 48 | self.apply_center_update() 49 | # teacher centering and sharpening 50 | # 51 | # WARNING: 52 | # as self.center is a float32, everything gets casted to float32 afterwards 53 | # 54 | # teacher_patch_tokens = teacher_patch_tokens.float() 55 | # return F.softmax((teacher_patch_tokens.sub_(self.center.to(teacher_patch_tokens.dtype))).mul_(1 / teacher_temp), dim=-1) 56 | 57 | return F.softmax((teacher_patch_tokens - self.center) / teacher_temp, dim=-1) 58 | 59 | # this is experimental, keep everything in float16 and let's see what happens: 60 | # return F.softmax((teacher_patch_tokens.sub_(self.center)) / teacher_temp, dim=-1) 61 | 62 | @torch.no_grad() 63 | def sinkhorn_knopp_teacher(self, teacher_output, teacher_temp, n_masked_patches_tensor, n_iterations=3): 64 | teacher_output = teacher_output.float() 65 | # world_size = dist.get_world_size() if dist.is_initialized() else 1 66 | Q = torch.exp(teacher_output / teacher_temp).t() # Q is K-by-B for consistency with notations from our paper 67 | # B = Q.shape[1] * world_size # number of samples to assign 68 | B = n_masked_patches_tensor 69 | dist.all_reduce(B) 70 | K = Q.shape[0] # how many prototypes 71 | 72 | # make the matrix sums to 1 73 | sum_Q = torch.sum(Q) 74 | if dist.is_initialized(): 75 | dist.all_reduce(sum_Q) 76 | Q /= sum_Q 77 | 78 | for it in range(n_iterations): 79 | # normalize each row: total weight per prototype must be 1/K 80 | sum_of_rows = torch.sum(Q, dim=1, keepdim=True) 81 | if dist.is_initialized(): 82 | dist.all_reduce(sum_of_rows) 83 | Q /= sum_of_rows 84 | Q /= K 85 | 86 | # normalize each column: total weight per sample must be 1/B 87 | Q /= torch.sum(Q, dim=0, keepdim=True) 88 | Q /= B 89 | 90 | Q *= B # the columns must sum to 1 so that Q is an assignment 91 | return Q.t() 92 | 93 | def forward(self, student_patch_tokens, teacher_patch_tokens, student_masks_flat): 94 | """ 95 | Cross-entropy between softmax outputs of the teacher and student networks. 96 | student_patch_tokens: (B, N, D) tensor 97 | teacher_patch_tokens: (B, N, D) tensor 98 | student_masks_flat: (B, N) tensor 99 | """ 100 | t = teacher_patch_tokens 101 | s = student_patch_tokens 102 | loss = torch.sum(t * F.log_softmax(s / self.student_temp, dim=-1), dim=-1) 103 | loss = torch.sum(loss * student_masks_flat.float(), dim=-1) / student_masks_flat.sum(dim=-1).clamp(min=1.0) 104 | return -loss.mean() 105 | 106 | def forward_masked( 107 | self, 108 | student_patch_tokens_masked, 109 | teacher_patch_tokens_masked, 110 | student_masks_flat, 111 | n_masked_patches=None, 112 | masks_weight=None, 113 | ): 114 | t = teacher_patch_tokens_masked 115 | s = student_patch_tokens_masked 116 | # loss = torch.sum(t * F.log_softmax(s / self.student_temp, dim=-1), dim=-1) 117 | loss = lossfunc(t, s, self.student_temp) 118 | if masks_weight is None: 119 | masks_weight = ( 120 | (1 / student_masks_flat.sum(-1).clamp(min=1.0)) 121 | .unsqueeze(-1) 122 | .expand_as(student_masks_flat)[student_masks_flat] 123 | ) 124 | if n_masked_patches is not None: 125 | loss = loss[:n_masked_patches] 126 | loss = loss * masks_weight 127 | return -loss.sum() / student_masks_flat.shape[0] 128 | 129 | @torch.no_grad() 130 | def update_center(self, teacher_patch_tokens): 131 | self.reduce_center_update(teacher_patch_tokens) 132 | 133 | @torch.no_grad() 134 | def reduce_center_update(self, teacher_patch_tokens): 135 | self.updated = False 136 | self.len_teacher_patch_tokens = len(teacher_patch_tokens) 137 | self.async_batch_center = torch.sum(teacher_patch_tokens.mean(1), dim=0, keepdim=True) 138 | if dist.is_initialized(): 139 | self.reduce_handle = dist.all_reduce(self.async_batch_center, async_op=True) 140 | 141 | @torch.no_grad() 142 | def apply_center_update(self): 143 | if self.updated is False: 144 | world_size = dist.get_world_size() if dist.is_initialized() else 1 145 | 146 | if self.reduce_handle is not None: 147 | self.reduce_handle.wait() 148 | _t = self.async_batch_center / (self.len_teacher_patch_tokens * world_size) 149 | 150 | self.center = self.center * self.center_momentum + _t * (1 - self.center_momentum) 151 | 152 | self.updated = True 153 | -------------------------------------------------------------------------------- /torchhub/facebookresearch_dinov2_main/dinov2/loss/koleo_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | # import torch.distributed as dist 14 | 15 | 16 | logger = logging.getLogger("dinov2") 17 | 18 | 19 | class KoLeoLoss(nn.Module): 20 | """Kozachenko-Leonenko entropic loss regularizer from Sablayrolles et al. - 2018 - Spreading vectors for similarity search""" 21 | 22 | def __init__(self): 23 | super().__init__() 24 | self.pdist = nn.PairwiseDistance(2, eps=1e-8) 25 | 26 | def pairwise_NNs_inner(self, x): 27 | """ 28 | Pairwise nearest neighbors for L2-normalized vectors. 29 | Uses Torch rather than Faiss to remain on GPU. 30 | """ 31 | # parwise dot products (= inverse distance) 32 | dots = torch.mm(x, x.t()) 33 | n = x.shape[0] 34 | dots.view(-1)[:: (n + 1)].fill_(-1) # Trick to fill diagonal with -1 35 | # max inner prod -> min distance 36 | _, I = torch.max(dots, dim=1) # noqa: E741 37 | return I 38 | 39 | def forward(self, student_output, eps=1e-8): 40 | """ 41 | Args: 42 | student_output (BxD): backbone output of student 43 | """ 44 | with torch.cuda.amp.autocast(enabled=False): 45 | student_output = F.normalize(student_output, eps=eps, p=2, dim=-1) 46 | I = self.pairwise_NNs_inner(student_output) # noqa: E741 47 | distances = self.pdist(student_output, student_output[I]) # BxD, BxD -> B 48 | loss = -torch.log(distances + eps).mean() 49 | return loss 50 | -------------------------------------------------------------------------------- /torchhub/facebookresearch_dinov2_main/dinov2/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | 9 | from . import vision_transformer as vits 10 | 11 | 12 | logger = logging.getLogger("dinov2") 13 | 14 | 15 | def build_model(args, only_teacher=False, img_size=224): 16 | args.arch = args.arch.removesuffix("_memeff") 17 | if "vit" in args.arch: 18 | vit_kwargs = dict( 19 | img_size=img_size, 20 | patch_size=args.patch_size, 21 | init_values=args.layerscale, 22 | ffn_layer=args.ffn_layer, 23 | block_chunks=args.block_chunks, 24 | qkv_bias=args.qkv_bias, 25 | proj_bias=args.proj_bias, 26 | ffn_bias=args.ffn_bias, 27 | ) 28 | teacher = vits.__dict__[args.arch](**vit_kwargs) 29 | if only_teacher: 30 | return teacher, teacher.embed_dim 31 | student = vits.__dict__[args.arch]( 32 | **vit_kwargs, 33 | drop_path_rate=args.drop_path_rate, 34 | drop_path_uniform=args.drop_path_uniform, 35 | ) 36 | embed_dim = student.embed_dim 37 | return student, teacher, embed_dim 38 | 39 | 40 | def build_model_from_cfg(cfg, only_teacher=False): 41 | return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size) 42 | -------------------------------------------------------------------------------- /torchhub/facebookresearch_dinov2_main/dinov2/run/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /torchhub/facebookresearch_dinov2_main/dinov2/run/eval/knn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import os 9 | import sys 10 | 11 | from dinov2.eval.knn import get_args_parser as get_knn_args_parser 12 | from dinov2.logging import setup_logging 13 | from dinov2.run.submit import get_args_parser, submit_jobs 14 | 15 | 16 | logger = logging.getLogger("dinov2") 17 | 18 | 19 | class Evaluator: 20 | def __init__(self, args): 21 | self.args = args 22 | 23 | def __call__(self): 24 | from dinov2.eval.knn import main as knn_main 25 | 26 | self._setup_args() 27 | knn_main(self.args) 28 | 29 | def checkpoint(self): 30 | import submitit 31 | 32 | logger.info(f"Requeuing {self.args}") 33 | empty = type(self)(self.args) 34 | return submitit.helpers.DelayedSubmission(empty) 35 | 36 | def _setup_args(self): 37 | import submitit 38 | 39 | job_env = submitit.JobEnvironment() 40 | self.args.output_dir = self.args.output_dir.replace("%j", str(job_env.job_id)) 41 | logger.info(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") 42 | logger.info(f"Args: {self.args}") 43 | 44 | 45 | def main(): 46 | description = "Submitit launcher for DINOv2 k-NN evaluation" 47 | knn_args_parser = get_knn_args_parser(add_help=False) 48 | parents = [knn_args_parser] 49 | args_parser = get_args_parser(description=description, parents=parents) 50 | args = args_parser.parse_args() 51 | 52 | setup_logging() 53 | 54 | assert os.path.exists(args.config_file), "Configuration file does not exist!" 55 | submit_jobs(Evaluator, args, name="dinov2:knn") 56 | return 0 57 | 58 | 59 | if __name__ == "__main__": 60 | sys.exit(main()) 61 | -------------------------------------------------------------------------------- /torchhub/facebookresearch_dinov2_main/dinov2/run/eval/linear.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import os 9 | import sys 10 | 11 | from dinov2.eval.linear import get_args_parser as get_linear_args_parser 12 | from dinov2.logging import setup_logging 13 | from dinov2.run.submit import get_args_parser, submit_jobs 14 | 15 | 16 | logger = logging.getLogger("dinov2") 17 | 18 | 19 | class Evaluator: 20 | def __init__(self, args): 21 | self.args = args 22 | 23 | def __call__(self): 24 | from dinov2.eval.linear import main as linear_main 25 | 26 | self._setup_args() 27 | linear_main(self.args) 28 | 29 | def checkpoint(self): 30 | import submitit 31 | 32 | logger.info(f"Requeuing {self.args}") 33 | empty = type(self)(self.args) 34 | return submitit.helpers.DelayedSubmission(empty) 35 | 36 | def _setup_args(self): 37 | import submitit 38 | 39 | job_env = submitit.JobEnvironment() 40 | self.args.output_dir = self.args.output_dir.replace("%j", str(job_env.job_id)) 41 | logger.info(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") 42 | logger.info(f"Args: {self.args}") 43 | 44 | 45 | def main(): 46 | description = "Submitit launcher for DINOv2 linear evaluation" 47 | linear_args_parser = get_linear_args_parser(add_help=False) 48 | parents = [linear_args_parser] 49 | args_parser = get_args_parser(description=description, parents=parents) 50 | args = args_parser.parse_args() 51 | 52 | setup_logging() 53 | 54 | assert os.path.exists(args.config_file), "Configuration file does not exist!" 55 | submit_jobs(Evaluator, args, name="dinov2:linear") 56 | return 0 57 | 58 | 59 | if __name__ == "__main__": 60 | sys.exit(main()) 61 | -------------------------------------------------------------------------------- /torchhub/facebookresearch_dinov2_main/dinov2/run/eval/log_regression.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import os 9 | import sys 10 | 11 | from dinov2.eval.log_regression import get_args_parser as get_log_regression_args_parser 12 | from dinov2.logging import setup_logging 13 | from dinov2.run.submit import get_args_parser, submit_jobs 14 | 15 | 16 | logger = logging.getLogger("dinov2") 17 | 18 | 19 | class Evaluator: 20 | def __init__(self, args): 21 | self.args = args 22 | 23 | def __call__(self): 24 | from dinov2.eval.log_regression import main as log_regression_main 25 | 26 | self._setup_args() 27 | log_regression_main(self.args) 28 | 29 | def checkpoint(self): 30 | import submitit 31 | 32 | logger.info(f"Requeuing {self.args}") 33 | empty = type(self)(self.args) 34 | return submitit.helpers.DelayedSubmission(empty) 35 | 36 | def _setup_args(self): 37 | import submitit 38 | 39 | job_env = submitit.JobEnvironment() 40 | self.args.output_dir = self.args.output_dir.replace("%j", str(job_env.job_id)) 41 | logger.info(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") 42 | logger.info(f"Args: {self.args}") 43 | 44 | 45 | def main(): 46 | description = "Submitit launcher for DINOv2 logistic evaluation" 47 | log_regression_args_parser = get_log_regression_args_parser(add_help=False) 48 | parents = [log_regression_args_parser] 49 | args_parser = get_args_parser(description=description, parents=parents) 50 | args = args_parser.parse_args() 51 | 52 | setup_logging() 53 | 54 | assert os.path.exists(args.config_file), "Configuration file does not exist!" 55 | submit_jobs(Evaluator, args, name="dinov2:logreg") 56 | return 0 57 | 58 | 59 | if __name__ == "__main__": 60 | sys.exit(main()) 61 | -------------------------------------------------------------------------------- /torchhub/facebookresearch_dinov2_main/dinov2/run/submit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import logging 9 | import os 10 | from pathlib import Path 11 | from typing import List, Optional 12 | 13 | import submitit 14 | 15 | from dinov2.utils.cluster import ( 16 | get_slurm_executor_parameters, 17 | get_slurm_partition, 18 | get_user_checkpoint_path, 19 | ) 20 | 21 | 22 | logger = logging.getLogger("dinov2") 23 | 24 | 25 | def get_args_parser( 26 | description: Optional[str] = None, 27 | parents: Optional[List[argparse.ArgumentParser]] = None, 28 | add_help: bool = True, 29 | ) -> argparse.ArgumentParser: 30 | parents = parents or [] 31 | slurm_partition = get_slurm_partition() 32 | parser = argparse.ArgumentParser( 33 | description=description, 34 | parents=parents, 35 | add_help=add_help, 36 | ) 37 | parser.add_argument( 38 | "--ngpus", 39 | "--gpus", 40 | "--gpus-per-node", 41 | default=8, 42 | type=int, 43 | help="Number of GPUs to request on each node", 44 | ) 45 | parser.add_argument( 46 | "--nodes", 47 | "--nnodes", 48 | default=2, 49 | type=int, 50 | help="Number of nodes to request", 51 | ) 52 | parser.add_argument( 53 | "--timeout", 54 | default=2800, 55 | type=int, 56 | help="Duration of the job", 57 | ) 58 | parser.add_argument( 59 | "--partition", 60 | default=slurm_partition, 61 | type=str, 62 | help="Partition where to submit", 63 | ) 64 | parser.add_argument( 65 | "--use-volta32", 66 | action="store_true", 67 | help="Request V100-32GB GPUs", 68 | ) 69 | parser.add_argument( 70 | "--comment", 71 | default="", 72 | type=str, 73 | help="Comment to pass to scheduler, e.g. priority message", 74 | ) 75 | parser.add_argument( 76 | "--exclude", 77 | default="", 78 | type=str, 79 | help="Nodes to exclude", 80 | ) 81 | return parser 82 | 83 | 84 | def get_shared_folder() -> Path: 85 | user_checkpoint_path = get_user_checkpoint_path() 86 | if user_checkpoint_path is None: 87 | raise RuntimeError("Path to user checkpoint cannot be determined") 88 | path = user_checkpoint_path / "experiments" 89 | path.mkdir(exist_ok=True) 90 | return path 91 | 92 | 93 | def submit_jobs(task_class, args, name: str): 94 | if not args.output_dir: 95 | args.output_dir = str(get_shared_folder() / "%j") 96 | 97 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 98 | executor = submitit.AutoExecutor(folder=args.output_dir, slurm_max_num_timeout=30) 99 | 100 | kwargs = {} 101 | if args.use_volta32: 102 | kwargs["slurm_constraint"] = "volta32gb" 103 | if args.comment: 104 | kwargs["slurm_comment"] = args.comment 105 | if args.exclude: 106 | kwargs["slurm_exclude"] = args.exclude 107 | 108 | executor_params = get_slurm_executor_parameters( 109 | nodes=args.nodes, 110 | num_gpus_per_node=args.ngpus, 111 | timeout_min=args.timeout, # max is 60 * 72 112 | slurm_signal_delay_s=120, 113 | slurm_partition=args.partition, 114 | **kwargs, 115 | ) 116 | executor.update_parameters(name=name, **executor_params) 117 | 118 | task = task_class(args) 119 | job = executor.submit(task) 120 | 121 | logger.info(f"Submitted job_id: {job.job_id}") 122 | str_output_dir = os.path.abspath(args.output_dir).replace("%j", str(job.job_id)) 123 | logger.info(f"Logs and checkpoints will be saved at: {str_output_dir}") 124 | -------------------------------------------------------------------------------- /torchhub/facebookresearch_dinov2_main/dinov2/run/train/train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import os 9 | import sys 10 | 11 | from dinov2.logging import setup_logging 12 | from dinov2.train import get_args_parser as get_train_args_parser 13 | from dinov2.run.submit import get_args_parser, submit_jobs 14 | 15 | 16 | logger = logging.getLogger("dinov2") 17 | 18 | 19 | class Trainer(object): 20 | def __init__(self, args): 21 | self.args = args 22 | 23 | def __call__(self): 24 | from dinov2.train import main as train_main 25 | 26 | self._setup_args() 27 | train_main(self.args) 28 | 29 | def checkpoint(self): 30 | import submitit 31 | 32 | logger.info(f"Requeuing {self.args}") 33 | empty = type(self)(self.args) 34 | return submitit.helpers.DelayedSubmission(empty) 35 | 36 | def _setup_args(self): 37 | import submitit 38 | 39 | job_env = submitit.JobEnvironment() 40 | self.args.output_dir = self.args.output_dir.replace("%j", str(job_env.job_id)) 41 | logger.info(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") 42 | logger.info(f"Args: {self.args}") 43 | 44 | 45 | def main(): 46 | description = "Submitit launcher for DINOv2 training" 47 | train_args_parser = get_train_args_parser(add_help=False) 48 | parents = [train_args_parser] 49 | args_parser = get_args_parser(description=description, parents=parents) 50 | args = args_parser.parse_args() 51 | 52 | setup_logging() 53 | 54 | assert os.path.exists(args.config_file), "Configuration file does not exist!" 55 | submit_jobs(Trainer, args, name="dinov2:train") 56 | return 0 57 | 58 | 59 | if __name__ == "__main__": 60 | sys.exit(main()) 61 | -------------------------------------------------------------------------------- /torchhub/facebookresearch_dinov2_main/dinov2/train/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .train import get_args_parser, main 8 | from .ssl_meta_arch import SSLMetaArch 9 | -------------------------------------------------------------------------------- /torchhub/facebookresearch_dinov2_main/dinov2/train/train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import logging 9 | import math 10 | import os 11 | from functools import partial 12 | 13 | from fvcore.common.checkpoint import PeriodicCheckpointer 14 | import torch 15 | 16 | from dinov2.data import SamplerType, make_data_loader, make_dataset 17 | from dinov2.data import collate_data_and_cast, DataAugmentationDINO, MaskingGenerator 18 | import dinov2.distributed as distributed 19 | from dinov2.fsdp import FSDPCheckpointer 20 | from dinov2.logging import MetricLogger 21 | from dinov2.utils.config import setup 22 | from dinov2.utils.utils import CosineScheduler 23 | 24 | from dinov2.train.ssl_meta_arch import SSLMetaArch 25 | 26 | 27 | torch.backends.cuda.matmul.allow_tf32 = True # PyTorch 1.12 sets this to False by default 28 | logger = logging.getLogger("dinov2") 29 | 30 | 31 | def get_args_parser(add_help: bool = True): 32 | parser = argparse.ArgumentParser("DINOv2 training", add_help=add_help) 33 | parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file") 34 | parser.add_argument( 35 | "--no-resume", 36 | action="store_true", 37 | help="Whether to not attempt to resume from the checkpoint directory. ", 38 | ) 39 | parser.add_argument("--eval-only", action="store_true", help="perform evaluation only") 40 | parser.add_argument("--eval", type=str, default="", help="Eval type to perform") 41 | parser.add_argument( 42 | "opts", 43 | help=""" 44 | Modify config options at the end of the command. For Yacs configs, use 45 | space-separated "PATH.KEY VALUE" pairs. 46 | For python-based LazyConfig, use "path.key=value". 47 | """.strip(), 48 | default=None, 49 | nargs=argparse.REMAINDER, 50 | ) 51 | parser.add_argument( 52 | "--output-dir", 53 | "--output_dir", 54 | default="", 55 | type=str, 56 | help="Output directory to save logs and checkpoints", 57 | ) 58 | 59 | return parser 60 | 61 | 62 | def build_optimizer(cfg, params_groups): 63 | return torch.optim.AdamW(params_groups, betas=(cfg.optim.adamw_beta1, cfg.optim.adamw_beta2)) 64 | 65 | 66 | def build_schedulers(cfg): 67 | OFFICIAL_EPOCH_LENGTH = cfg.train.OFFICIAL_EPOCH_LENGTH 68 | lr = dict( 69 | base_value=cfg.optim["lr"], 70 | final_value=cfg.optim["min_lr"], 71 | total_iters=cfg.optim["epochs"] * OFFICIAL_EPOCH_LENGTH, 72 | warmup_iters=cfg.optim["warmup_epochs"] * OFFICIAL_EPOCH_LENGTH, 73 | start_warmup_value=0, 74 | ) 75 | wd = dict( 76 | base_value=cfg.optim["weight_decay"], 77 | final_value=cfg.optim["weight_decay_end"], 78 | total_iters=cfg.optim["epochs"] * OFFICIAL_EPOCH_LENGTH, 79 | ) 80 | momentum = dict( 81 | base_value=cfg.teacher["momentum_teacher"], 82 | final_value=cfg.teacher["final_momentum_teacher"], 83 | total_iters=cfg.optim["epochs"] * OFFICIAL_EPOCH_LENGTH, 84 | ) 85 | teacher_temp = dict( 86 | base_value=cfg.teacher["teacher_temp"], 87 | final_value=cfg.teacher["teacher_temp"], 88 | total_iters=cfg.teacher["warmup_teacher_temp_epochs"] * OFFICIAL_EPOCH_LENGTH, 89 | warmup_iters=cfg.teacher["warmup_teacher_temp_epochs"] * OFFICIAL_EPOCH_LENGTH, 90 | start_warmup_value=cfg.teacher["warmup_teacher_temp"], 91 | ) 92 | 93 | lr_schedule = CosineScheduler(**lr) 94 | wd_schedule = CosineScheduler(**wd) 95 | momentum_schedule = CosineScheduler(**momentum) 96 | teacher_temp_schedule = CosineScheduler(**teacher_temp) 97 | last_layer_lr_schedule = CosineScheduler(**lr) 98 | 99 | last_layer_lr_schedule.schedule[ 100 | : cfg.optim["freeze_last_layer_epochs"] * OFFICIAL_EPOCH_LENGTH 101 | ] = 0 # mimicking the original schedules 102 | 103 | logger.info("Schedulers ready.") 104 | 105 | return ( 106 | lr_schedule, 107 | wd_schedule, 108 | momentum_schedule, 109 | teacher_temp_schedule, 110 | last_layer_lr_schedule, 111 | ) 112 | 113 | 114 | def apply_optim_scheduler(optimizer, lr, wd, last_layer_lr): 115 | for param_group in optimizer.param_groups: 116 | is_last_layer = param_group["is_last_layer"] 117 | lr_multiplier = param_group["lr_multiplier"] 118 | wd_multiplier = param_group["wd_multiplier"] 119 | param_group["weight_decay"] = wd * wd_multiplier 120 | param_group["lr"] = (last_layer_lr if is_last_layer else lr) * lr_multiplier 121 | 122 | 123 | def do_test(cfg, model, iteration): 124 | new_state_dict = model.teacher.state_dict() 125 | 126 | if distributed.is_main_process(): 127 | iterstring = str(iteration) 128 | eval_dir = os.path.join(cfg.train.output_dir, "eval", iterstring) 129 | os.makedirs(eval_dir, exist_ok=True) 130 | # save teacher checkpoint 131 | teacher_ckp_path = os.path.join(eval_dir, "teacher_checkpoint.pth") 132 | torch.save({"teacher": new_state_dict}, teacher_ckp_path) 133 | 134 | 135 | def do_train(cfg, model, resume=False): 136 | model.train() 137 | inputs_dtype = torch.half 138 | fp16_scaler = model.fp16_scaler # for mixed precision training 139 | 140 | # setup optimizer 141 | 142 | optimizer = build_optimizer(cfg, model.get_params_groups()) 143 | ( 144 | lr_schedule, 145 | wd_schedule, 146 | momentum_schedule, 147 | teacher_temp_schedule, 148 | last_layer_lr_schedule, 149 | ) = build_schedulers(cfg) 150 | 151 | # checkpointer 152 | checkpointer = FSDPCheckpointer(model, cfg.train.output_dir, optimizer=optimizer, save_to_disk=True) 153 | 154 | start_iter = checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=resume).get("iteration", -1) + 1 155 | 156 | OFFICIAL_EPOCH_LENGTH = cfg.train.OFFICIAL_EPOCH_LENGTH 157 | max_iter = cfg.optim.epochs * OFFICIAL_EPOCH_LENGTH 158 | 159 | periodic_checkpointer = PeriodicCheckpointer( 160 | checkpointer, 161 | period=3 * OFFICIAL_EPOCH_LENGTH, 162 | max_iter=max_iter, 163 | max_to_keep=3, 164 | ) 165 | 166 | # setup data preprocessing 167 | 168 | img_size = cfg.crops.global_crops_size 169 | patch_size = cfg.student.patch_size 170 | n_tokens = (img_size // patch_size) ** 2 171 | mask_generator = MaskingGenerator( 172 | input_size=(img_size // patch_size, img_size // patch_size), 173 | max_num_patches=0.5 * img_size // patch_size * img_size // patch_size, 174 | ) 175 | 176 | data_transform = DataAugmentationDINO( 177 | cfg.crops.global_crops_scale, 178 | cfg.crops.local_crops_scale, 179 | cfg.crops.local_crops_number, 180 | global_crops_size=cfg.crops.global_crops_size, 181 | local_crops_size=cfg.crops.local_crops_size, 182 | ) 183 | 184 | collate_fn = partial( 185 | collate_data_and_cast, 186 | mask_ratio_tuple=cfg.ibot.mask_ratio_min_max, 187 | mask_probability=cfg.ibot.mask_sample_probability, 188 | n_tokens=n_tokens, 189 | mask_generator=mask_generator, 190 | dtype=inputs_dtype, 191 | ) 192 | 193 | # setup data loader 194 | 195 | dataset = make_dataset( 196 | dataset_str=cfg.train.dataset_path, 197 | transform=data_transform, 198 | target_transform=lambda _: (), 199 | ) 200 | # sampler_type = SamplerType.INFINITE 201 | sampler_type = SamplerType.SHARDED_INFINITE 202 | data_loader = make_data_loader( 203 | dataset=dataset, 204 | batch_size=cfg.train.batch_size_per_gpu, 205 | num_workers=cfg.train.num_workers, 206 | shuffle=True, 207 | seed=start_iter, # TODO: Fix this -- cfg.train.seed 208 | sampler_type=sampler_type, 209 | sampler_advance=0, # TODO(qas): fix this -- start_iter * cfg.train.batch_size_per_gpu, 210 | drop_last=True, 211 | collate_fn=collate_fn, 212 | ) 213 | 214 | # training loop 215 | 216 | iteration = start_iter 217 | 218 | logger.info("Starting training from iteration {}".format(start_iter)) 219 | metrics_file = os.path.join(cfg.train.output_dir, "training_metrics.json") 220 | metric_logger = MetricLogger(delimiter=" ", output_file=metrics_file) 221 | header = "Training" 222 | 223 | for data in metric_logger.log_every( 224 | data_loader, 225 | 10, 226 | header, 227 | max_iter, 228 | start_iter, 229 | ): 230 | current_batch_size = data["collated_global_crops"].shape[0] / 2 231 | if iteration > max_iter: 232 | return 233 | 234 | # apply schedules 235 | 236 | lr = lr_schedule[iteration] 237 | wd = wd_schedule[iteration] 238 | mom = momentum_schedule[iteration] 239 | teacher_temp = teacher_temp_schedule[iteration] 240 | last_layer_lr = last_layer_lr_schedule[iteration] 241 | apply_optim_scheduler(optimizer, lr, wd, last_layer_lr) 242 | 243 | # compute losses 244 | 245 | optimizer.zero_grad(set_to_none=True) 246 | loss_dict = model.forward_backward(data, teacher_temp=teacher_temp) 247 | 248 | # clip gradients 249 | 250 | if fp16_scaler is not None: 251 | if cfg.optim.clip_grad: 252 | fp16_scaler.unscale_(optimizer) 253 | for v in model.student.values(): 254 | v.clip_grad_norm_(cfg.optim.clip_grad) 255 | fp16_scaler.step(optimizer) 256 | fp16_scaler.update() 257 | else: 258 | if cfg.optim.clip_grad: 259 | for v in model.student.values(): 260 | v.clip_grad_norm_(cfg.optim.clip_grad) 261 | optimizer.step() 262 | 263 | # perform teacher EMA update 264 | 265 | model.update_teacher(mom) 266 | 267 | # logging 268 | 269 | if distributed.get_global_size() > 1: 270 | for v in loss_dict.values(): 271 | torch.distributed.all_reduce(v) 272 | loss_dict_reduced = {k: v.item() / distributed.get_global_size() for k, v in loss_dict.items()} 273 | 274 | if math.isnan(sum(loss_dict_reduced.values())): 275 | logger.info("NaN detected") 276 | raise AssertionError 277 | losses_reduced = sum(loss for loss in loss_dict_reduced.values()) 278 | 279 | metric_logger.update(lr=lr) 280 | metric_logger.update(wd=wd) 281 | metric_logger.update(mom=mom) 282 | metric_logger.update(last_layer_lr=last_layer_lr) 283 | metric_logger.update(current_batch_size=current_batch_size) 284 | metric_logger.update(total_loss=losses_reduced, **loss_dict_reduced) 285 | 286 | # checkpointing and testing 287 | 288 | if cfg.evaluation.eval_period_iterations > 0 and (iteration + 1) % cfg.evaluation.eval_period_iterations == 0: 289 | do_test(cfg, model, f"training_{iteration}") 290 | torch.cuda.synchronize() 291 | periodic_checkpointer.step(iteration) 292 | 293 | iteration = iteration + 1 294 | metric_logger.synchronize_between_processes() 295 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 296 | 297 | 298 | def main(args): 299 | cfg = setup(args) 300 | 301 | model = SSLMetaArch(cfg).to(torch.device("cuda")) 302 | model.prepare_for_distributed_training() 303 | 304 | logger.info("Model:\n{}".format(model)) 305 | if args.eval_only: 306 | iteration = ( 307 | FSDPCheckpointer(model, save_dir=cfg.train.output_dir) 308 | .resume_or_load(cfg.MODEL.WEIGHTS, resume=not args.no_resume) 309 | .get("iteration", -1) 310 | + 1 311 | ) 312 | return do_test(cfg, model, f"manual_{iteration}") 313 | 314 | do_train(cfg, model, resume=not args.no_resume) 315 | 316 | 317 | if __name__ == "__main__": 318 | args = get_args_parser(add_help=True).parse_args() 319 | main(args) 320 | -------------------------------------------------------------------------------- /torchhub/facebookresearch_dinov2_main/dinov2/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /torchhub/facebookresearch_dinov2_main/dinov2/utils/cluster.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from enum import Enum 8 | import os 9 | from pathlib import Path 10 | from typing import Any, Dict, Optional 11 | 12 | 13 | class ClusterType(Enum): 14 | AWS = "aws" 15 | FAIR = "fair" 16 | RSC = "rsc" 17 | 18 | 19 | def _guess_cluster_type() -> ClusterType: 20 | uname = os.uname() 21 | if uname.sysname == "Linux": 22 | if uname.release.endswith("-aws"): 23 | # Linux kernel versions on AWS instances are of the form "5.4.0-1051-aws" 24 | return ClusterType.AWS 25 | elif uname.nodename.startswith("rsc"): 26 | # Linux kernel versions on RSC instances are standard ones but hostnames start with "rsc" 27 | return ClusterType.RSC 28 | 29 | return ClusterType.FAIR 30 | 31 | 32 | def get_cluster_type(cluster_type: Optional[ClusterType] = None) -> Optional[ClusterType]: 33 | if cluster_type is None: 34 | return _guess_cluster_type() 35 | 36 | return cluster_type 37 | 38 | 39 | def get_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]: 40 | cluster_type = get_cluster_type(cluster_type) 41 | if cluster_type is None: 42 | return None 43 | 44 | CHECKPOINT_DIRNAMES = { 45 | ClusterType.AWS: "checkpoints", 46 | ClusterType.FAIR: "checkpoint", 47 | ClusterType.RSC: "checkpoint/dino", 48 | } 49 | return Path("/") / CHECKPOINT_DIRNAMES[cluster_type] 50 | 51 | 52 | def get_user_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]: 53 | checkpoint_path = get_checkpoint_path(cluster_type) 54 | if checkpoint_path is None: 55 | return None 56 | 57 | username = os.environ.get("USER") 58 | assert username is not None 59 | return checkpoint_path / username 60 | 61 | 62 | def get_slurm_partition(cluster_type: Optional[ClusterType] = None) -> Optional[str]: 63 | cluster_type = get_cluster_type(cluster_type) 64 | if cluster_type is None: 65 | return None 66 | 67 | SLURM_PARTITIONS = { 68 | ClusterType.AWS: "learnlab", 69 | ClusterType.FAIR: "learnlab", 70 | ClusterType.RSC: "learn", 71 | } 72 | return SLURM_PARTITIONS[cluster_type] 73 | 74 | 75 | def get_slurm_executor_parameters( 76 | nodes: int, num_gpus_per_node: int, cluster_type: Optional[ClusterType] = None, **kwargs 77 | ) -> Dict[str, Any]: 78 | # create default parameters 79 | params = { 80 | "mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html 81 | "gpus_per_node": num_gpus_per_node, 82 | "tasks_per_node": num_gpus_per_node, # one task per GPU 83 | "cpus_per_task": 10, 84 | "nodes": nodes, 85 | "slurm_partition": get_slurm_partition(cluster_type), 86 | } 87 | # apply cluster-specific adjustments 88 | cluster_type = get_cluster_type(cluster_type) 89 | if cluster_type == ClusterType.AWS: 90 | params["cpus_per_task"] = 12 91 | del params["mem_gb"] 92 | elif cluster_type == ClusterType.RSC: 93 | params["cpus_per_task"] = 12 94 | # set additional parameters / apply overrides 95 | params.update(kwargs) 96 | return params 97 | -------------------------------------------------------------------------------- /torchhub/facebookresearch_dinov2_main/dinov2/utils/config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | import logging 9 | import os 10 | 11 | from omegaconf import OmegaConf 12 | 13 | import dinov2.distributed as distributed 14 | from dinov2.logging import setup_logging 15 | from dinov2.utils import utils 16 | from dinov2.configs import dinov2_default_config 17 | 18 | 19 | logger = logging.getLogger("dinov2") 20 | 21 | 22 | def apply_scaling_rules_to_cfg(cfg): # to fix 23 | if cfg.optim.scaling_rule == "sqrt_wrt_1024": 24 | base_lr = cfg.optim.base_lr 25 | cfg.optim.lr = base_lr 26 | cfg.optim.lr *= math.sqrt(cfg.train.batch_size_per_gpu * distributed.get_global_size() / 1024.0) 27 | logger.info(f"sqrt scaling learning rate; base: {base_lr}, new: {cfg.optim.lr}") 28 | else: 29 | raise NotImplementedError 30 | return cfg 31 | 32 | 33 | def write_config(cfg, output_dir, name="config.yaml"): 34 | logger.info(OmegaConf.to_yaml(cfg)) 35 | saved_cfg_path = os.path.join(output_dir, name) 36 | with open(saved_cfg_path, "w") as f: 37 | OmegaConf.save(config=cfg, f=f) 38 | return saved_cfg_path 39 | 40 | 41 | def get_cfg_from_args(args): 42 | args.output_dir = os.path.abspath(args.output_dir) 43 | args.opts += [f"train.output_dir={args.output_dir}"] 44 | default_cfg = OmegaConf.create(dinov2_default_config) 45 | cfg = OmegaConf.load(args.config_file) 46 | cfg = OmegaConf.merge(default_cfg, cfg, OmegaConf.from_cli(args.opts)) 47 | return cfg 48 | 49 | 50 | def default_setup(args): 51 | distributed.enable(overwrite=True) 52 | seed = getattr(args, "seed", 0) 53 | rank = distributed.get_global_rank() 54 | 55 | global logger 56 | setup_logging(output=args.output_dir, level=logging.INFO) 57 | logger = logging.getLogger("dinov2") 58 | 59 | utils.fix_random_seeds(seed + rank) 60 | logger.info("git:\n {}\n".format(utils.get_sha())) 61 | logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))) 62 | 63 | 64 | def setup(args): 65 | """ 66 | Create configs and perform basic setups. 67 | """ 68 | cfg = get_cfg_from_args(args) 69 | os.makedirs(args.output_dir, exist_ok=True) 70 | default_setup(args) 71 | apply_scaling_rules_to_cfg(cfg) 72 | write_config(cfg, args.output_dir) 73 | return cfg 74 | -------------------------------------------------------------------------------- /torchhub/facebookresearch_dinov2_main/dinov2/utils/dtype.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | from typing import Dict, Union 9 | 10 | import numpy as np 11 | import torch 12 | 13 | 14 | TypeSpec = Union[str, np.dtype, torch.dtype] 15 | 16 | 17 | _NUMPY_TO_TORCH_DTYPE: Dict[np.dtype, torch.dtype] = { 18 | np.dtype("bool"): torch.bool, 19 | np.dtype("uint8"): torch.uint8, 20 | np.dtype("int8"): torch.int8, 21 | np.dtype("int16"): torch.int16, 22 | np.dtype("int32"): torch.int32, 23 | np.dtype("int64"): torch.int64, 24 | np.dtype("float16"): torch.float16, 25 | np.dtype("float32"): torch.float32, 26 | np.dtype("float64"): torch.float64, 27 | np.dtype("complex64"): torch.complex64, 28 | np.dtype("complex128"): torch.complex128, 29 | } 30 | 31 | 32 | def as_torch_dtype(dtype: TypeSpec) -> torch.dtype: 33 | if isinstance(dtype, torch.dtype): 34 | return dtype 35 | if isinstance(dtype, str): 36 | dtype = np.dtype(dtype) 37 | assert isinstance(dtype, np.dtype), f"Expected an instance of nunpy dtype, got {type(dtype)}" 38 | return _NUMPY_TO_TORCH_DTYPE[dtype] 39 | -------------------------------------------------------------------------------- /torchhub/facebookresearch_dinov2_main/dinov2/utils/param_groups.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from collections import defaultdict 8 | import logging 9 | 10 | 11 | logger = logging.getLogger("dinov2") 12 | 13 | 14 | def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12, force_is_backbone=False, chunked_blocks=False): 15 | """ 16 | Calculate lr decay rate for different ViT blocks. 17 | Args: 18 | name (string): parameter name. 19 | lr_decay_rate (float): base lr decay rate. 20 | num_layers (int): number of ViT blocks. 21 | Returns: 22 | lr decay rate for the given parameter. 23 | """ 24 | layer_id = num_layers + 1 25 | if name.startswith("backbone") or force_is_backbone: 26 | if ".pos_embed" in name or ".patch_embed" in name or ".mask_token" in name or ".cls_token" in name: 27 | layer_id = 0 28 | elif force_is_backbone and ( 29 | "pos_embed" in name or "patch_embed" in name or "mask_token" in name or "cls_token" in name 30 | ): 31 | layer_id = 0 32 | elif ".blocks." in name and ".residual." not in name: 33 | layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1 34 | elif chunked_blocks and "blocks." in name and "residual." not in name: 35 | layer_id = int(name[name.find("blocks.") :].split(".")[2]) + 1 36 | elif "blocks." in name and "residual." not in name: 37 | layer_id = int(name[name.find("blocks.") :].split(".")[1]) + 1 38 | 39 | return lr_decay_rate ** (num_layers + 1 - layer_id) 40 | 41 | 42 | def get_params_groups_with_decay(model, lr_decay_rate=1.0, patch_embed_lr_mult=1.0): 43 | chunked_blocks = False 44 | if hasattr(model, "n_blocks"): 45 | logger.info("chunked fsdp") 46 | n_blocks = model.n_blocks 47 | chunked_blocks = model.chunked_blocks 48 | elif hasattr(model, "blocks"): 49 | logger.info("first code branch") 50 | n_blocks = len(model.blocks) 51 | elif hasattr(model, "backbone"): 52 | logger.info("second code branch") 53 | n_blocks = len(model.backbone.blocks) 54 | else: 55 | logger.info("else code branch") 56 | n_blocks = 0 57 | all_param_groups = [] 58 | 59 | for name, param in model.named_parameters(): 60 | name = name.replace("_fsdp_wrapped_module.", "") 61 | if not param.requires_grad: 62 | continue 63 | decay_rate = get_vit_lr_decay_rate( 64 | name, lr_decay_rate, num_layers=n_blocks, force_is_backbone=n_blocks > 0, chunked_blocks=chunked_blocks 65 | ) 66 | d = {"params": param, "is_last_layer": False, "lr_multiplier": decay_rate, "wd_multiplier": 1.0, "name": name} 67 | 68 | if "last_layer" in name: 69 | d.update({"is_last_layer": True}) 70 | 71 | if name.endswith(".bias") or "norm" in name or "gamma" in name: 72 | d.update({"wd_multiplier": 0.0}) 73 | 74 | if "patch_embed" in name: 75 | d.update({"lr_multiplier": d["lr_multiplier"] * patch_embed_lr_mult}) 76 | 77 | all_param_groups.append(d) 78 | logger.info(f"""{name}: lr_multiplier: {d["lr_multiplier"]}, wd_multiplier: {d["wd_multiplier"]}""") 79 | 80 | return all_param_groups 81 | 82 | 83 | def fuse_params_groups(all_params_groups, keys=("lr_multiplier", "wd_multiplier", "is_last_layer")): 84 | fused_params_groups = defaultdict(lambda: {"params": []}) 85 | for d in all_params_groups: 86 | identifier = "" 87 | for k in keys: 88 | identifier += k + str(d[k]) + "_" 89 | 90 | for k in keys: 91 | fused_params_groups[identifier][k] = d[k] 92 | fused_params_groups[identifier]["params"].append(d["params"]) 93 | 94 | return fused_params_groups.values() 95 | -------------------------------------------------------------------------------- /torchhub/facebookresearch_dinov2_main/dinov2/utils/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import os 9 | import random 10 | import subprocess 11 | from urllib.parse import urlparse 12 | 13 | import numpy as np 14 | import torch 15 | from torch import nn 16 | 17 | 18 | logger = logging.getLogger("dinov2") 19 | 20 | 21 | def load_pretrained_weights(model, pretrained_weights, checkpoint_key): 22 | if urlparse(pretrained_weights).scheme: # If it looks like an URL 23 | state_dict = torch.hub.load_state_dict_from_url(pretrained_weights, map_location="cpu") 24 | else: 25 | state_dict = torch.load(pretrained_weights, map_location="cpu") 26 | if checkpoint_key is not None and checkpoint_key in state_dict: 27 | logger.info(f"Take key {checkpoint_key} in provided checkpoint dict") 28 | state_dict = state_dict[checkpoint_key] 29 | # remove `module.` prefix 30 | state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} 31 | # remove `backbone.` prefix induced by multicrop wrapper 32 | state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} 33 | msg = model.load_state_dict(state_dict, strict=False) 34 | logger.info("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg)) 35 | 36 | 37 | def fix_random_seeds(seed=31): 38 | """ 39 | Fix random seeds. 40 | """ 41 | torch.manual_seed(seed) 42 | torch.cuda.manual_seed_all(seed) 43 | np.random.seed(seed) 44 | random.seed(seed) 45 | 46 | 47 | def get_sha(): 48 | cwd = os.path.dirname(os.path.abspath(__file__)) 49 | 50 | def _run(command): 51 | return subprocess.check_output(command, cwd=cwd).decode("ascii").strip() 52 | 53 | sha = "N/A" 54 | diff = "clean" 55 | branch = "N/A" 56 | try: 57 | sha = _run(["git", "rev-parse", "HEAD"]) 58 | subprocess.check_output(["git", "diff"], cwd=cwd) 59 | diff = _run(["git", "diff-index", "HEAD"]) 60 | diff = "has uncommitted changes" if diff else "clean" 61 | branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"]) 62 | except Exception: 63 | pass 64 | message = f"sha: {sha}, status: {diff}, branch: {branch}" 65 | return message 66 | 67 | 68 | class CosineScheduler(object): 69 | def __init__(self, base_value, final_value, total_iters, warmup_iters=0, start_warmup_value=0, freeze_iters=0): 70 | super().__init__() 71 | self.final_value = final_value 72 | self.total_iters = total_iters 73 | 74 | freeze_schedule = np.zeros((freeze_iters)) 75 | 76 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) 77 | 78 | iters = np.arange(total_iters - warmup_iters - freeze_iters) 79 | schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters))) 80 | self.schedule = np.concatenate((freeze_schedule, warmup_schedule, schedule)) 81 | 82 | assert len(self.schedule) == self.total_iters 83 | 84 | def __getitem__(self, it): 85 | if it >= self.total_iters: 86 | return self.final_value 87 | else: 88 | return self.schedule[it] 89 | 90 | 91 | def has_batchnorms(model): 92 | bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm) 93 | for name, module in model.named_modules(): 94 | if isinstance(module, bn_types): 95 | return True 96 | return False 97 | -------------------------------------------------------------------------------- /torchhub/facebookresearch_dinov2_main/hubconf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from enum import Enum 7 | from typing import Union 8 | 9 | import torch 10 | 11 | _DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2" 12 | 13 | 14 | def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str: 15 | compact_arch_name = arch_name.replace("_", "")[:4] 16 | registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else "" 17 | return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}" 18 | 19 | 20 | class Weights(Enum): 21 | LVD142M = "LVD142M" 22 | 23 | 24 | def _make_dinov2_model( 25 | *, 26 | arch_name: str = "vit_large", 27 | img_size: int = 518, 28 | patch_size: int = 14, 29 | init_values: float = 1.0, 30 | ffn_layer: str = "mlp", 31 | block_chunks: int = 0, 32 | num_register_tokens: int = 0, 33 | interpolate_antialias: bool = False, 34 | interpolate_offset: float = 0.1, 35 | pretrained: bool = True, 36 | weights: Union[Weights, str] = Weights.LVD142M, 37 | **kwargs, 38 | ): 39 | import vision_transformer as vits 40 | 41 | if isinstance(weights, str): 42 | try: 43 | weights = Weights[weights] 44 | except KeyError: 45 | raise AssertionError(f"Unsupported weights: {weights}") 46 | 47 | model_base_name = _make_dinov2_model_name(arch_name, patch_size) 48 | vit_kwargs = dict( 49 | img_size=img_size, 50 | patch_size=patch_size, 51 | init_values=init_values, 52 | ffn_layer=ffn_layer, 53 | block_chunks=block_chunks, 54 | num_register_tokens=num_register_tokens, 55 | interpolate_antialias=interpolate_antialias, 56 | interpolate_offset=interpolate_offset, 57 | ) 58 | vit_kwargs.update(**kwargs) 59 | model = vits.__dict__[arch_name](**vit_kwargs) 60 | 61 | if pretrained: 62 | model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens) 63 | url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth" 64 | state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu") 65 | model.load_state_dict(state_dict, strict=True) 66 | 67 | return model 68 | 69 | 70 | def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): 71 | """ 72 | DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset. 73 | """ 74 | return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs) 75 | 76 | 77 | def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): 78 | """ 79 | DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset. 80 | """ 81 | return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs) 82 | 83 | 84 | def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): 85 | """ 86 | DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset. 87 | """ 88 | return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs) 89 | 90 | 91 | def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): 92 | """ 93 | DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset. 94 | """ 95 | return _make_dinov2_model( 96 | arch_name="vit_giant2", 97 | ffn_layer="swiglufused", 98 | weights=weights, 99 | pretrained=pretrained, 100 | **kwargs, 101 | ) 102 | 103 | 104 | def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): 105 | """ 106 | DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset. 107 | """ 108 | return _make_dinov2_model( 109 | arch_name="vit_small", 110 | pretrained=pretrained, 111 | weights=weights, 112 | num_register_tokens=4, 113 | interpolate_antialias=True, 114 | interpolate_offset=0.0, 115 | **kwargs, 116 | ) 117 | 118 | 119 | def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): 120 | """ 121 | DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset. 122 | """ 123 | return _make_dinov2_model( 124 | arch_name="vit_base", 125 | pretrained=pretrained, 126 | weights=weights, 127 | num_register_tokens=4, 128 | interpolate_antialias=True, 129 | interpolate_offset=0.0, 130 | **kwargs, 131 | ) 132 | 133 | 134 | def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): 135 | """ 136 | DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset. 137 | """ 138 | return _make_dinov2_model( 139 | arch_name="vit_large", 140 | pretrained=pretrained, 141 | weights=weights, 142 | num_register_tokens=4, 143 | interpolate_antialias=True, 144 | interpolate_offset=0.0, 145 | **kwargs, 146 | ) 147 | 148 | 149 | def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): 150 | """ 151 | DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset. 152 | """ 153 | return _make_dinov2_model( 154 | arch_name="vit_giant2", 155 | ffn_layer="swiglufused", 156 | weights=weights, 157 | pretrained=pretrained, 158 | num_register_tokens=4, 159 | interpolate_antialias=True, 160 | interpolate_offset=0.0, 161 | **kwargs, 162 | ) 163 | -------------------------------------------------------------------------------- /torchhub/facebookresearch_dinov2_main/pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 120 3 | 4 | [tool.pylint.master] 5 | persistent = false 6 | score = false 7 | 8 | [tool.pylint.messages_control] 9 | disable = "all" 10 | enable = [ 11 | "miscellaneous", 12 | "similarities", 13 | ] 14 | 15 | [tool.pylint.similarities] 16 | ignore-comments = true 17 | ignore-docstrings = true 18 | ignore-imports = true 19 | min-similarity-lines = 8 20 | 21 | [tool.pylint.reports] 22 | reports = false 23 | 24 | [tool.pylint.miscellaneous] 25 | notes = [ 26 | "FIXME", 27 | "XXX", 28 | "TODO", 29 | ] 30 | -------------------------------------------------------------------------------- /torchhub/facebookresearch_dinov2_main/scripts/lint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | if [ -n "$1" ]; then 4 | echo "linting \"$1\"" 5 | fi 6 | 7 | echo "running black" 8 | if [ -n "$1" ]; then 9 | black "$1" 10 | else 11 | black dinov2 12 | fi 13 | 14 | echo "running flake8" 15 | if [ -n "$1" ]; then 16 | flake8 "$1" 17 | else 18 | flake8 19 | fi 20 | 21 | echo "running pylint" 22 | if [ -n "$1" ]; then 23 | pylint "$1" 24 | else 25 | pylint dinov2 26 | fi 27 | 28 | exit 0 29 | -------------------------------------------------------------------------------- /torchhub/facebookresearch_dinov2_main/setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 120 3 | ignore = E203,E501,W503 4 | per-file-ignores = 5 | __init__.py:F401 6 | exclude = 7 | venv 8 | -------------------------------------------------------------------------------- /torchhub/facebookresearch_dinov2_main/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from pathlib import Path 8 | import re 9 | from typing import List, Tuple 10 | 11 | from setuptools import setup, find_packages 12 | 13 | 14 | NAME = "dinov2" 15 | DESCRIPTION = "PyTorch code and models for the DINOv2 self-supervised learning method." 16 | 17 | URL = "https://github.com/facebookresearch/dinov2" 18 | AUTHOR = "FAIR" 19 | REQUIRES_PYTHON = ">=3.9.0" 20 | HERE = Path(__file__).parent 21 | 22 | 23 | try: 24 | with open(HERE / "README.md", encoding="utf-8") as f: 25 | long_description = "\n" + f.read() 26 | except FileNotFoundError: 27 | long_description = DESCRIPTION 28 | 29 | 30 | def get_requirements(path: str = HERE / "requirements.txt") -> Tuple[List[str], List[str]]: 31 | requirements = [] 32 | extra_indices = [] 33 | with open(path) as f: 34 | for line in f.readlines(): 35 | line = line.rstrip("\r\n") 36 | if line.startswith("--extra-index-url "): 37 | extra_indices.append(line[18:]) 38 | continue 39 | requirements.append(line) 40 | return requirements, extra_indices 41 | 42 | 43 | def get_package_version() -> str: 44 | with open(HERE / "dinov2/__init__.py") as f: 45 | result = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", f.read(), re.M) 46 | if result: 47 | return result.group(1) 48 | raise RuntimeError("Can't get package version") 49 | 50 | 51 | requirements, extra_indices = get_requirements() 52 | version = get_package_version() 53 | dev_requirements, _ = get_requirements(HERE / "requirements-dev.txt") 54 | 55 | 56 | setup( 57 | name=NAME, 58 | version=version, 59 | description=DESCRIPTION, 60 | long_description=long_description, 61 | long_description_content_type="text/markdown", 62 | author=AUTHOR, 63 | python_requires=REQUIRES_PYTHON, 64 | url=URL, 65 | packages=find_packages(), 66 | package_data={ 67 | "": ["*.yaml"], 68 | }, 69 | install_requires=requirements, 70 | dependency_links=extra_indices, 71 | extras_require={ 72 | "dev": dev_requirements, 73 | }, 74 | install_package_data=True, 75 | license="CC-BY-NC", 76 | license_files=("LICENSE",), 77 | classifiers=[ 78 | # Trove classifiers: https://github.com/pypa/trove-classifiers/blob/main/src/trove_classifiers/__init__.py 79 | "Development Status :: 3 - Alpha", 80 | "Intended Audience :: Developers", 81 | "Intended Audience :: Science/Research", 82 | "License :: Other/Proprietary License", 83 | "Programming Language :: Python :: 3.9", 84 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 85 | "Topic :: Software Development :: Libraries :: Python Modules", 86 | ], 87 | ) 88 | -------------------------------------------------------------------------------- /torchhub/facebookresearch_dinov2_main/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import itertools 7 | import math 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | _DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2" 15 | 16 | 17 | def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str: 18 | compact_arch_name = arch_name.replace("_", "")[:4] 19 | registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else "" 20 | return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}" 21 | 22 | 23 | class CenterPadding(nn.Module): 24 | def __init__(self, multiple): 25 | super().__init__() 26 | self.multiple = multiple 27 | 28 | def _get_pad(self, size): 29 | new_size = math.ceil(size / self.multiple) * self.multiple 30 | pad_size = new_size - size 31 | pad_size_left = pad_size // 2 32 | pad_size_right = pad_size - pad_size_left 33 | return pad_size_left, pad_size_right 34 | 35 | @torch.inference_mode() 36 | def forward(self, x): 37 | pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1])) 38 | output = F.pad(x, pads) 39 | return output 40 | --------------------------------------------------------------------------------