├── .gitignore
├── LICENSE
├── README.md
├── assets
├── example_images
│ ├── arkit_depth.png
│ └── image.jpg
└── teaser.gif
├── promptda
├── main.py
├── model
│ ├── blocks.py
│ ├── config.py
│ └── dpt.py
├── promptda.py
├── scripts
│ ├── generate_video.py
│ ├── infer_stray_scan.py
│ └── sanity_check.py
└── utils
│ ├── depth_utils.py
│ ├── io_wrapper.py
│ ├── logger.py
│ └── parallel_utils.py
├── requirements.txt
├── setup.py
└── torchhub
├── README.md
└── facebookresearch_dinov2_main
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── MODEL_CARD.md
├── README.md
├── conda.yaml
├── dinov2
├── __init__.py
├── configs
│ ├── __init__.py
│ ├── eval
│ │ ├── vitb14_pretrain.yaml
│ │ ├── vitg14_pretrain.yaml
│ │ ├── vitl14_pretrain.yaml
│ │ └── vits14_pretrain.yaml
│ ├── ssl_default_config.yaml
│ └── train
│ │ ├── vitg14.yaml
│ │ ├── vitl14.yaml
│ │ └── vitl16_short.yaml
├── distributed
│ └── __init__.py
├── eval
│ ├── __init__.py
│ ├── knn.py
│ ├── linear.py
│ ├── log_regression.py
│ ├── metrics.py
│ ├── setup.py
│ └── utils.py
├── fsdp
│ └── __init__.py
├── layers
│ ├── __init__.py
│ ├── attention.py
│ ├── block.py
│ ├── dino_head.py
│ ├── drop_path.py
│ ├── layer_scale.py
│ ├── mlp.py
│ ├── patch_embed.py
│ └── swiglu_ffn.py
├── logging
│ ├── __init__.py
│ └── helpers.py
├── loss
│ ├── __init__.py
│ ├── dino_clstoken_loss.py
│ ├── ibot_patch_loss.py
│ └── koleo_loss.py
├── models
│ ├── __init__.py
│ └── vision_transformer.py
├── run
│ ├── __init__.py
│ ├── eval
│ │ ├── knn.py
│ │ ├── linear.py
│ │ └── log_regression.py
│ ├── submit.py
│ └── train
│ │ └── train.py
├── train
│ ├── __init__.py
│ ├── ssl_meta_arch.py
│ └── train.py
└── utils
│ ├── __init__.py
│ ├── cluster.py
│ ├── config.py
│ ├── dtype.py
│ ├── param_groups.py
│ └── utils.py
├── hubconf.py
├── pyproject.toml
├── scripts
└── lint.sh
├── setup.cfg
├── setup.py
├── utils.py
└── vision_transformer.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # delete files larger than 10MiB
2 | **basicModel_neutral_lbs_10_207_0_v1.0.0.pkl
3 |
4 | .vscode
5 | .hydra
6 | *.txt
7 | # All file or folders start with tmp will be ignored
8 | tmp*
9 |
10 | # Byte-compiled / optimized / DLL files
11 | __pycache__/
12 | *.py[cod]
13 | *$py.class
14 |
15 | # C extensions
16 | *.so
17 |
18 | # Distribution / packaging
19 | .Python
20 | build/
21 | results/
22 | checkpoints/
23 | data/
24 | develop-eggs/
25 | dist/
26 | downloads/
27 | eggs/
28 | .eggs/
29 | lib64/
30 | parts/
31 | sdist/
32 | var/
33 | wheels/
34 | share/python-wheels/
35 | *.egg-info/
36 | .installed.cfg
37 | *.egg
38 | MANIFEST
39 |
40 | # PyInstaller
41 | # Usually these files are written by a python script from a template
42 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
43 | *.manifest
44 | *.spec
45 |
46 | # Installer logs
47 | pip-log.txt
48 | pip-delete-this-directory.txt
49 |
50 | # Unit test / coverage reports
51 | htmlcov/
52 | .tox/
53 | .nox/
54 | .coverage
55 | .coverage.*
56 | .cache
57 | nosetests.xml
58 | coverage.xml
59 | *.cover
60 | *.py,cover
61 | .hypothesis/
62 | .pytest_cache/
63 | cover/
64 |
65 | # Translations
66 | *.mo
67 | *.pot
68 |
69 | # Django stuff:
70 | *.log
71 | local_settings.py
72 | db.sqlite3
73 | db.sqlite3-journal
74 |
75 | # Flask stuff:
76 | instance/
77 | .webassets-cache
78 |
79 | # Scrapy stuff:
80 | .scrapy
81 |
82 | # Sphinx documentation
83 | docs/_build/
84 |
85 | # PyBuilder
86 | .pybuilder/
87 | target/
88 |
89 | # Jupyter Notebook
90 | .ipynb_checkpoints
91 |
92 | # IPython
93 | profile_default/
94 | ipython_config.py
95 |
96 | # pyenv
97 | # For a library or package, you might want to ignore these files since the code is
98 | # intended to run in multiple environments; otherwise, check them in:
99 | # .python-version
100 |
101 | # pipenv
102 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
103 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
104 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
105 | # install all needed dependencies.
106 | #Pipfile.lock
107 |
108 | # poetry
109 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
110 | # This is especially recommended for binary packages to ensure reproducibility, and is more
111 | # commonly ignored for libraries.
112 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
113 | #poetry.lock
114 |
115 | # pdm
116 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
117 | #pdm.lock
118 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
119 | # in version control.
120 | # https://pdm.fming.dev/#use-with-ide
121 | .pdm.toml
122 |
123 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
124 | __pypackages__/
125 |
126 | # Celery stuff
127 | celerybeat-schedule
128 | celerybeat.pid
129 |
130 | # SageMath parsed files
131 | *.sage.py
132 |
133 | # Environments
134 | .env
135 | .venv
136 | env/
137 | venv/
138 | ENV/
139 | env.bak/
140 | venv.bak/
141 |
142 | # Spyder project settings
143 | .spyderproject
144 | .spyproject
145 |
146 | # Rope project settings
147 | .ropeproject
148 |
149 | # mkdocs documentation
150 | /site
151 |
152 | # mypy
153 | .mypy_cache/
154 | .dmypy.json
155 | dmypy.json
156 |
157 | # Pyre type checker
158 | .pyre/
159 |
160 | # pytype static type analyzer
161 | .pytype/
162 |
163 | # Cython debug symbols
164 | cython_debug/
165 |
166 | #
167 | .DS_Store/
168 |
169 | # PyCharm
170 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
171 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
172 | # and can be added to the global gitignore or merged into this file. For a more nuclear
173 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
174 | #.idea/
175 |
176 | # torchsparse
177 | torchsparse
178 |
179 | # tensorboard
180 | tensorboard
181 |
182 | # glove
183 | glove
184 | *.jpg
185 | *.png
186 | *.mp4
187 |
188 | *.ipynb
189 | *.ply
190 | 3rdparty
191 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Prompting Depth Anything for 4K Resolution Accurate Metric Depth Estimation
2 | ### [Project Page](https://promptda.github.io/) | [Paper](https://promptda.github.io/assets/main_paper_with_supp.pdf) | [Hugging Face Demo](https://huggingface.co/spaces/depth-anything/PromptDA) | [Interactive Results](https://promptda.github.io/interactive.html) | [Data](https://promptda.github.io/)
3 |
4 | > Prompting Depth Anything for 4K Resolution Accurate Metric Depth Estimation
5 | > [Haotong Lin](https://haotongl.github.io/),
6 | [Sida Peng](https://pengsida.net/),
7 | [Jingxiao Chen](https://scholar.google.com/citations?user=-zs1V28AAAAJ),
8 | [Songyou Peng](https://pengsongyou.github.io/),
9 | [Jiaming Sun](https://jiamingsun.me/),
10 | [Minghuan Liu](https://minghuanliu.com/),
11 | [Hujun Bao](http://www.cad.zju.edu.cn/home/bao/),
12 | [Jiashi Feng](https://scholar.google.com/citations?user=Q8iay0gAAAAJ),
13 | [Xiaowei Zhou](https://www.xzhou.me/),
14 | [Bingyi Kang](https://bingykang.github.io/)
15 | > CVPR 2025
16 |
17 | 
18 |
19 |
20 | ## 🛠️ Installation
21 |
22 | Setting up the environment
23 |
24 | ```bash
25 | git clone https://github.com/DepthAnything/PromptDA.git
26 | cd PromptDA
27 | pip install -r requirements.txt
28 | pip install -e .
29 | sudo apt install ffmpeg # for video generation
30 | ```
31 |
32 | Pre-trained Models
33 |
34 | | Model | Params | Checkpoint |
35 | |:-|-:|:-:|
36 | | Prompt-Depth-Anything-Large | 340M | [Download](https://huggingface.co/depth-anything/prompt-depth-anything-vitl/resolve/main/model.ckpt) |
37 | | Prompt-Depth-Anything-Small | 25.1M | [Download](https://huggingface.co/depth-anything/prompt-depth-anything-vits/resolve/main/model.ckpt) |
38 | | Prompt-Depth-Anything-Small-Transparent | 25.1M | [Download](https://huggingface.co/depth-anything/prompt-depth-anything-vits-transparent/resolve/main/model.ckpt) |
39 |
40 | Only Prompt-Depth-Anything-Large is used to benchmark in our paper. Prompt-Depth-Anything-Small-Transparent is further fine-tuned 10K steps with [hammer dataset](https://github.com/Junggy/HAMMER-dataset) with our iPhone lidar simulation method to improve the performance on transparent objects.
41 |
42 |
43 |
44 |
45 | ## 🚀 Usage
46 | Example usage
47 |
48 | ```python
49 | from promptda.promptda import PromptDA
50 | from promptda.utils.io_wrapper import load_image, load_depth, save_depth
51 |
52 | DEVICE = 'cuda'
53 | image_path = "assets/example_images/image.jpg"
54 | prompt_depth_path = "assets/example_images/arkit_depth.png"
55 | image = load_image(image_path).to(DEVICE)
56 | prompt_depth = load_depth(prompt_depth_path).to(DEVICE) # 192x256, ARKit LiDAR depth in meters
57 |
58 | model = PromptDA.from_pretrained("depth-anything/prompt-depth-anything-vitl").to(DEVICE).eval()
59 | depth = model.predict(image, prompt_depth) # HxW, depth in meters
60 |
61 | save_depth(depth, prompt_depth=prompt_depth, image=image)
62 | ```
63 |
64 |
65 |
66 | ## 📸 Running on your own capture
67 |
68 | You can use [Stray Scanner App](https://apps.apple.com/us/app/stray-scanner/id1557051662) to capture your own data, which requires iPhone 12 Pro or later Pro models, iPad 2020 Pro or later Pro models. We setup a [Hugging Face Space](https://huggingface.co/spaces/depth-anything/PromptDA) for you to quickly test our model. If you want to obtain video results, please follow the following steps.
69 |
70 | Testing steps
71 |
72 | 1. Capture a scene with the Stray Scanner App. (The charging port is preferred to face downward or to the right.)
73 | 2. Use the iPhone Files App to compress it into a zip file and transfer it to your computer. Here is an [example screen recording](https://haotongl.github.io/promptda/assets/ScreenRecording_12-16-2024.mp4).
74 | 3. Run the following commands to infer our model and generate the video results.
75 | ```bash
76 | export PATH_TO_ZIP_FILE=data/8b98276b0a.zip # Replace with your own zip file path
77 | export PATH_TO_SAVE_FOLDER=data/8b98276b0a_results # Replace with your own save folder path
78 | python3 -m promptda.scripts.infer_stray_scan --input_path ${PATH_TO_ZIP_FILE} --output_path ${PATH_TO_SAVE_FOLDER}
79 | python3 -m promptda.scripts.generate_video process_stray_scan --input_path ${PATH_TO_ZIP_FILE} --result_path ${PATH_TO_SAVE_FOLDER}
80 | ffmpeg -framerate 60 -i ${PATH_TO_SAVE_FOLDER}/%06d_smooth.jpg -c:v libx264 -pix_fmt yuv420p ${PATH_TO_SAVE_FOLDER}.mp4
81 | ```
82 |
83 |
84 |
85 | ## 👏 Acknowledgements
86 | We thank the generous support from Prof. [Weinan Zhang](https://wnzhang.net/) for robot experiments, including the space, objects and the Unitree H1 robot. We also thank [Zhengbang Zhu](https://scholar.google.com/citations?user=ozatRA0AAAAJ), Jiahang Cao, Xinyao Li, Wentao Dong for their help in setting up the robot platform and collecting robot data.
87 |
88 | ## 📚 Citation
89 | If you find this code useful for your research, please use the following BibTeX entry
90 | ```
91 | @inproceedings{lin2024promptda,
92 | title={Prompting Depth Anything for 4K Resolution Accurate Metric Depth Estimation},
93 | author={Lin, Haotong and Peng, Sida and Chen, Jingxiao and Peng, Songyou and Sun, Jiaming and Liu, Minghuan and Bao, Hujun and Feng, Jiashi and Zhou, Xiaowei and Kang, Bingyi},
94 | journal={arXiv},
95 | year={2024}
96 | }
97 | ```
98 |
--------------------------------------------------------------------------------
/assets/example_images/arkit_depth.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DepthAnything/PromptDA/d2b3560c66dfe6ab401dce18a330c00655827b83/assets/example_images/arkit_depth.png
--------------------------------------------------------------------------------
/assets/example_images/image.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DepthAnything/PromptDA/d2b3560c66dfe6ab401dce18a330c00655827b83/assets/example_images/image.jpg
--------------------------------------------------------------------------------
/assets/teaser.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DepthAnything/PromptDA/d2b3560c66dfe6ab401dce18a330c00655827b83/assets/teaser.gif
--------------------------------------------------------------------------------
/promptda/main.py:
--------------------------------------------------------------------------------
1 | from promptda.utils.logger import Log
2 |
3 |
4 | def main():
5 | pass
6 |
7 |
8 | if __name__ == "__main__":
9 | main()
10 |
--------------------------------------------------------------------------------
/promptda/model/blocks.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import torch.nn.functional as F
4 | from promptda.utils.logger import Log
5 | import os
6 | import numpy as np
7 |
8 |
9 | def _make_fusion_block(features, use_bn, size=None):
10 | return FeatureFusionDepthBlock(
11 | features,
12 | nn.ReLU(False),
13 | deconv=False,
14 | bn=use_bn,
15 | expand=False,
16 | align_corners=True,
17 | size=size,
18 | )
19 |
20 |
21 | def _make_scratch(in_shape, out_shape, groups=1, expand=False):
22 | scratch = nn.Module()
23 |
24 | out_shape1 = out_shape
25 | out_shape2 = out_shape
26 | out_shape3 = out_shape
27 | if len(in_shape) >= 4:
28 | out_shape4 = out_shape
29 |
30 | if expand:
31 | out_shape1 = out_shape
32 | out_shape2 = out_shape*2
33 | out_shape3 = out_shape*4
34 | if len(in_shape) >= 4:
35 | out_shape4 = out_shape*8
36 |
37 | scratch.layer1_rn = nn.Conv2d(
38 | in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
39 | )
40 | scratch.layer2_rn = nn.Conv2d(
41 | in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
42 | )
43 | scratch.layer3_rn = nn.Conv2d(
44 | in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
45 | )
46 | if len(in_shape) >= 4:
47 | scratch.layer4_rn = nn.Conv2d(
48 | in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
49 | )
50 |
51 | return scratch
52 |
53 |
54 | class ResidualConvUnit(nn.Module):
55 | """Residual convolution module.
56 | """
57 |
58 | def __init__(self, features, activation, bn):
59 | """Init.
60 |
61 | Args:
62 | features (int): number of features
63 | """
64 | super().__init__()
65 |
66 | self.bn = bn
67 |
68 | self.groups = 1
69 |
70 | self.conv1 = nn.Conv2d(
71 | features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
72 | )
73 |
74 | self.conv2 = nn.Conv2d(
75 | features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
76 | )
77 |
78 | if self.bn == True:
79 | self.bn1 = nn.BatchNorm2d(features)
80 | self.bn2 = nn.BatchNorm2d(features)
81 |
82 | self.activation = activation
83 |
84 | self.skip_add = nn.quantized.FloatFunctional()
85 |
86 | def forward(self, x):
87 | """Forward pass.
88 |
89 | Args:
90 | x (tensor): input
91 |
92 | Returns:
93 | tensor: output
94 | """
95 |
96 | out = self.activation(x)
97 | out = self.conv1(out)
98 | if self.bn == True:
99 | out = self.bn1(out)
100 |
101 | out = self.activation(out)
102 | out = self.conv2(out)
103 | if self.bn == True:
104 | out = self.bn2(out)
105 |
106 | if self.groups > 1:
107 | out = self.conv_merge(out)
108 |
109 | return self.skip_add.add(out, x)
110 |
111 |
112 | class FeatureFusionBlock(nn.Module):
113 | """Feature fusion block.
114 | """
115 |
116 | def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, size=None):
117 | """Init.
118 |
119 | Args:
120 | features (int): number of features
121 | """
122 | super(FeatureFusionBlock, self).__init__()
123 |
124 | self.deconv = deconv
125 | self.align_corners = align_corners
126 |
127 | self.groups = 1
128 |
129 | self.expand = expand
130 | out_features = features
131 | if self.expand == True:
132 | out_features = features//2
133 |
134 | self.out_conv = nn.Conv2d(
135 | features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
136 |
137 | self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
138 | self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
139 |
140 | self.skip_add = nn.quantized.FloatFunctional()
141 |
142 | self.size = size
143 |
144 | def forward(self, *xs, size=None):
145 | """Forward pass.
146 |
147 | Returns:
148 | tensor: output
149 | """
150 | output = xs[0]
151 |
152 | if len(xs) == 2:
153 | res = self.resConfUnit1(xs[1])
154 | output = self.skip_add.add(output, res)
155 |
156 | output = self.resConfUnit2(output)
157 |
158 | if (size is None) and (self.size is None):
159 | modifier = {"scale_factor": 2}
160 | elif size is None:
161 | modifier = {"size": self.size}
162 | else:
163 | modifier = {"size": size}
164 |
165 | output = nn.functional.interpolate(
166 | output, **modifier, mode="bilinear", align_corners=self.align_corners
167 | )
168 |
169 | output = self.out_conv(output)
170 |
171 | return output
172 |
173 |
174 | class FeatureFusionControlBlock(FeatureFusionBlock):
175 | """Feature fusion block.
176 | """
177 |
178 | def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, size=None):
179 | """Init.
180 |
181 | Args:
182 | features (int): number of features
183 | """
184 | super.__init__(features, activation, deconv,
185 | bn, expand, align_corners, size)
186 | self.copy_block = FeatureFusionBlock(
187 | features, activation, deconv, bn, expand, align_corners, size)
188 |
189 | def forward(self, *xs, size=None):
190 | """Forward pass.
191 |
192 | Returns:
193 | tensor: output
194 | """
195 | output = xs[0]
196 |
197 | if len(xs) == 2:
198 | res = self.resConfUnit1(xs[1])
199 | output = self.skip_add.add(output, res)
200 |
201 | output = self.resConfUnit2(output)
202 |
203 | if (size is None) and (self.size is None):
204 | modifier = {"scale_factor": 2}
205 | elif size is None:
206 | modifier = {"size": self.size}
207 | else:
208 | modifier = {"size": size}
209 |
210 | output = nn.functional.interpolate(
211 | output, **modifier, mode="bilinear", align_corners=self.align_corners
212 | )
213 |
214 | output = self.out_conv(output)
215 |
216 | return output
217 |
218 |
219 | def zero_module(module):
220 | """
221 | Zero out the parameters of a module and return it.
222 | """
223 | for p in module.parameters():
224 | p.detach().zero_()
225 | return module
226 |
227 |
228 | class FeatureFusionDepthBlock(nn.Module):
229 | """Feature fusion block.
230 | """
231 |
232 | def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, size=None):
233 | """Init.
234 |
235 | Args:
236 | features (int): number of features
237 | """
238 | super(FeatureFusionDepthBlock, self).__init__()
239 |
240 | self.deconv = deconv
241 | self.align_corners = align_corners
242 |
243 | self.groups = 1
244 |
245 | self.expand = expand
246 | out_features = features
247 | if self.expand == True:
248 | out_features = features//2
249 |
250 | self.out_conv = nn.Conv2d(
251 | features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
252 |
253 | self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
254 | self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
255 | self.resConfUnit_depth = nn.Sequential(
256 | nn.Conv2d(1, features, kernel_size=3, stride=1,
257 | padding=1, bias=True, groups=1),
258 | activation,
259 | nn.Conv2d(features, features, kernel_size=3,
260 | stride=1, padding=1, bias=True, groups=1),
261 | activation,
262 | zero_module(
263 | nn.Conv2d(features, features, kernel_size=3,
264 | stride=1, padding=1, bias=True, groups=1)
265 | )
266 | )
267 | self.skip_add = nn.quantized.FloatFunctional()
268 | self.size = size
269 |
270 | def forward(self, *xs, prompt_depth=None, size=None):
271 | """Forward pass.
272 |
273 | Returns:
274 | tensor: output
275 | """
276 | output = xs[0]
277 |
278 | if len(xs) == 2:
279 | res = self.resConfUnit1(xs[1])
280 | output = self.skip_add.add(output, res)
281 |
282 | output = self.resConfUnit2(output)
283 |
284 | if prompt_depth is not None:
285 | prompt_depth = F.interpolate(
286 | prompt_depth, output.shape[2:], mode='bilinear', align_corners=False)
287 | res = self.resConfUnit_depth(prompt_depth)
288 | output = self.skip_add.add(output, res)
289 |
290 | if (size is None) and (self.size is None):
291 | modifier = {"scale_factor": 2}
292 | elif size is None:
293 | modifier = {"size": self.size}
294 | else:
295 | modifier = {"size": size}
296 |
297 | output = nn.functional.interpolate(
298 | output, **modifier, mode="bilinear", align_corners=self.align_corners
299 | )
300 |
301 | output = self.out_conv(output)
302 |
303 | return output
304 |
--------------------------------------------------------------------------------
/promptda/model/config.py:
--------------------------------------------------------------------------------
1 | model_configs = {
2 | 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384], 'layer_idxs': [2, 5, 8, 11]},
3 | 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768], 'layer_idxs': [2, 5, 8, 11]},
4 | 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024], 'layer_idxs': [4, 11, 17, 23]},
5 | 'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536], 'layer_idxs': [9, 19, 29, 39]}
6 | }
7 |
--------------------------------------------------------------------------------
/promptda/model/dpt.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, Depth Anything V2
2 | # https://github.com/DepthAnything/Depth-Anything-V2/blob/main/depth_anything_v2/dpt.py
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from promptda.model.blocks import _make_scratch, _make_fusion_block
7 |
8 |
9 | class DPTHead(nn.Module):
10 | def __init__(self,
11 | nclass,
12 | in_channels,
13 | features=256,
14 | out_channels=[256, 512, 1024, 1024],
15 | use_bn=False,
16 | use_clstoken=False,
17 | output_act='sigmoid'):
18 | super(DPTHead, self).__init__()
19 |
20 | self.nclass = nclass
21 | self.use_clstoken = use_clstoken
22 |
23 | self.projects = nn.ModuleList([
24 | nn.Conv2d(
25 | in_channels=in_channels,
26 | out_channels=out_channel,
27 | kernel_size=1,
28 | stride=1,
29 | padding=0,
30 | ) for out_channel in out_channels
31 | ])
32 |
33 | self.resize_layers = nn.ModuleList([
34 | nn.ConvTranspose2d(
35 | in_channels=out_channels[0],
36 | out_channels=out_channels[0],
37 | kernel_size=4,
38 | stride=4,
39 | padding=0),
40 | nn.ConvTranspose2d(
41 | in_channels=out_channels[1],
42 | out_channels=out_channels[1],
43 | kernel_size=2,
44 | stride=2,
45 | padding=0),
46 | nn.Identity(),
47 | nn.Conv2d(
48 | in_channels=out_channels[3],
49 | out_channels=out_channels[3],
50 | kernel_size=3,
51 | stride=2,
52 | padding=1)
53 | ])
54 |
55 | if use_clstoken:
56 | self.readout_projects = nn.ModuleList()
57 | for _ in range(len(self.projects)):
58 | self.readout_projects.append(
59 | nn.Sequential(
60 | nn.Linear(2 * in_channels, in_channels),
61 | nn.GELU()))
62 |
63 | self.scratch = _make_scratch(
64 | out_channels,
65 | features,
66 | groups=1,
67 | expand=False,
68 | )
69 |
70 | self.scratch.stem_transpose = None
71 |
72 | self.scratch.refinenet1 = _make_fusion_block(
73 | features, use_bn)
74 | self.scratch.refinenet2 = _make_fusion_block(
75 | features, use_bn)
76 | self.scratch.refinenet3 = _make_fusion_block(
77 | features, use_bn)
78 | self.scratch.refinenet4 = _make_fusion_block(
79 | features, use_bn)
80 |
81 | head_features_1 = features
82 | head_features_2 = 32
83 |
84 | act_func = nn.Sigmoid() if output_act == 'sigmoid' else nn.Identity()
85 |
86 | if nclass > 1:
87 | self.scratch.output_conv = nn.Sequential(
88 | nn.Conv2d(head_features_1, head_features_1,
89 | kernel_size=3, stride=1, padding=1),
90 | nn.ReLU(True),
91 | nn.Conv2d(head_features_1, nclass,
92 | kernel_size=1, stride=1, padding=0),
93 | )
94 | else:
95 | self.scratch.output_conv1 = nn.Conv2d(
96 | head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1)
97 |
98 | self.scratch.output_conv2 = nn.Sequential(
99 | nn.Conv2d(head_features_1 // 2, head_features_2,
100 | kernel_size=3, stride=1, padding=1),
101 | nn.ReLU(True),
102 | nn.Conv2d(head_features_2, 1, kernel_size=1,
103 | stride=1, padding=0),
104 | act_func,
105 | )
106 |
107 | def forward(self, out_features, patch_h, patch_w, prompt_depth=None):
108 | out = []
109 | for i, x in enumerate(out_features):
110 | if self.use_clstoken:
111 | x, cls_token = x[0], x[1]
112 | readout = cls_token.unsqueeze(1).expand_as(x)
113 | x = self.readout_projects[i](torch.cat((x, readout), -1))
114 | else:
115 | x = x[0]
116 |
117 | x = x.permute(0, 2, 1).reshape(
118 | (x.shape[0], x.shape[-1], patch_h, patch_w))
119 |
120 | x = self.projects[i](x)
121 | x = self.resize_layers[i](x)
122 |
123 | out.append(x)
124 |
125 | layer_1, layer_2, layer_3, layer_4 = out
126 |
127 | layer_1_rn = self.scratch.layer1_rn(layer_1)
128 | layer_2_rn = self.scratch.layer2_rn(layer_2)
129 | layer_3_rn = self.scratch.layer3_rn(layer_3)
130 | layer_4_rn = self.scratch.layer4_rn(layer_4)
131 |
132 | path_4 = self.scratch.refinenet4(
133 | layer_4_rn, size=layer_3_rn.shape[2:], prompt_depth=prompt_depth)
134 | path_3 = self.scratch.refinenet3(
135 | path_4, layer_3_rn, size=layer_2_rn.shape[2:], prompt_depth=prompt_depth)
136 | path_2 = self.scratch.refinenet2(
137 | path_3, layer_2_rn, size=layer_1_rn.shape[2:], prompt_depth=prompt_depth)
138 | path_1 = self.scratch.refinenet1(
139 | path_2, layer_1_rn, prompt_depth=prompt_depth)
140 | out = self.scratch.output_conv1(path_1)
141 | out_feat = F.interpolate(
142 | out, (int(patch_h * 14), int(patch_w * 14)),
143 | mode="bilinear", align_corners=True)
144 | out = self.scratch.output_conv2(out_feat)
145 | return out
146 |
--------------------------------------------------------------------------------
/promptda/promptda.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from promptda.model.dpt import DPTHead
4 | from promptda.model.config import model_configs
5 | from promptda.utils.logger import Log
6 | import os
7 | from pathlib import Path
8 | from huggingface_hub import hf_hub_download
9 |
10 |
11 | class PromptDA(nn.Module):
12 | patch_size = 14 # patch size of the pretrained dinov2 model
13 | use_bn = False
14 | use_clstoken = False
15 | output_act = 'sigmoid'
16 |
17 | def __init__(self,
18 | encoder='vitl',
19 | ckpt_path='checkpoints/promptda_vitl.ckpt'):
20 | super().__init__()
21 | model_config = model_configs[encoder]
22 |
23 | self.encoder = encoder
24 | self.model_config = model_config
25 | module_path = Path(__file__) # From anywhere else: module_path = Path(inspect.getfile(PromptDA))
26 | package_base_dir = str(Path(*module_path.parts[:-2])) # extract path to PromptDA module, then resolve to repo base dir for dinov2 load
27 | self.pretrained = torch.hub.load(
28 | f'{package_base_dir}/torchhub/facebookresearch_dinov2_main',
29 | 'dinov2_{:}14'.format(encoder),
30 | source='local',
31 | pretrained=False)
32 | dim = self.pretrained.blocks[0].attn.qkv.in_features
33 | self.depth_head = DPTHead(nclass=1,
34 | in_channels=dim,
35 | features=model_config['features'],
36 | out_channels=model_config['out_channels'],
37 | use_bn=self.use_bn,
38 | use_clstoken=self.use_clstoken,
39 | output_act=self.output_act)
40 |
41 | # mean and std of the pretrained dinov2 model
42 | self.register_buffer('_mean', torch.tensor(
43 | [0.485, 0.456, 0.406]).view(1, 3, 1, 1))
44 | self.register_buffer('_std', torch.tensor(
45 | [0.229, 0.224, 0.225]).view(1, 3, 1, 1))
46 |
47 | self.load_checkpoint(ckpt_path)
48 |
49 | @classmethod
50 | def from_pretrained(cls, pretrained_model_name_or_path = None, model_kwargs = None, **hf_kwargs):
51 | """
52 | Load a model from a checkpoint file.
53 | ### Parameters:
54 | - `pretrained_model_name_or_path`: path to the checkpoint file or repo id.
55 | - `model_kwargs`: additional keyword arguments to override the parameters in the checkpoint.
56 | - `hf_kwargs`: additional keyword arguments to pass to the `hf_hub_download` function. Ignored if `pretrained_model_name_or_path` is a local path.
57 | ### Returns:
58 | - A new instance of `MoGe` with the parameters loaded from the checkpoint.
59 | """
60 | ckpt_path = None
61 | if Path(pretrained_model_name_or_path).exists():
62 | ckpt_path = pretrained_model_name_or_path
63 | else:
64 | cached_checkpoint_path = hf_hub_download(
65 | repo_id=pretrained_model_name_or_path,
66 | repo_type="model",
67 | filename="model.ckpt",
68 | **hf_kwargs
69 | )
70 | ckpt_path = cached_checkpoint_path
71 | # model_config = checkpoint['model_config']
72 | # if model_kwargs is not None:
73 | # model_config.update(model_kwargs)
74 | if model_kwargs is None:
75 | model_kwargs = {}
76 | model_kwargs.update({'ckpt_path': ckpt_path})
77 | model = cls(**model_kwargs)
78 | return model
79 |
80 | def load_checkpoint(self, ckpt_path):
81 | if os.path.exists(ckpt_path):
82 | Log.info(f'Loading checkpoint from {ckpt_path}')
83 | checkpoint = torch.load(ckpt_path, map_location='cpu')
84 | self.load_state_dict(
85 | {k[9:]: v for k, v in checkpoint['state_dict'].items()})
86 | else:
87 | Log.warn(f'Checkpoint {ckpt_path} not found')
88 |
89 | def forward(self, x, prompt_depth=None):
90 | assert prompt_depth is not None, 'prompt_depth is required'
91 | prompt_depth, min_val, max_val = self.normalize(prompt_depth)
92 | h, w = x.shape[-2:]
93 | features = self.pretrained.get_intermediate_layers(
94 | (x - self._mean) / self._std, self.model_config['layer_idxs'],
95 | return_class_token=True)
96 | patch_h, patch_w = h // self.patch_size, w // self.patch_size
97 | depth = self.depth_head(features, patch_h, patch_w, prompt_depth)
98 | depth = self.denormalize(depth, min_val, max_val)
99 | return depth
100 |
101 | @torch.no_grad()
102 | def predict(self,
103 | image: torch.Tensor,
104 | prompt_depth: torch.Tensor):
105 | return self.forward(image, prompt_depth)
106 |
107 | def normalize(self,
108 | prompt_depth: torch.Tensor):
109 | B, C, H, W = prompt_depth.shape
110 | min_val = torch.quantile(
111 | prompt_depth.reshape(B, -1), 0., dim=1, keepdim=True)[:, :, None, None]
112 | max_val = torch.quantile(
113 | prompt_depth.reshape(B, -1), 1., dim=1, keepdim=True)[:, :, None, None]
114 | prompt_depth = (prompt_depth - min_val) / (max_val - min_val)
115 | return prompt_depth, min_val, max_val
116 |
117 | def denormalize(self,
118 | depth: torch.Tensor,
119 | min_val: torch.Tensor,
120 | max_val: torch.Tensor):
121 | return depth * (max_val - min_val) + min_val
122 |
--------------------------------------------------------------------------------
/promptda/scripts/generate_video.py:
--------------------------------------------------------------------------------
1 | import tyro
2 | from tqdm.auto import tqdm
3 | import numpy as np
4 | import glob
5 | import os
6 | import cv2
7 | import imageio
8 | from promptda.utils.depth_utils import smooth_min_max, visualize_depth
9 | from promptda.utils.io_wrapper import load_depth, load_image
10 | from promptda.utils.parallel_utils import async_call, parallel_execution
11 |
12 |
13 | def load_depths(depth_paths: list[str]) -> list[np.ndarray]:
14 | depths = parallel_execution(depth_paths,
15 | to_tensor=False,
16 | action=load_depth,
17 | num_processes=32,
18 | desc='Loading depths')
19 | return depths
20 |
21 |
22 | def load_imgs(rgb_paths: list[str], max_size: int) -> list[np.ndarray]:
23 | rgbs = parallel_execution(rgb_paths,
24 | to_tensor=False,
25 | max_size=max_size,
26 | action=load_image,
27 | num_processes=32,
28 | desc='Loading RGB images')
29 | return rgbs
30 |
31 |
32 | def load_result_depths(result_path: str) -> list[np.ndarray]:
33 | depth_paths = sorted(glob.glob(os.path.join(result_path, '*.png')))
34 | depths = load_depths(depth_paths)
35 | return depths
36 |
37 |
38 | def load_prompt_depths(input_path: str) -> list[np.ndarray]:
39 | prompt_depth_paths = sorted(glob.glob(os.path.join(input_path, 'depth/*.png')))
40 | prompt_depths = load_depths(prompt_depth_paths)
41 | return prompt_depths
42 |
43 |
44 | def load_rgbs(input_path: str, max_size: int) -> list[np.ndarray]:
45 | rgb_paths = sorted(glob.glob(os.path.join(input_path, 'rgb/*.jpg')))
46 | rgbs = load_imgs(rgb_paths, max_size)
47 | return rgbs
48 |
49 | @async_call
50 | def generate_frame(
51 | depth: np.ndarray,
52 | min_val: float,
53 | max_val: float,
54 | frame_idx: int,
55 | result_path: str,
56 | prompt_depth: np.ndarray = None,
57 | rgb: np.ndarray = None,
58 | ) -> None:
59 | output_img = visualize_depth(depth, min_val, max_val)
60 | if prompt_depth is not None:
61 | prompt_depth_img = visualize_depth(prompt_depth, min_val, max_val)
62 | if prompt_depth_img.shape[:2] != depth.shape[:2]:
63 | prompt_depth_img = cv2.resize(prompt_depth_img, (depth.shape[1], depth.shape[0]))
64 | output_img = np.concatenate([output_img, prompt_depth_img], axis=1)
65 | if rgb is not None:
66 | if rgb.shape[:2] != depth.shape[:2]:
67 | rgb = cv2.resize(rgb, (depth.shape[1], depth.shape[0]))
68 | if rgb.dtype == np.float32 or rgb.dtype == np.float64:
69 | rgb = (rgb * 255).astype(np.uint8)
70 | output_img = np.concatenate([rgb, output_img], axis=1)
71 | imageio.imwrite(os.path.join(result_path, f'{frame_idx:06d}_smooth.jpg'), output_img)
72 |
73 |
74 | def process_stray_scan(input_path: str = 'data/8b98276b0a',
75 | result_path: str = 'data/8b98276b0a_results',
76 | include_prompt: bool = True,
77 | include_rgb: bool = True,
78 | percentile: float = 2,
79 | smooth_interval: int = 60) -> None:
80 | result_depths = load_result_depths(result_path)
81 | min_vals = [np.percentile(depth, percentile) for depth in result_depths]
82 | max_vals = [np.percentile(depth, 100 - percentile) for depth in result_depths]
83 | min_vals_smooth, max_vals_smooth = smooth_min_max(min_vals, max_vals, smooth_interval)
84 |
85 | if include_prompt:
86 | prompt_depths = load_prompt_depths(input_path)
87 | if include_rgb:
88 | rgbs = load_rgbs(input_path, max(result_depths[0].shape))
89 |
90 | min_len = min(len(result_depths), len(prompt_depths), len(rgbs))
91 | result_depths = result_depths[:min_len]
92 | prompt_depths = prompt_depths[:min_len]
93 | rgbs = rgbs[:min_len]
94 |
95 | for frame_idx in tqdm(range(len(result_depths)), desc='Generating frames'):
96 | generate_frame(result_depths[frame_idx],
97 | min_vals_smooth[frame_idx],
98 | max_vals_smooth[frame_idx],
99 | frame_idx,
100 | result_path,
101 | prompt_depths[frame_idx] if include_prompt else None,
102 | rgbs[frame_idx] if include_rgb else None)
103 |
104 | def main() -> None:
105 | pass
106 |
107 |
108 | if __name__ == "__main__":
109 | tyro.extras.subcommand_cli_from_dict(
110 | {
111 | 'process_stray_scan': process_stray_scan,
112 | 'main': main,
113 | }
114 | )
115 |
116 |
--------------------------------------------------------------------------------
/promptda/scripts/infer_stray_scan.py:
--------------------------------------------------------------------------------
1 | import tyro
2 | import os
3 | import glob
4 | from tqdm.auto import tqdm
5 |
6 | from promptda.utils.io_wrapper import load_image, load_depth, save_depth
7 | from promptda.utils.parallel_utils import parallel_execution
8 | from promptda.promptda import PromptDA
9 |
10 | def load_data(input_path: str, max_size: int):
11 | root_dir = os.path.dirname(input_path)
12 | scene_name = input_path.split('/')[-1].split('.')[0]
13 | input_dir = os.path.join(root_dir, scene_name)
14 | if not os.path.exists(input_dir):
15 | cmd = f'unzip -o {input_path} -d {root_dir}'
16 | os.system(cmd)
17 |
18 | if not os.path.exists(os.path.join(input_dir, 'rgb')):
19 | os.makedirs(os.path.join(input_dir, 'rgb'), exist_ok=True)
20 | cmd = f'ffmpeg -i {input_dir}/rgb.mp4 -start_number 0 -q:v 2 {input_dir}/rgb/%06d.jpg'
21 | os.system(cmd)
22 |
23 | rgb_files = sorted(glob.glob(os.path.join(input_dir, 'rgb', '*.jpg')))
24 | prompt_depth_files = sorted(glob.glob(os.path.join(input_dir, 'depth', '*.png')))
25 |
26 | if len(rgb_files) != len(prompt_depth_files):
27 | min_len = min(len(rgb_files), len(prompt_depth_files))
28 | rgb_files = rgb_files[:min_len]
29 | prompt_depth_files = prompt_depth_files[:min_len]
30 |
31 | rgbs = parallel_execution(rgb_files,
32 | to_tensor=True, # to_tensor
33 | max_size=max_size,
34 | action=load_image,
35 | num_processes=32,
36 | print_progress=True,
37 | desc='Loading RGB images')
38 |
39 | prompt_depths = parallel_execution(prompt_depth_files,
40 | to_tensor=True, # to_tensor
41 | action=load_depth,
42 | num_processes=32,
43 | print_progress=True,
44 | desc='Loading Prompt Depth')
45 | return rgbs, prompt_depths
46 |
47 |
48 | def main(input_path: str = 'data/8b98276b0a.zip',
49 | output_path: str = 'data/8b98276b0a_results',
50 | max_size: int = 1008,
51 | ):
52 | os.makedirs(output_path, exist_ok=True)
53 | rgbs, prompt_depths = load_data(input_path, max_size)
54 | results = []
55 | DEVICE = 'cuda'
56 | model = PromptDA.from_pretrained("depth-anything/prompt-depth-anything-vitl").to(DEVICE).eval()
57 | for frame_idx, (rgb, prompt_depth) in tqdm(enumerate(zip(rgbs, prompt_depths)), desc='Inferring', total=len(rgbs)):
58 | rgb, prompt_depth = rgb.to(DEVICE), prompt_depth.to(DEVICE)
59 | depth = model.predict(rgb, prompt_depth)
60 | save_depth(depth.detach().cpu(),
61 | output_path=os.path.join(output_path, f'{frame_idx:06d}.png'),
62 | save_vis=True)
63 |
64 | if __name__ == "__main__":
65 | tyro.cli(main)
--------------------------------------------------------------------------------
/promptda/scripts/sanity_check.py:
--------------------------------------------------------------------------------
1 | from promptda.promptda import PromptDA
2 | from promptda.utils.io_wrapper import load_image, load_depth, save_depth
3 |
4 | DEVICE = 'cuda'
5 | image_path = "assets/example_images/image.jpg"
6 | prompt_depth_path = "assets/example_images/arkit_depth.png"
7 | image = load_image(image_path).to(DEVICE)
8 | prompt_depth = load_depth(prompt_depth_path).to(DEVICE) # 192x256, ARKit LiDAR depth in meters
9 |
10 | model = PromptDA.from_pretrained("depth-anything/prompt-depth-anything-vitl").to(DEVICE).eval()
11 | depth = model.predict(image, prompt_depth) # HxW, depth in meters
12 |
13 | save_depth(depth, prompt_depth=prompt_depth, image=image)
--------------------------------------------------------------------------------
/promptda/utils/depth_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib
3 | import open3d as o3d
4 | from scipy.interpolate import CubicSpline
5 |
6 | def visualize_depth(depth: np.ndarray,
7 | depth_min=None,
8 | depth_max=None,
9 | percentile=2,
10 | ret_minmax=False,
11 | cmap='Spectral'):
12 | if depth_min is None: depth_min = np.percentile(depth, percentile)
13 | if depth_max is None: depth_max = np.percentile(depth, 100 - percentile)
14 | if depth_min == depth_max:
15 | depth_min = depth_min - 1e-6
16 | depth_max = depth_max + 1e-6
17 | cm = matplotlib.colormaps[cmap]
18 | depth = ((depth - depth_min) / (depth_max - depth_min)).clip(0, 1)
19 | img_colored_np = cm(depth[None], bytes=False)[:, :, :, 0:3] # value from 0 to 1
20 | img_colored_np = (img_colored_np[0] * 255.0).astype(np.uint8)
21 | if ret_minmax:
22 | return img_colored_np, depth_min, depth_max
23 | else:
24 | return img_colored_np
25 |
26 |
27 | def unproject_depth(depth,
28 | ixt,
29 | depth_min=0.01,
30 | depth_max=None,
31 | color=None,
32 | ext=None,
33 | conf=None,
34 | ret_pcd=False,
35 | clip_box=None):
36 | height, width = depth.shape
37 | x = np.arange(0, width)
38 | y = np.arange(0, height)
39 | xx, yy = np.meshgrid(x, y)
40 | xx = xx.reshape(-1)
41 | yy = yy.reshape(-1)
42 | zz = depth.reshape(-1)
43 | mask = np.ones_like(xx, dtype=np.bool_)
44 | if depth_min is not None:
45 | mask &= zz >= depth_min
46 | if depth_max is not None:
47 | mask &= zz <= depth_max
48 | if conf is not None:
49 | mask &= conf.reshape(-1) == 2
50 | xx = xx[mask]
51 | yy = yy[mask]
52 | zz = zz[mask]
53 | pcd = np.stack([xx, yy, np.ones_like(xx)], axis=1)
54 | pcd = pcd * zz[:, None]
55 | pcd = np.dot(pcd, np.linalg.inv(ixt).T)
56 | if ext is not None:
57 | pcd = np.concatenate([pcd, np.ones((pcd.shape[0], 1))], axis=1)
58 | pcd = np.dot(pcd, np.linalg.inv(ext).T)
59 | new_mask = np.ones_like(pcd[:, 0]).astype(np.bool_)
60 | if clip_box is not None:
61 | assert len(clip_box) == 6
62 | for i, val in enumerate(clip_box):
63 | if val is None:
64 | continue
65 | if i == 0: new_mask &= (pcd[:, 0] <= val)
66 | elif i == 1: new_mask &= (pcd[:, 1] <= val)
67 | elif i == 2: new_mask &= (pcd[:, 2] <= val)
68 | elif i == 3: new_mask &= (pcd[:, 0] >= val)
69 | elif i == 4: new_mask &= (pcd[:, 1] >= val)
70 | elif i == 5: new_mask &= (pcd[:, 2] >= val)
71 | if color is not None:
72 | if color.dtype == np.uint8:
73 | color = color.astype(np.float32) / 255.
74 | if ret_pcd:
75 | points = pcd
76 | pcd = o3d.geometry.PointCloud()
77 | pcd.points = o3d.utility.Vector3dVector(points[:, :3][new_mask])
78 | pcd.colors = o3d.utility.Vector3dVector(color.reshape(-1, 3)[mask][new_mask])
79 | else:
80 | return pcd[:, :3][new_mask], color.reshape(-1, 3)[mask][new_mask]
81 | else:
82 | if ret_pcd:
83 | points = pcd
84 | pcd = o3d.geometry.PointCloud()
85 | pcd.points = o3d.utility.Vector3dVector(pcd[:, :3][new_mask])
86 | else:
87 | return pcd[:, :3][new_mask]
88 | return pcd
89 |
90 | def smooth_min_max(min_vals,
91 | max_vals,
92 | interval: int = 60):
93 | '''
94 | Slerp interpolate and smooth min and max values
95 | Args:
96 | min_vals: list[float]
97 | max_vals: list[float]
98 | Returns:
99 | min_vals_smooth: list[float]
100 | max_vals_smooth: list[float]
101 | '''
102 |
103 | key_frames = list(range(0, len(min_vals), interval))
104 | if key_frames[-1] != len(min_vals) - 1:
105 | key_frames.append(len(min_vals) - 1)
106 |
107 | key_frame_indices = np.array(key_frames)
108 | min_key_vals = np.array([min_vals[i] for i in key_frames])
109 | max_key_vals = np.array([max_vals[i] for i in key_frames])
110 |
111 | # Use CubicSpline for smooth interpolation
112 | min_spline = CubicSpline(
113 | key_frame_indices, min_key_vals, bc_type='natural')
114 | max_spline = CubicSpline(
115 | key_frame_indices, max_key_vals, bc_type='natural')
116 |
117 | x_full = np.arange(len(min_vals))
118 | min_vals_smooth = min_spline(x_full)
119 | max_vals_smooth = max_spline(x_full)
120 | # plt.plot(min_vals, label='min_vals')
121 | # plt.plot(min_vals_smooth, label='min_vals_smooth')
122 | # plt.legend()
123 | # plt.savefig('min_vals.png')
124 | return min_vals_smooth, max_vals_smooth
125 |
126 | if __name__ == '__main__':
127 | depth = np.random.rand(100, 100)
128 | visualize_depth(depth)
--------------------------------------------------------------------------------
/promptda/utils/io_wrapper.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import imageio
3 | import torch
4 | import os
5 | import matplotlib.pyplot as plt
6 | import cv2
7 |
8 | from promptda.utils.logger import Log
9 | from promptda.utils.depth_utils import visualize_depth
10 | from promptda.utils.parallel_utils import async_call
11 |
12 | # DEVICE = 'cuda' if torch.cuda.is_available(
13 | # ) else 'mps' if torch.backends.mps.is_available() else 'cpu'
14 |
15 |
16 | def to_tensor_func(arr):
17 | if arr.ndim == 2:
18 | arr = arr[:, :, np.newaxis]
19 | return torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0)
20 |
21 |
22 | def to_numpy_func(tensor):
23 | arr = tensor.squeeze(0).permute(1, 2, 0).cpu().numpy()
24 | if arr.shape[2] == 1:
25 | arr = arr[:, :, 0]
26 | return arr
27 |
28 |
29 | def ensure_multiple_of(x, multiple_of=14):
30 | return int(x // multiple_of * multiple_of)
31 |
32 |
33 | def load_image(image_path, to_tensor=True, max_size=1008, multiple_of=14):
34 | '''
35 | Load image from path and convert to tensor
36 | max_size // 14 = 0
37 | '''
38 | image = np.asarray(imageio.imread(image_path)).astype(np.float32)
39 | image = image / 255.
40 |
41 | max_size = max_size // multiple_of * multiple_of
42 | if max(image.shape) > max_size:
43 | h, w = image.shape[:2]
44 | scale = max_size / max(h, w)
45 | tar_h = ensure_multiple_of(h * scale)
46 | tar_w = ensure_multiple_of(w * scale)
47 | image = cv2.resize(image, (tar_w, tar_h), interpolation=cv2.INTER_AREA)
48 | if to_tensor:
49 | return to_tensor_func(image)
50 | return image
51 |
52 |
53 | def load_depth(depth_path, to_tensor=True):
54 | '''
55 | Load depth from path and convert to tensor
56 | '''
57 | if depth_path.endswith('.png'):
58 | depth = np.asarray(imageio.imread(depth_path)).astype(np.float32)
59 | depth = depth / 1000.
60 | elif depth_path.endswith('.npz'):
61 | depth = np.load(depth_path)['depth']
62 | else:
63 | raise ValueError(f"Unsupported depth format: {depth_path}")
64 | if to_tensor:
65 | return to_tensor_func(depth)
66 | return depth
67 |
68 |
69 | @async_call
70 | def save_depth(depth,
71 | prompt_depth=None,
72 | image=None,
73 | output_path='results/example_depth.png',
74 | save_vis=True):
75 | '''
76 | Save depth to path
77 | '''
78 | os.makedirs(os.path.dirname(output_path), exist_ok=True)
79 | depth = to_numpy_func(depth)
80 | uint16_depth = (depth * 1000.).astype(np.uint16)
81 | imageio.imwrite(output_path, uint16_depth)
82 | Log.info(f'Saved depth to {output_path}', tag='save_depth')
83 |
84 | if not save_vis:
85 | return
86 | output_path_ = output_path
87 | output_path = output_path_.replace('.png', '_depth.jpg')
88 | depth_vis, depth_min, depth_max = visualize_depth(depth, ret_minmax=True)
89 | imageio.imwrite(output_path, depth_vis)
90 |
91 |
92 | if prompt_depth is not None:
93 | prompt_depth = to_numpy_func(prompt_depth)
94 | output_path = output_path_.replace('.png', '_prompt_depth.jpg')
95 | prompt_depth_vis = visualize_depth(prompt_depth,
96 | depth_min=depth_min,
97 | depth_max=depth_max)
98 | imageio.imwrite(output_path, prompt_depth_vis)
99 |
100 | if image is not None:
101 | output_path = output_path_.replace('.png', '_image.jpg')
102 | image = to_numpy_func(image)
103 | imageio.imwrite(output_path, (image * 255).astype(np.uint8))
104 |
--------------------------------------------------------------------------------
/promptda/utils/logger.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 |
4 | class Log:
5 | log_on = True # fast switch
6 | used_tags = dict() # To keep track of used tags
7 | _is_main_cached = None # Cache to store the main process check result
8 |
9 | @staticmethod
10 | def is_main_process():
11 | if Log._is_main_cached is not None:
12 | return Log._is_main_cached
13 | try:
14 | from pytorch_lightning.utilities import rank_zero_only
15 | if rank_zero_only.rank == 0:
16 | Log._is_main_cached = True
17 | else:
18 | Log._is_main_cached = False
19 | except:
20 | Log._is_main_cached = True
21 | return Log._is_main_cached
22 |
23 | @staticmethod
24 | def _should_log(tag):
25 | """
26 | Determine if the log information should be recorded.
27 | Conditions: log function is enabled, current process is the main process, and the tag has not been used.
28 | """
29 | if not Log.log_on:
30 | return False
31 | if not Log.is_main_process():
32 | return False
33 | if tag is None:
34 | return True
35 | if '__' in tag:
36 | num = int(tag.split('__')[-1])
37 | tag = tag.split('__')[0] # can output num same information
38 | else:
39 | num = 3 # default 3
40 |
41 | if tag not in Log.used_tags:
42 | Log.used_tags[tag] = num
43 | Log.used_tags[tag] -= 1
44 | if Log.used_tags[tag] >= 0:
45 | return True
46 | else:
47 | return False
48 |
49 | @staticmethod
50 | def info(*args, tag=None):
51 | """
52 | Output INFO level log information.
53 | """
54 | if Log._should_log(tag):
55 | print("\033[1;32m[INFO]\033[0;0m", *args)
56 |
57 | @staticmethod
58 | def warn(*args, tag=None):
59 | """
60 | Output WARN level log information.
61 | """
62 | if Log._should_log(tag):
63 | print("\033[1;35m[WARN]\033[0;0m", *args)
64 |
65 | @staticmethod
66 | def error(*args, tag=None):
67 | print("\033[1;31m[ERROR]\033[0;0m", *args)
68 |
69 | @staticmethod
70 | def debug(*args, tag=None):
71 | """
72 | Output DEBUG level log information.
73 | """
74 | if Log._should_log(tag) and 'HT_DEBUG' in os.environ and os.environ['HT_DEBUG'] == '1':
75 | print("\033[1;33m[DEBUG]\033[0;0m", *args)
76 |
--------------------------------------------------------------------------------
/promptda/utils/parallel_utils.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, List, Dict
2 | from multiprocessing.pool import ThreadPool
3 | from tqdm import tqdm
4 | from threading import Thread
5 | import asyncio
6 | from functools import wraps
7 |
8 |
9 | def async_call_func(func):
10 | @wraps(func)
11 | async def wrapper(*args, **kwargs):
12 | loop = asyncio.get_event_loop()
13 | # Use run_in_executor to run the blocking function in a separate thread
14 | return await loop.run_in_executor(None, func, *args, **kwargs)
15 | return wrapper
16 |
17 |
18 | def async_call(fn):
19 | def wrapper(*args, **kwargs):
20 | Thread(target=fn, args=args, kwargs=kwargs).start()
21 | return wrapper
22 |
23 |
24 | def parallel_execution(*args, action: Callable, num_processes=32, print_progress=False, sequential=False, async_return=False, desc=None, **kwargs):
25 | # Copy from EasyVolCap
26 | # Author: Zhen Xu https://github.com/dendenxu
27 | # NOTE: we expect first arg / or kwargs to be distributed
28 | # NOTE: print_progress arg is reserved
29 | def get_length(args: List, kwargs: Dict):
30 | for a in args:
31 | if isinstance(a, list):
32 | return len(a)
33 | for v in kwargs.values():
34 | if isinstance(v, list):
35 | return len(v)
36 | raise NotImplementedError
37 |
38 | def get_action_args(length: int, args: List, kwargs: Dict, i: int):
39 | action_args = [(arg[i] if isinstance(arg, list) and len(
40 | arg) == length else arg) for arg in args]
41 | # TODO: Support all types of iterable
42 | action_kwargs = {key: (kwargs[key][i] if isinstance(kwargs[key], list) and len(
43 | kwargs[key]) == length else kwargs[key]) for key in kwargs}
44 | return action_args, action_kwargs
45 |
46 | if not sequential:
47 | # Create ThreadPool
48 | pool = ThreadPool(processes=num_processes)
49 |
50 | # Spawn threads
51 | results = []
52 | asyncs = []
53 | length = get_length(args, kwargs)
54 | for i in range(length):
55 | action_args, action_kwargs = get_action_args(
56 | length, args, kwargs, i)
57 | async_result = pool.apply_async(action, action_args, action_kwargs)
58 | asyncs.append(async_result)
59 |
60 | # Join threads and get return values
61 | if not async_return:
62 | for async_result in tqdm(asyncs, desc=desc, disable=not print_progress):
63 | # will sync the corresponding thread
64 | results.append(async_result.get())
65 | pool.close()
66 | pool.join()
67 | return results
68 | else:
69 | return pool
70 | else:
71 | results = []
72 | length = get_length(args, kwargs)
73 | for i in tqdm(range(length), desc=desc, disable=not print_progress):
74 | action_args, action_kwargs = get_action_args(
75 | length, args, kwargs, i)
76 | async_result = action(*action_args, **action_kwargs)
77 | results.append(async_result)
78 | return results
79 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # python version >= 3.9 <= 3.11
2 | # torch
3 | torch==2.0.1
4 | torchvision==0.15.2
5 | torchaudio==2.0.2
6 | xformers==0.0.22
7 | #
8 | lightning==2.1.3
9 | imageio>=2.33.1
10 | Pillow>=10.1.0
11 | imageio-ffmpeg
12 | einops
13 | tqdm
14 | ipdb
15 | # Diffusion
16 | transformers
17 | # ray
18 | termcolor
19 | numpy==1.26.4
20 | opencv-python==4.9.0.80
21 | scipy
22 | matplotlib
23 | h5py
24 | tyro==0.9.2
25 | # open3d
26 |
27 | # app.py
28 | gradio==4.44.1
29 | gradio-imageslider==0.0.20
30 | spaces
31 | trimesh
32 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name="promptda",
5 | version="1.0",
6 | packages=find_packages(where="src"),
7 | author="Haotong Lin",
8 | description=["Prompt Depth Anything"],
9 | url="https://github.com/DepthAnything/PromptDA",
10 | )
11 |
--------------------------------------------------------------------------------
/torchhub/README.md:
--------------------------------------------------------------------------------
1 | # Local PyTorch Hub
2 |
3 | This directory is for loading the DINOv2 encoder locally in case of no Internet connection.
4 |
--------------------------------------------------------------------------------
/torchhub/facebookresearch_dinov2_main/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | In the interest of fostering an open and welcoming environment, we as
6 | contributors and maintainers pledge to make participation in our project and
7 | our community a harassment-free experience for everyone, regardless of age, body
8 | size, disability, ethnicity, sex characteristics, gender identity and expression,
9 | level of experience, education, socio-economic status, nationality, personal
10 | appearance, race, religion, or sexual identity and orientation.
11 |
12 | ## Our Standards
13 |
14 | Examples of behavior that contributes to creating a positive environment
15 | include:
16 |
17 | * Using welcoming and inclusive language
18 | * Being respectful of differing viewpoints and experiences
19 | * Gracefully accepting constructive criticism
20 | * Focusing on what is best for the community
21 | * Showing empathy towards other community members
22 |
23 | Examples of unacceptable behavior by participants include:
24 |
25 | * The use of sexualized language or imagery and unwelcome sexual attention or
26 | advances
27 | * Trolling, insulting/derogatory comments, and personal or political attacks
28 | * Public or private harassment
29 | * Publishing others' private information, such as a physical or electronic
30 | address, without explicit permission
31 | * Other conduct which could reasonably be considered inappropriate in a
32 | professional setting
33 |
34 | ## Our Responsibilities
35 |
36 | Project maintainers are responsible for clarifying the standards of acceptable
37 | behavior and are expected to take appropriate and fair corrective action in
38 | response to any instances of unacceptable behavior.
39 |
40 | Project maintainers have the right and responsibility to remove, edit, or
41 | reject comments, commits, code, wiki edits, issues, and other contributions
42 | that are not aligned to this Code of Conduct, or to ban temporarily or
43 | permanently any contributor for other behaviors that they deem inappropriate,
44 | threatening, offensive, or harmful.
45 |
46 | ## Scope
47 |
48 | This Code of Conduct applies within all project spaces, and it also applies when
49 | an individual is representing the project or its community in public spaces.
50 | Examples of representing a project or community include using an official
51 | project e-mail address, posting via an official social media account, or acting
52 | as an appointed representative at an online or offline event. Representation of
53 | a project may be further defined and clarified by project maintainers.
54 |
55 | This Code of Conduct also applies outside the project spaces when there is a
56 | reasonable belief that an individual's behavior may have a negative impact on
57 | the project or its community.
58 |
59 | ## Enforcement
60 |
61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
62 | reported by contacting the project team at . All
63 | complaints will be reviewed and investigated and will result in a response that
64 | is deemed necessary and appropriate to the circumstances. The project team is
65 | obligated to maintain confidentiality with regard to the reporter of an incident.
66 | Further details of specific enforcement policies may be posted separately.
67 |
68 | Project maintainers who do not follow or enforce the Code of Conduct in good
69 | faith may face temporary or permanent repercussions as determined by other
70 | members of the project's leadership.
71 |
72 | ## Attribution
73 |
74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
76 |
77 | [homepage]: https://www.contributor-covenant.org
78 |
79 | For answers to common questions about this code of conduct, see
80 | https://www.contributor-covenant.org/faq
81 |
--------------------------------------------------------------------------------
/torchhub/facebookresearch_dinov2_main/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing to DINOv2
2 | We want to make contributing to this project as easy and transparent as
3 | possible.
4 |
5 | ## Pull Requests
6 | We actively welcome your pull requests.
7 |
8 | 1. Fork the repo and create your branch from `main`.
9 | 2. If you've added code that should be tested, add tests.
10 | 3. If you've changed APIs, update the documentation.
11 | 4. Ensure the test suite passes.
12 | 5. Make sure your code lints.
13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA").
14 |
15 | ## Contributor License Agreement ("CLA")
16 | In order to accept your pull request, we need you to submit a CLA. You only need
17 | to do this once to work on any of Meta's open source projects.
18 |
19 | Complete your CLA here:
20 |
21 | ## Issues
22 | We use GitHub issues to track public bugs. Please ensure your description is
23 | clear and has sufficient instructions to be able to reproduce the issue.
24 |
25 | Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe
26 | disclosure of security bugs. In those cases, please go through the process
27 | outlined on that page and do not file a public issue.
28 |
29 | ## License
30 | By contributing to DINOv2, you agree that your contributions will be licensed
31 | under the LICENSE file in the root directory of this source tree.
32 |
--------------------------------------------------------------------------------
/torchhub/facebookresearch_dinov2_main/MODEL_CARD.md:
--------------------------------------------------------------------------------
1 | # Model Card for DINOv2-S/B/L/g
2 |
3 | These are Vision Transformer models trained following the method described in the paper:
4 | "DINOv2: Learning Robust Visual Features without Supervision"
5 |
6 | We provide 4 models: 1 ViT-g trained from scratch, and 3 ViT-S/B/L models distilled from the ViT-g.
7 |
8 | ## Model Details
9 | The model takes an image as input and returns a class token and patch tokens.
10 |
11 | The embedding dimension is:
12 | - 384 for ViT-S.
13 | - 768 for ViT-B.
14 | - 1024 for ViT-L.
15 | - 1536 for ViT-g.
16 |
17 | The models follow a Transformer architecture, with a patch size of 14.
18 |
19 | For a 224x224 image, this results in 1 class token + 256 patch tokens.
20 |
21 | The models can accept larger images provided the image shapes are multiples of the patch size (14).
22 | If this condition is not verified, the model will crop to the closest smaller multiple of the patch size.
23 |
24 | ### Model Description
25 |
26 | - **Developed by:** Meta AI
27 | - **Model type:** Vision Transformer
28 | - **License:** CC-BY-NC
29 |
30 | - **Repository:** https://github.com/facebookresearch/dinov2
31 | - **Paper:** https://arxiv.org/abs/2304.07193
32 | - **Demo:** https://dinov2.metademolab.com/
33 |
34 | ## Uses
35 |
36 | The models are vision backbones providing multi-purpose features for downstream tasks.
37 |
38 | ### Direct Use
39 |
40 | The models can be used without fine-tuning, with downstream classifiers as simple as linear layers, to obtain competitive results:
41 | - on depth estimation, semantic segmentation, using linear layers.
42 | - on image classification, using k-NN classifiers on the class token.
43 | - on image classification, with logistic regression classifiers applied on the class token.
44 | - on image classification, with a linear layer applied on the class token and the average of the patch tokens.
45 | - on image retrieval using nearest neighbors.
46 |
47 | ### Downstream Use
48 |
49 | It is technically possible to perform fine-tuning on the models, for small gains (we measured +2% on ImageNet-1k classification).
50 | We recommend keeping this as a very last step and only when necessary, as the features already provide good performance out-of-the-box.
51 |
52 | ## Bias, Risks, and Limitations
53 |
54 | Despite improvements thanks to the training method not using annotations, we still observe significant biases in our models toward rich households from Western countries.
55 |
56 | ### Recommendations
57 |
58 | We expect fine-tuning will increase the biases in the features produced by the model as they will be tuned to the fine-tuning labels.
59 |
60 | ## How to Get Started with the Model
61 |
62 | Use the code below to get started with the model.
63 |
64 | ```python
65 | import torch
66 | dinov2_vits14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
67 | dinov2_vitb14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14')
68 | dinov2_vitl14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14')
69 | dinov2_vitg14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14')
70 | ```
71 |
72 | ## Training Details
73 |
74 | ### Training Data
75 |
76 | - **Training data:** LVD-142M (see paper)
77 | - **Training regime:** fp16 using PyTorch-FSDP mixed-precision.
78 |
79 | ### Training Procedure
80 |
81 | - **Training objective:**
82 | - DINO self-distillation loss with multi-crop
83 | - iBOT masked-image modeling loss
84 | - KoLeo regularization on [CLS] tokens
85 | - **Architectures:**
86 | - ViT-S (21M params): Patch size 14, embedding dimension 384, 6 heads, MLP FFN
87 | - ViT-B (86M params): Patch size 14, embedding dimension 768, 12 heads, MLP FFN
88 | - ViT-L (0.3B params): Patch size 14, embedding dimension 1024, 16 heads, MLP FFN
89 | - ViT-g (1.1B params): Patch size 14, embedding dimension 1536, 24 heads, SwiGLU FFN
90 | - **Distillation:**
91 | - Distillation follows the standard DINOv2 pretraining procedure, except the teacher is a pretrained ViT-g, frozen.
92 |
93 | ## Evaluation
94 |
95 | We refer users to the associated paper for the evaluation protocols.
96 |
97 |
98 |
99 | model |
100 | ImageNet-1k |
101 | NYU-Depth v2 |
102 | SUN-RGBD |
103 | ADE20k |
104 | iNaturalist 2018 |
105 | Oxford-H |
106 |
107 |
108 | task |
109 | classif. (acc) |
110 | classif. (acc) |
111 | classif. V2 (acc) |
112 | depth (RMSE) |
113 | depth (RMSE) |
114 | segm. (mAP) |
115 | classif. (acc) |
116 | retrieval (mAP) |
117 |
118 |
119 |
120 | k-NN |
121 | linear |
122 | linear |
123 | linear 4 layers |
124 | NYU-D transfer |
125 | multiscale |
126 | linear |
127 | nearest neighbor |
128 |
129 |
130 | ViT-S/14 |
131 | 79.0% |
132 | 81.1% |
133 | 70.8% |
134 | 0.417 |
135 | 0.431 |
136 | 47.2 |
137 | 69.5% |
138 | 43.2 |
139 |
140 |
141 | ViT-B/14 |
142 | 82.1% |
143 | 84.5% |
144 | 74.9% |
145 | 0.362 |
146 | 0.400 |
147 | 51.3 |
148 | 76.3% |
149 | 49.5 |
150 |
151 |
152 | ViT-L/14 |
153 | 83.5% |
154 | 86.3% |
155 | 77.6% |
156 | 0.333 |
157 | 0.396 |
158 | 53.1 |
159 | 79.8% |
160 | 54.0 |
161 |
162 |
163 | ViT-g/14 |
164 | 83.5% |
165 | 86.5% |
166 | 78.4% |
167 | 0.298 |
168 | 0.362 |
169 | 53.0 |
170 | 81.6% |
171 | 52.3 |
172 |
173 |
174 |
175 | ## Environmental Impact
176 |
177 | - **Hardware Type:** Nvidia A100
178 | - **Hours used:** 22,000 for ViT-g, 4,500 for ViT-S distillation, 5,300 for ViT-B distillation, 8,000 for ViT-L distillation
179 | - **Cloud Provider:** Private infra
180 | - **Compute Region:** USA
181 | - **Carbon Emitted:** 7t CO2eq
182 |
183 | #### Hardware
184 |
185 | Nvidia A100 GPUs
186 |
187 | #### Software
188 |
189 | PyTorch 2.0,
190 | xFormers 0.0.18
191 |
192 | **BibTeX**
193 |
194 | ```
195 | @misc{oquab2023dinov2,
196 | title={DINOv2: Learning Robust Visual Features without Supervision},
197 | author={Oquab, Maxime and Darcet, Timothée and Moutakanni, Theo and Vo, Huy and Szafraniec, Marc and Khalidov, Vasil and Fernandez, Pierre and Haziza, Daniel and Massa, Francisco and El-Nouby, Alaaeldin and Howes, Russell and Huang, Po-Yao and Xu, Hu and Sharma, Vasu and Li, Shang-Wen and Galuba, Wojciech and Rabbat, Mike and Assran, Mido and Ballas, Nicolas and Synnaeve, Gabriel and Misra, Ishan and Jegou, Herve and Mairal, Julien and Labatut, Patrick and Joulin, Armand and Bojanowski, Piotr},
198 | journal={arXiv:2304.07193},
199 | year={2023}
200 | }
201 | ```
202 |
--------------------------------------------------------------------------------
/torchhub/facebookresearch_dinov2_main/conda.yaml:
--------------------------------------------------------------------------------
1 | name: dinov2
2 | channels:
3 | - defaults
4 | - pytorch
5 | - nvidia
6 | - xformers
7 | - conda-forge
8 | dependencies:
9 | - python=3.9
10 | - pytorch::pytorch=2.0.0
11 | - pytorch::pytorch-cuda=11.7.0
12 | - pytorch::torchvision=0.15.0
13 | - omegaconf
14 | - torchmetrics=0.10.3
15 | - fvcore
16 | - iopath
17 | - xformers::xformers=0.0.18
18 | - pip
19 | - pip:
20 | - git+https://github.com/facebookincubator/submitit
21 | - --extra-index-url https://pypi.nvidia.com
22 | - cuml-cu11
23 |
--------------------------------------------------------------------------------
/torchhub/facebookresearch_dinov2_main/dinov2/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | __version__ = "0.0.1"
8 |
--------------------------------------------------------------------------------
/torchhub/facebookresearch_dinov2_main/dinov2/configs/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import pathlib
8 |
9 | from omegaconf import OmegaConf
10 |
11 |
12 | def load_config(config_name: str):
13 | config_filename = config_name + ".yaml"
14 | return OmegaConf.load(pathlib.Path(__file__).parent.resolve() / config_filename)
15 |
16 |
17 | dinov2_default_config = load_config("ssl_default_config")
18 |
19 |
20 | def load_and_merge_config(config_name: str):
21 | default_config = OmegaConf.create(dinov2_default_config)
22 | loaded_config = load_config(config_name)
23 | return OmegaConf.merge(default_config, loaded_config)
24 |
--------------------------------------------------------------------------------
/torchhub/facebookresearch_dinov2_main/dinov2/configs/eval/vitb14_pretrain.yaml:
--------------------------------------------------------------------------------
1 | student:
2 | arch: vit_base
3 | patch_size: 14
4 | crops:
5 | global_crops_size: 518 # this is to set up the position embeddings properly
6 | local_crops_size: 98
--------------------------------------------------------------------------------
/torchhub/facebookresearch_dinov2_main/dinov2/configs/eval/vitg14_pretrain.yaml:
--------------------------------------------------------------------------------
1 | student:
2 | arch: vit_giant2
3 | patch_size: 14
4 | ffn_layer: swiglufused
5 | crops:
6 | global_crops_size: 518 # this is to set up the position embeddings properly
7 | local_crops_size: 98
--------------------------------------------------------------------------------
/torchhub/facebookresearch_dinov2_main/dinov2/configs/eval/vitl14_pretrain.yaml:
--------------------------------------------------------------------------------
1 | student:
2 | arch: vit_large
3 | patch_size: 14
4 | crops:
5 | global_crops_size: 518 # this is to set up the position embeddings properly
6 | local_crops_size: 98
--------------------------------------------------------------------------------
/torchhub/facebookresearch_dinov2_main/dinov2/configs/eval/vits14_pretrain.yaml:
--------------------------------------------------------------------------------
1 | student:
2 | arch: vit_small
3 | patch_size: 14
4 | crops:
5 | global_crops_size: 518 # this is to set up the position embeddings properly
6 | local_crops_size: 98
--------------------------------------------------------------------------------
/torchhub/facebookresearch_dinov2_main/dinov2/configs/ssl_default_config.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | WEIGHTS: ''
3 | compute_precision:
4 | grad_scaler: true
5 | teacher:
6 | backbone:
7 | sharding_strategy: SHARD_GRAD_OP
8 | mixed_precision:
9 | param_dtype: fp16
10 | reduce_dtype: fp16
11 | buffer_dtype: fp32
12 | dino_head:
13 | sharding_strategy: SHARD_GRAD_OP
14 | mixed_precision:
15 | param_dtype: fp16
16 | reduce_dtype: fp16
17 | buffer_dtype: fp32
18 | ibot_head:
19 | sharding_strategy: SHARD_GRAD_OP
20 | mixed_precision:
21 | param_dtype: fp16
22 | reduce_dtype: fp16
23 | buffer_dtype: fp32
24 | student:
25 | backbone:
26 | sharding_strategy: SHARD_GRAD_OP
27 | mixed_precision:
28 | param_dtype: fp16
29 | reduce_dtype: fp16
30 | buffer_dtype: fp32
31 | dino_head:
32 | sharding_strategy: SHARD_GRAD_OP
33 | mixed_precision:
34 | param_dtype: fp16
35 | reduce_dtype: fp32
36 | buffer_dtype: fp32
37 | ibot_head:
38 | sharding_strategy: SHARD_GRAD_OP
39 | mixed_precision:
40 | param_dtype: fp16
41 | reduce_dtype: fp32
42 | buffer_dtype: fp32
43 | dino:
44 | loss_weight: 1.0
45 | head_n_prototypes: 65536
46 | head_bottleneck_dim: 256
47 | head_nlayers: 3
48 | head_hidden_dim: 2048
49 | koleo_loss_weight: 0.1
50 | ibot:
51 | loss_weight: 1.0
52 | mask_sample_probability: 0.5
53 | mask_ratio_min_max:
54 | - 0.1
55 | - 0.5
56 | separate_head: false
57 | head_n_prototypes: 65536
58 | head_bottleneck_dim: 256
59 | head_nlayers: 3
60 | head_hidden_dim: 2048
61 | train:
62 | batch_size_per_gpu: 64
63 | dataset_path: ImageNet:split=TRAIN
64 | output_dir: .
65 | saveckp_freq: 20
66 | seed: 0
67 | num_workers: 10
68 | OFFICIAL_EPOCH_LENGTH: 1250
69 | cache_dataset: true
70 | centering: "centering" # or "sinkhorn_knopp"
71 | student:
72 | arch: vit_large
73 | patch_size: 16
74 | drop_path_rate: 0.3
75 | layerscale: 1.0e-05
76 | drop_path_uniform: true
77 | pretrained_weights: ''
78 | ffn_layer: "mlp"
79 | block_chunks: 0
80 | qkv_bias: true
81 | proj_bias: true
82 | ffn_bias: true
83 | teacher:
84 | momentum_teacher: 0.992
85 | final_momentum_teacher: 1
86 | warmup_teacher_temp: 0.04
87 | teacher_temp: 0.07
88 | warmup_teacher_temp_epochs: 30
89 | optim:
90 | epochs: 100
91 | weight_decay: 0.04
92 | weight_decay_end: 0.4
93 | base_lr: 0.004 # learning rate for a batch size of 1024
94 | lr: 0. # will be set after applying scaling rule
95 | warmup_epochs: 10
96 | min_lr: 1.0e-06
97 | clip_grad: 3.0
98 | freeze_last_layer_epochs: 1
99 | scaling_rule: sqrt_wrt_1024
100 | patch_embed_lr_mult: 0.2
101 | layerwise_decay: 0.9
102 | adamw_beta1: 0.9
103 | adamw_beta2: 0.999
104 | crops:
105 | global_crops_scale:
106 | - 0.32
107 | - 1.0
108 | local_crops_number: 8
109 | local_crops_scale:
110 | - 0.05
111 | - 0.32
112 | global_crops_size: 224
113 | local_crops_size: 96
114 | evaluation:
115 | eval_period_iterations: 12500
116 |
--------------------------------------------------------------------------------
/torchhub/facebookresearch_dinov2_main/dinov2/configs/train/vitg14.yaml:
--------------------------------------------------------------------------------
1 | dino:
2 | head_n_prototypes: 131072
3 | head_bottleneck_dim: 384
4 | ibot:
5 | separate_head: true
6 | head_n_prototypes: 131072
7 | train:
8 | batch_size_per_gpu: 12
9 | dataset_path: ImageNet22k
10 | centering: sinkhorn_knopp
11 | student:
12 | arch: vit_giant2
13 | patch_size: 14
14 | drop_path_rate: 0.4
15 | ffn_layer: swiglufused
16 | block_chunks: 4
17 | teacher:
18 | momentum_teacher: 0.994
19 | optim:
20 | epochs: 500
21 | weight_decay_end: 0.2
22 | base_lr: 2.0e-04 # learning rate for a batch size of 1024
23 | warmup_epochs: 80
24 | layerwise_decay: 1.0
25 | crops:
26 | local_crops_size: 98
--------------------------------------------------------------------------------
/torchhub/facebookresearch_dinov2_main/dinov2/configs/train/vitl14.yaml:
--------------------------------------------------------------------------------
1 | dino:
2 | head_n_prototypes: 131072
3 | head_bottleneck_dim: 384
4 | ibot:
5 | separate_head: true
6 | head_n_prototypes: 131072
7 | train:
8 | batch_size_per_gpu: 32
9 | dataset_path: ImageNet22k
10 | centering: sinkhorn_knopp
11 | student:
12 | arch: vit_large
13 | patch_size: 14
14 | drop_path_rate: 0.4
15 | ffn_layer: swiglufused
16 | block_chunks: 4
17 | teacher:
18 | momentum_teacher: 0.994
19 | optim:
20 | epochs: 500
21 | weight_decay_end: 0.2
22 | base_lr: 2.0e-04 # learning rate for a batch size of 1024
23 | warmup_epochs: 80
24 | layerwise_decay: 1.0
25 | crops:
26 | local_crops_size: 98
--------------------------------------------------------------------------------
/torchhub/facebookresearch_dinov2_main/dinov2/configs/train/vitl16_short.yaml:
--------------------------------------------------------------------------------
1 | # this corresponds to the default config
2 | train:
3 | dataset_path: ImageNet:split=TRAIN
4 | batch_size_per_gpu: 64
5 | student:
6 | block_chunks: 4
7 |
--------------------------------------------------------------------------------
/torchhub/facebookresearch_dinov2_main/dinov2/distributed/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import os
8 | import random
9 | import re
10 | import socket
11 | from typing import Dict, List
12 |
13 | import torch
14 | import torch.distributed as dist
15 |
16 | _LOCAL_RANK = -1
17 | _LOCAL_WORLD_SIZE = -1
18 |
19 |
20 | def is_enabled() -> bool:
21 | """
22 | Returns:
23 | True if distributed training is enabled
24 | """
25 | return dist.is_available() and dist.is_initialized()
26 |
27 |
28 | def get_global_size() -> int:
29 | """
30 | Returns:
31 | The number of processes in the process group
32 | """
33 | return dist.get_world_size() if is_enabled() else 1
34 |
35 |
36 | def get_global_rank() -> int:
37 | """
38 | Returns:
39 | The rank of the current process within the global process group.
40 | """
41 | return dist.get_rank() if is_enabled() else 0
42 |
43 |
44 | def get_local_rank() -> int:
45 | """
46 | Returns:
47 | The rank of the current process within the local (per-machine) process group.
48 | """
49 | if not is_enabled():
50 | return 0
51 | assert 0 <= _LOCAL_RANK < _LOCAL_WORLD_SIZE
52 | return _LOCAL_RANK
53 |
54 |
55 | def get_local_size() -> int:
56 | """
57 | Returns:
58 | The size of the per-machine process group,
59 | i.e. the number of processes per machine.
60 | """
61 | if not is_enabled():
62 | return 1
63 | assert 0 <= _LOCAL_RANK < _LOCAL_WORLD_SIZE
64 | return _LOCAL_WORLD_SIZE
65 |
66 |
67 | def is_main_process() -> bool:
68 | """
69 | Returns:
70 | True if the current process is the main one.
71 | """
72 | return get_global_rank() == 0
73 |
74 |
75 | def _restrict_print_to_main_process() -> None:
76 | """
77 | This function disables printing when not in the main process
78 | """
79 | import builtins as __builtin__
80 |
81 | builtin_print = __builtin__.print
82 |
83 | def print(*args, **kwargs):
84 | force = kwargs.pop("force", False)
85 | if is_main_process() or force:
86 | builtin_print(*args, **kwargs)
87 |
88 | __builtin__.print = print
89 |
90 |
91 | def _get_master_port(seed: int = 0) -> int:
92 | MIN_MASTER_PORT, MAX_MASTER_PORT = (20_000, 60_000)
93 |
94 | master_port_str = os.environ.get("MASTER_PORT")
95 | if master_port_str is None:
96 | rng = random.Random(seed)
97 | return rng.randint(MIN_MASTER_PORT, MAX_MASTER_PORT)
98 |
99 | return int(master_port_str)
100 |
101 |
102 | def _get_available_port() -> int:
103 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
104 | # A "" host address means INADDR_ANY i.e. binding to all interfaces.
105 | # Note this is not compatible with IPv6.
106 | s.bind(("", 0))
107 | port = s.getsockname()[1]
108 | return port
109 |
110 |
111 | _TORCH_DISTRIBUTED_ENV_VARS = (
112 | "MASTER_ADDR",
113 | "MASTER_PORT",
114 | "RANK",
115 | "WORLD_SIZE",
116 | "LOCAL_RANK",
117 | "LOCAL_WORLD_SIZE",
118 | )
119 |
120 |
121 | def _collect_env_vars() -> Dict[str, str]:
122 | return {env_var: os.environ[env_var] for env_var in _TORCH_DISTRIBUTED_ENV_VARS if env_var in os.environ}
123 |
124 |
125 | def _is_slurm_job_process() -> bool:
126 | return "SLURM_JOB_ID" in os.environ
127 |
128 |
129 | def _parse_slurm_node_list(s: str) -> List[str]:
130 | nodes = []
131 | # Extract "hostname", "hostname[1-2,3,4-5]," substrings
132 | p = re.compile(r"(([^\[]+)(?:\[([^\]]+)\])?),?")
133 | for m in p.finditer(s):
134 | prefix, suffixes = s[m.start(2) : m.end(2)], s[m.start(3) : m.end(3)]
135 | for suffix in suffixes.split(","):
136 | span = suffix.split("-")
137 | if len(span) == 1:
138 | nodes.append(prefix + suffix)
139 | else:
140 | width = len(span[0])
141 | start, end = int(span[0]), int(span[1]) + 1
142 | nodes.extend([prefix + f"{i:0{width}}" for i in range(start, end)])
143 | return nodes
144 |
145 |
146 | def _check_env_variable(key: str, new_value: str):
147 | # Only check for difference with preset environment variables
148 | if key in os.environ and os.environ[key] != new_value:
149 | raise RuntimeError(f"Cannot export environment variables as {key} is already set")
150 |
151 |
152 | class _TorchDistributedEnvironment:
153 | def __init__(self):
154 | self.master_addr = "127.0.0.1"
155 | self.master_port = 0
156 | self.rank = -1
157 | self.world_size = -1
158 | self.local_rank = -1
159 | self.local_world_size = -1
160 |
161 | if _is_slurm_job_process():
162 | return self._set_from_slurm_env()
163 |
164 | env_vars = _collect_env_vars()
165 | if not env_vars:
166 | # Environment is not set
167 | pass
168 | elif len(env_vars) == len(_TORCH_DISTRIBUTED_ENV_VARS):
169 | # Environment is fully set
170 | return self._set_from_preset_env()
171 | else:
172 | # Environment is partially set
173 | collected_env_vars = ", ".join(env_vars.keys())
174 | raise RuntimeError(f"Partially set environment: {collected_env_vars}")
175 |
176 | if torch.cuda.device_count() > 0:
177 | return self._set_from_local()
178 |
179 | raise RuntimeError("Can't initialize PyTorch distributed environment")
180 |
181 | # Slurm job created with sbatch, submitit, etc...
182 | def _set_from_slurm_env(self):
183 | # logger.info("Initialization from Slurm environment")
184 | job_id = int(os.environ["SLURM_JOB_ID"])
185 | node_count = int(os.environ["SLURM_JOB_NUM_NODES"])
186 | nodes = _parse_slurm_node_list(os.environ["SLURM_JOB_NODELIST"])
187 | assert len(nodes) == node_count
188 |
189 | self.master_addr = nodes[0]
190 | self.master_port = _get_master_port(seed=job_id)
191 | self.rank = int(os.environ["SLURM_PROCID"])
192 | self.world_size = int(os.environ["SLURM_NTASKS"])
193 | assert self.rank < self.world_size
194 | self.local_rank = int(os.environ["SLURM_LOCALID"])
195 | self.local_world_size = self.world_size // node_count
196 | assert self.local_rank < self.local_world_size
197 |
198 | # Single node job with preset environment (i.e. torchrun)
199 | def _set_from_preset_env(self):
200 | # logger.info("Initialization from preset environment")
201 | self.master_addr = os.environ["MASTER_ADDR"]
202 | self.master_port = os.environ["MASTER_PORT"]
203 | self.rank = int(os.environ["RANK"])
204 | self.world_size = int(os.environ["WORLD_SIZE"])
205 | assert self.rank < self.world_size
206 | self.local_rank = int(os.environ["LOCAL_RANK"])
207 | self.local_world_size = int(os.environ["LOCAL_WORLD_SIZE"])
208 | assert self.local_rank < self.local_world_size
209 |
210 | # Single node and GPU job (i.e. local script run)
211 | def _set_from_local(self):
212 | # logger.info("Initialization from local")
213 | self.master_addr = "127.0.0.1"
214 | self.master_port = _get_available_port()
215 | self.rank = 0
216 | self.world_size = 1
217 | self.local_rank = 0
218 | self.local_world_size = 1
219 |
220 | def export(self, *, overwrite: bool) -> "_TorchDistributedEnvironment":
221 | # See the "Environment variable initialization" section from
222 | # https://pytorch.org/docs/stable/distributed.html for the complete list of
223 | # environment variables required for the env:// initialization method.
224 | env_vars = {
225 | "MASTER_ADDR": self.master_addr,
226 | "MASTER_PORT": str(self.master_port),
227 | "RANK": str(self.rank),
228 | "WORLD_SIZE": str(self.world_size),
229 | "LOCAL_RANK": str(self.local_rank),
230 | "LOCAL_WORLD_SIZE": str(self.local_world_size),
231 | }
232 | if not overwrite:
233 | for k, v in env_vars.items():
234 | _check_env_variable(k, v)
235 |
236 | os.environ.update(env_vars)
237 | return self
238 |
239 |
240 | def enable(*, set_cuda_current_device: bool = True, overwrite: bool = False, allow_nccl_timeout: bool = False):
241 | """Enable distributed mode
242 |
243 | Args:
244 | set_cuda_current_device: If True, call torch.cuda.set_device() to set the
245 | current PyTorch CUDA device to the one matching the local rank.
246 | overwrite: If True, overwrites already set variables. Else fails.
247 | """
248 |
249 | global _LOCAL_RANK, _LOCAL_WORLD_SIZE
250 | if _LOCAL_RANK >= 0 or _LOCAL_WORLD_SIZE >= 0:
251 | raise RuntimeError("Distributed mode has already been enabled")
252 | torch_env = _TorchDistributedEnvironment()
253 | torch_env.export(overwrite=overwrite)
254 |
255 | if set_cuda_current_device:
256 | torch.cuda.set_device(torch_env.local_rank)
257 |
258 | if allow_nccl_timeout:
259 | # This allows to use torch distributed timeout in a NCCL backend
260 | key, value = "NCCL_ASYNC_ERROR_HANDLING", "1"
261 | if not overwrite:
262 | _check_env_variable(key, value)
263 | os.environ[key] = value
264 |
265 | dist.init_process_group(backend="nccl")
266 | dist.barrier()
267 |
268 | # Finalize setup
269 | _LOCAL_RANK = torch_env.local_rank
270 | _LOCAL_WORLD_SIZE = torch_env.local_world_size
271 | _restrict_print_to_main_process()
272 |
--------------------------------------------------------------------------------
/torchhub/facebookresearch_dinov2_main/dinov2/eval/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/torchhub/facebookresearch_dinov2_main/dinov2/eval/metrics.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from enum import Enum
8 | import logging
9 | from typing import Any, Dict, Optional
10 |
11 | import torch
12 | from torch import Tensor
13 | from torchmetrics import Metric, MetricCollection
14 | from torchmetrics.classification import MulticlassAccuracy
15 | from torchmetrics.utilities.data import dim_zero_cat, select_topk
16 |
17 |
18 | logger = logging.getLogger("dinov2")
19 |
20 |
21 | class MetricType(Enum):
22 | MEAN_ACCURACY = "mean_accuracy"
23 | MEAN_PER_CLASS_ACCURACY = "mean_per_class_accuracy"
24 | PER_CLASS_ACCURACY = "per_class_accuracy"
25 | IMAGENET_REAL_ACCURACY = "imagenet_real_accuracy"
26 |
27 | @property
28 | def accuracy_averaging(self):
29 | return getattr(AccuracyAveraging, self.name, None)
30 |
31 | def __str__(self):
32 | return self.value
33 |
34 |
35 | class AccuracyAveraging(Enum):
36 | MEAN_ACCURACY = "micro"
37 | MEAN_PER_CLASS_ACCURACY = "macro"
38 | PER_CLASS_ACCURACY = "none"
39 |
40 | def __str__(self):
41 | return self.value
42 |
43 |
44 | def build_metric(metric_type: MetricType, *, num_classes: int, ks: Optional[tuple] = None):
45 | if metric_type.accuracy_averaging is not None:
46 | return build_topk_accuracy_metric(
47 | average_type=metric_type.accuracy_averaging,
48 | num_classes=num_classes,
49 | ks=(1, 5) if ks is None else ks,
50 | )
51 | elif metric_type == MetricType.IMAGENET_REAL_ACCURACY:
52 | return build_topk_imagenet_real_accuracy_metric(
53 | num_classes=num_classes,
54 | ks=(1, 5) if ks is None else ks,
55 | )
56 |
57 | raise ValueError(f"Unknown metric type {metric_type}")
58 |
59 |
60 | def build_topk_accuracy_metric(average_type: AccuracyAveraging, num_classes: int, ks: tuple = (1, 5)):
61 | metrics: Dict[str, Metric] = {
62 | f"top-{k}": MulticlassAccuracy(top_k=k, num_classes=int(num_classes), average=average_type.value) for k in ks
63 | }
64 | return MetricCollection(metrics)
65 |
66 |
67 | def build_topk_imagenet_real_accuracy_metric(num_classes: int, ks: tuple = (1, 5)):
68 | metrics: Dict[str, Metric] = {f"top-{k}": ImageNetReaLAccuracy(top_k=k, num_classes=int(num_classes)) for k in ks}
69 | return MetricCollection(metrics)
70 |
71 |
72 | class ImageNetReaLAccuracy(Metric):
73 | is_differentiable: bool = False
74 | higher_is_better: Optional[bool] = None
75 | full_state_update: bool = False
76 |
77 | def __init__(
78 | self,
79 | num_classes: int,
80 | top_k: int = 1,
81 | **kwargs: Any,
82 | ) -> None:
83 | super().__init__(**kwargs)
84 | self.num_classes = num_classes
85 | self.top_k = top_k
86 | self.add_state("tp", [], dist_reduce_fx="cat")
87 |
88 | def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
89 | # preds [B, D]
90 | # target [B, A]
91 | # preds_oh [B, D] with 0 and 1
92 | # select top K highest probabilities, use one hot representation
93 | preds_oh = select_topk(preds, self.top_k)
94 | # target_oh [B, D + 1] with 0 and 1
95 | target_oh = torch.zeros((preds_oh.shape[0], preds_oh.shape[1] + 1), device=target.device, dtype=torch.int32)
96 | target = target.long()
97 | # for undefined targets (-1) use a fake value `num_classes`
98 | target[target == -1] = self.num_classes
99 | # fill targets, use one hot representation
100 | target_oh.scatter_(1, target, 1)
101 | # target_oh [B, D] (remove the fake target at index `num_classes`)
102 | target_oh = target_oh[:, :-1]
103 | # tp [B] with 0 and 1
104 | tp = (preds_oh * target_oh == 1).sum(dim=1)
105 | # at least one match between prediction and target
106 | tp.clip_(max=1)
107 | # ignore instances where no targets are defined
108 | mask = target_oh.sum(dim=1) > 0
109 | tp = tp[mask]
110 | self.tp.append(tp) # type: ignore
111 |
112 | def compute(self) -> Tensor:
113 | tp = dim_zero_cat(self.tp) # type: ignore
114 | return tp.float().mean()
115 |
--------------------------------------------------------------------------------
/torchhub/facebookresearch_dinov2_main/dinov2/eval/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import argparse
8 | from typing import Any, List, Optional, Tuple
9 |
10 | import torch
11 | import torch.backends.cudnn as cudnn
12 |
13 | from dinov2.models import build_model_from_cfg
14 | from dinov2.utils.config import setup
15 | import dinov2.utils.utils as dinov2_utils
16 |
17 |
18 | def get_args_parser(
19 | description: Optional[str] = None,
20 | parents: Optional[List[argparse.ArgumentParser]] = None,
21 | add_help: bool = True,
22 | ):
23 | parser = argparse.ArgumentParser(
24 | description=description,
25 | parents=parents or [],
26 | add_help=add_help,
27 | )
28 | parser.add_argument(
29 | "--config-file",
30 | type=str,
31 | help="Model configuration file",
32 | )
33 | parser.add_argument(
34 | "--pretrained-weights",
35 | type=str,
36 | help="Pretrained model weights",
37 | )
38 | parser.add_argument(
39 | "--output-dir",
40 | default="",
41 | type=str,
42 | help="Output directory to write results and logs",
43 | )
44 | parser.add_argument(
45 | "--opts",
46 | help="Extra configuration options",
47 | default=[],
48 | nargs="+",
49 | )
50 | return parser
51 |
52 |
53 | def get_autocast_dtype(config):
54 | teacher_dtype_str = config.compute_precision.teacher.backbone.mixed_precision.param_dtype
55 | if teacher_dtype_str == "fp16":
56 | return torch.half
57 | elif teacher_dtype_str == "bf16":
58 | return torch.bfloat16
59 | else:
60 | return torch.float
61 |
62 |
63 | def build_model_for_eval(config, pretrained_weights):
64 | model, _ = build_model_from_cfg(config, only_teacher=True)
65 | dinov2_utils.load_pretrained_weights(model, pretrained_weights, "teacher")
66 | model.eval()
67 | model.cuda()
68 | return model
69 |
70 |
71 | def setup_and_build_model(args) -> Tuple[Any, torch.dtype]:
72 | cudnn.benchmark = True
73 | config = setup(args)
74 | model = build_model_for_eval(config, args.pretrained_weights)
75 | autocast_dtype = get_autocast_dtype(config)
76 | return model, autocast_dtype
77 |
--------------------------------------------------------------------------------
/torchhub/facebookresearch_dinov2_main/dinov2/eval/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import logging
8 | from typing import Dict, Optional
9 |
10 | import torch
11 | from torch import nn
12 | from torchmetrics import MetricCollection
13 |
14 | from dinov2.data import DatasetWithEnumeratedTargets, SamplerType, make_data_loader
15 | import dinov2.distributed as distributed
16 | from dinov2.logging import MetricLogger
17 |
18 |
19 | logger = logging.getLogger("dinov2")
20 |
21 |
22 | class ModelWithNormalize(torch.nn.Module):
23 | def __init__(self, model):
24 | super().__init__()
25 | self.model = model
26 |
27 | def forward(self, samples):
28 | return nn.functional.normalize(self.model(samples), dim=1, p=2)
29 |
30 |
31 | class ModelWithIntermediateLayers(nn.Module):
32 | def __init__(self, feature_model, n_last_blocks, autocast_ctx):
33 | super().__init__()
34 | self.feature_model = feature_model
35 | self.feature_model.eval()
36 | self.n_last_blocks = n_last_blocks
37 | self.autocast_ctx = autocast_ctx
38 |
39 | def forward(self, images):
40 | with torch.inference_mode():
41 | with self.autocast_ctx():
42 | features = self.feature_model.get_intermediate_layers(
43 | images, self.n_last_blocks, return_class_token=True
44 | )
45 | return features
46 |
47 |
48 | @torch.inference_mode()
49 | def evaluate(
50 | model: nn.Module,
51 | data_loader,
52 | postprocessors: Dict[str, nn.Module],
53 | metrics: Dict[str, MetricCollection],
54 | device: torch.device,
55 | criterion: Optional[nn.Module] = None,
56 | ):
57 | model.eval()
58 | if criterion is not None:
59 | criterion.eval()
60 |
61 | for metric in metrics.values():
62 | metric = metric.to(device)
63 |
64 | metric_logger = MetricLogger(delimiter=" ")
65 | header = "Test:"
66 |
67 | for samples, targets, *_ in metric_logger.log_every(data_loader, 10, header):
68 | outputs = model(samples.to(device))
69 | targets = targets.to(device)
70 |
71 | if criterion is not None:
72 | loss = criterion(outputs, targets)
73 | metric_logger.update(loss=loss.item())
74 |
75 | for k, metric in metrics.items():
76 | metric_inputs = postprocessors[k](outputs, targets)
77 | metric.update(**metric_inputs)
78 |
79 | metric_logger.synchronize_between_processes()
80 | logger.info(f"Averaged stats: {metric_logger}")
81 |
82 | stats = {k: metric.compute() for k, metric in metrics.items()}
83 | metric_logger_stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
84 | return metric_logger_stats, stats
85 |
86 |
87 | def all_gather_and_flatten(tensor_rank):
88 | tensor_all_ranks = torch.empty(
89 | distributed.get_global_size(),
90 | *tensor_rank.shape,
91 | dtype=tensor_rank.dtype,
92 | device=tensor_rank.device,
93 | )
94 | tensor_list = list(tensor_all_ranks.unbind(0))
95 | torch.distributed.all_gather(tensor_list, tensor_rank.contiguous())
96 | return tensor_all_ranks.flatten(end_dim=1)
97 |
98 |
99 | def extract_features(model, dataset, batch_size, num_workers, gather_on_cpu=False):
100 | dataset_with_enumerated_targets = DatasetWithEnumeratedTargets(dataset)
101 | sample_count = len(dataset_with_enumerated_targets)
102 | data_loader = make_data_loader(
103 | dataset=dataset_with_enumerated_targets,
104 | batch_size=batch_size,
105 | num_workers=num_workers,
106 | sampler_type=SamplerType.DISTRIBUTED,
107 | drop_last=False,
108 | shuffle=False,
109 | )
110 | return extract_features_with_dataloader(model, data_loader, sample_count, gather_on_cpu)
111 |
112 |
113 | @torch.inference_mode()
114 | def extract_features_with_dataloader(model, data_loader, sample_count, gather_on_cpu=False):
115 | gather_device = torch.device("cpu") if gather_on_cpu else torch.device("cuda")
116 | metric_logger = MetricLogger(delimiter=" ")
117 | features, all_labels = None, None
118 | for samples, (index, labels_rank) in metric_logger.log_every(data_loader, 10):
119 | samples = samples.cuda(non_blocking=True)
120 | labels_rank = labels_rank.cuda(non_blocking=True)
121 | index = index.cuda(non_blocking=True)
122 | features_rank = model(samples).float()
123 |
124 | # init storage feature matrix
125 | if features is None:
126 | features = torch.zeros(sample_count, features_rank.shape[-1], device=gather_device)
127 | labels_shape = list(labels_rank.shape)
128 | labels_shape[0] = sample_count
129 | all_labels = torch.full(labels_shape, fill_value=-1, device=gather_device)
130 | logger.info(f"Storing features into tensor of shape {features.shape}")
131 |
132 | # share indexes, features and labels between processes
133 | index_all = all_gather_and_flatten(index).to(gather_device)
134 | features_all_ranks = all_gather_and_flatten(features_rank).to(gather_device)
135 | labels_all_ranks = all_gather_and_flatten(labels_rank).to(gather_device)
136 |
137 | # update storage feature matrix
138 | if len(index_all) > 0:
139 | features.index_copy_(0, index_all, features_all_ranks)
140 | all_labels.index_copy_(0, index_all, labels_all_ranks)
141 |
142 | logger.info(f"Features shape: {tuple(features.shape)}")
143 | logger.info(f"Labels shape: {tuple(all_labels.shape)}")
144 |
145 | assert torch.all(all_labels > -1)
146 |
147 | return features, all_labels
148 |
--------------------------------------------------------------------------------
/torchhub/facebookresearch_dinov2_main/dinov2/fsdp/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import os
8 | from typing import Any
9 |
10 | import torch
11 | import dinov2.distributed as distributed
12 | from functools import partial
13 | from fvcore.common.checkpoint import Checkpointer
14 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
15 | from torch.distributed.fsdp import ShardingStrategy
16 | from torch.distributed.fsdp import MixedPrecision
17 | from torch.distributed.fsdp import StateDictType
18 | from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
19 | from torch.distributed.fsdp.wrap import ModuleWrapPolicy
20 | from torch.distributed.fsdp._runtime_utils import _reshard
21 |
22 |
23 | def get_fsdp_wrapper(model_cfg, modules_to_wrap=set()):
24 | sharding_strategy_dict = {
25 | "NO_SHARD": ShardingStrategy.NO_SHARD,
26 | "SHARD_GRAD_OP": ShardingStrategy.SHARD_GRAD_OP,
27 | "FULL_SHARD": ShardingStrategy.FULL_SHARD,
28 | }
29 |
30 | dtype_dict = {
31 | "fp32": torch.float32,
32 | "fp16": torch.float16,
33 | "bf16": torch.bfloat16,
34 | }
35 |
36 | mixed_precision_config = MixedPrecision(
37 | param_dtype=dtype_dict[model_cfg.mixed_precision.param_dtype],
38 | reduce_dtype=dtype_dict[model_cfg.mixed_precision.reduce_dtype],
39 | buffer_dtype=dtype_dict[model_cfg.mixed_precision.buffer_dtype],
40 | )
41 |
42 | sharding_strategy_config = sharding_strategy_dict[model_cfg.sharding_strategy]
43 |
44 | local_rank = distributed.get_local_rank()
45 |
46 | fsdp_wrapper = partial(
47 | FSDP,
48 | sharding_strategy=sharding_strategy_config,
49 | mixed_precision=mixed_precision_config,
50 | device_id=local_rank,
51 | sync_module_states=True,
52 | use_orig_params=True,
53 | auto_wrap_policy=ModuleWrapPolicy(modules_to_wrap),
54 | )
55 | return fsdp_wrapper
56 |
57 |
58 | def is_fsdp(x):
59 | return isinstance(x, FSDP)
60 |
61 |
62 | def is_sharded_fsdp(x):
63 | return is_fsdp(x) and x.sharding_strategy is not ShardingStrategy.NO_SHARD
64 |
65 |
66 | def free_if_fsdp(x):
67 | if is_sharded_fsdp(x):
68 | handles = x._handles
69 | true_list = [True for h in handles]
70 | _reshard(x, handles, true_list)
71 |
72 |
73 | def get_fsdp_modules(x):
74 | return FSDP.fsdp_modules(x)
75 |
76 |
77 | def reshard_fsdp_model(x):
78 | for m in get_fsdp_modules(x):
79 | free_if_fsdp(m)
80 |
81 |
82 | def rankstr():
83 | return f"rank_{distributed.get_global_rank()}"
84 |
85 |
86 | class FSDPCheckpointer(Checkpointer):
87 | def save(self, name: str, **kwargs: Any) -> None:
88 | """
89 | Dump model and checkpointables to a file.
90 |
91 | Args:
92 | name (str): name of the file.
93 | kwargs (dict): extra arbitrary data to save.
94 | """
95 | if not self.save_dir or not self.save_to_disk:
96 | return
97 |
98 | data = {}
99 | with FSDP.state_dict_type(self.model, StateDictType.LOCAL_STATE_DICT):
100 | data["model"] = self.model.state_dict()
101 |
102 | # data["model"] = self.model.state_dict()
103 | for key, obj in self.checkpointables.items():
104 | data[key] = obj.state_dict()
105 | data.update(kwargs)
106 |
107 | basename = f"{name}.{rankstr()}.pth"
108 | save_file = os.path.join(self.save_dir, basename)
109 | assert os.path.basename(save_file) == basename, basename
110 | self.logger.info("Saving checkpoint to {}".format(save_file))
111 | with self.path_manager.open(save_file, "wb") as f:
112 | torch.save(data, f)
113 | self.tag_last_checkpoint(basename)
114 |
115 | def load(self, *args, **kwargs):
116 | with FSDP.state_dict_type(self.model, StateDictType.LOCAL_STATE_DICT):
117 | return super().load(*args, **kwargs)
118 |
119 | def has_checkpoint(self) -> bool:
120 | """
121 | Returns:
122 | bool: whether a checkpoint exists in the target directory.
123 | """
124 | save_file = os.path.join(self.save_dir, f"last_checkpoint.{rankstr()}")
125 | return self.path_manager.exists(save_file)
126 |
127 | def get_checkpoint_file(self) -> str:
128 | """
129 | Returns:
130 | str: The latest checkpoint file in target directory.
131 | """
132 | save_file = os.path.join(self.save_dir, f"last_checkpoint.{rankstr()}")
133 | try:
134 | with self.path_manager.open(save_file, "r") as f:
135 | last_saved = f.read().strip()
136 | except IOError:
137 | # if file doesn't exist, maybe because it has just been
138 | # deleted by a separate process
139 | return ""
140 | # pyre-fixme[6]: For 2nd param expected `Union[PathLike[str], str]` but got
141 | # `Union[bytes, str]`.
142 | return os.path.join(self.save_dir, last_saved)
143 |
144 | def tag_last_checkpoint(self, last_filename_basename: str) -> None:
145 | """
146 | Tag the last checkpoint.
147 |
148 | Args:
149 | last_filename_basename (str): the basename of the last filename.
150 | """
151 | if distributed.is_enabled():
152 | torch.distributed.barrier()
153 | save_file = os.path.join(self.save_dir, f"last_checkpoint.{rankstr()}")
154 | with self.path_manager.open(save_file, "w") as f:
155 | f.write(last_filename_basename) # pyre-ignore
156 |
157 |
158 | ShardedGradScaler = ShardedGradScaler
159 |
--------------------------------------------------------------------------------
/torchhub/facebookresearch_dinov2_main/dinov2/layers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from .dino_head import DINOHead
8 | from .mlp import Mlp
9 | from .patch_embed import PatchEmbed
10 | from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
11 | from .block import NestedTensorBlock
12 | from .attention import MemEffAttention
13 |
--------------------------------------------------------------------------------
/torchhub/facebookresearch_dinov2_main/dinov2/layers/attention.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # References:
8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
10 |
11 | import logging
12 |
13 | from torch import Tensor
14 | from torch import nn
15 |
16 |
17 | logger = logging.getLogger("dinov2")
18 |
19 |
20 | try:
21 | from xformers.ops import memory_efficient_attention, unbind, fmha
22 |
23 | XFORMERS_AVAILABLE = True
24 | except ImportError:
25 | logger.warning("xFormers not available")
26 | XFORMERS_AVAILABLE = False
27 |
28 |
29 | class Attention(nn.Module):
30 | def __init__(
31 | self,
32 | dim: int,
33 | num_heads: int = 8,
34 | qkv_bias: bool = False,
35 | proj_bias: bool = True,
36 | attn_drop: float = 0.0,
37 | proj_drop: float = 0.0,
38 | ) -> None:
39 | super().__init__()
40 | self.num_heads = num_heads
41 | head_dim = dim // num_heads
42 | self.scale = head_dim**-0.5
43 |
44 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
45 | self.attn_drop = nn.Dropout(attn_drop)
46 | self.proj = nn.Linear(dim, dim, bias=proj_bias)
47 | self.proj_drop = nn.Dropout(proj_drop)
48 |
49 | def forward(self, x: Tensor) -> Tensor:
50 | B, N, C = x.shape
51 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
52 |
53 | q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
54 | attn = q @ k.transpose(-2, -1)
55 |
56 | attn = attn.softmax(dim=-1)
57 | attn = self.attn_drop(attn)
58 |
59 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
60 | x = self.proj(x)
61 | x = self.proj_drop(x)
62 | return x
63 |
64 |
65 | class MemEffAttention(Attention):
66 | def forward(self, x: Tensor, attn_bias=None) -> Tensor:
67 | if not XFORMERS_AVAILABLE:
68 | assert attn_bias is None, "xFormers is required for nested tensors usage"
69 | return super().forward(x)
70 |
71 | B, N, C = x.shape
72 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
73 |
74 | q, k, v = unbind(qkv, 2)
75 |
76 | x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
77 | x = x.reshape([B, N, C])
78 |
79 | x = self.proj(x)
80 | x = self.proj_drop(x)
81 | return x
82 |
--------------------------------------------------------------------------------
/torchhub/facebookresearch_dinov2_main/dinov2/layers/block.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # References:
8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
10 |
11 | import logging
12 | from typing import Callable, List, Any, Tuple, Dict
13 |
14 | import torch
15 | from torch import nn, Tensor
16 |
17 | from .attention import Attention, MemEffAttention
18 | from .drop_path import DropPath
19 | from .layer_scale import LayerScale
20 | from .mlp import Mlp
21 |
22 |
23 | logger = logging.getLogger("dinov2")
24 |
25 |
26 | try:
27 | from xformers.ops import fmha
28 | from xformers.ops import scaled_index_add, index_select_cat
29 |
30 | XFORMERS_AVAILABLE = True
31 | except ImportError:
32 | logger.warning("xFormers not available")
33 | XFORMERS_AVAILABLE = False
34 |
35 |
36 | class Block(nn.Module):
37 | def __init__(
38 | self,
39 | dim: int,
40 | num_heads: int,
41 | mlp_ratio: float = 4.0,
42 | qkv_bias: bool = False,
43 | proj_bias: bool = True,
44 | ffn_bias: bool = True,
45 | drop: float = 0.0,
46 | attn_drop: float = 0.0,
47 | init_values=None,
48 | drop_path: float = 0.0,
49 | act_layer: Callable[..., nn.Module] = nn.GELU,
50 | norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
51 | attn_class: Callable[..., nn.Module] = Attention,
52 | ffn_layer: Callable[..., nn.Module] = Mlp,
53 | ) -> None:
54 | super().__init__()
55 | # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
56 | self.norm1 = norm_layer(dim)
57 | self.attn = attn_class(
58 | dim,
59 | num_heads=num_heads,
60 | qkv_bias=qkv_bias,
61 | proj_bias=proj_bias,
62 | attn_drop=attn_drop,
63 | proj_drop=drop,
64 | )
65 | self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
66 | self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
67 |
68 | self.norm2 = norm_layer(dim)
69 | mlp_hidden_dim = int(dim * mlp_ratio)
70 | self.mlp = ffn_layer(
71 | in_features=dim,
72 | hidden_features=mlp_hidden_dim,
73 | act_layer=act_layer,
74 | drop=drop,
75 | bias=ffn_bias,
76 | )
77 | self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
78 | self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
79 |
80 | self.sample_drop_ratio = drop_path
81 |
82 | def forward(self, x: Tensor) -> Tensor:
83 | def attn_residual_func(x: Tensor) -> Tensor:
84 | return self.ls1(self.attn(self.norm1(x)))
85 |
86 | def ffn_residual_func(x: Tensor) -> Tensor:
87 | return self.ls2(self.mlp(self.norm2(x)))
88 |
89 | if self.training and self.sample_drop_ratio > 0.1:
90 | # the overhead is compensated only for a drop path rate larger than 0.1
91 | x = drop_add_residual_stochastic_depth(
92 | x,
93 | residual_func=attn_residual_func,
94 | sample_drop_ratio=self.sample_drop_ratio,
95 | )
96 | x = drop_add_residual_stochastic_depth(
97 | x,
98 | residual_func=ffn_residual_func,
99 | sample_drop_ratio=self.sample_drop_ratio,
100 | )
101 | elif self.training and self.sample_drop_ratio > 0.0:
102 | x = x + self.drop_path1(attn_residual_func(x))
103 | x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
104 | else:
105 | x = x + attn_residual_func(x)
106 | x = x + ffn_residual_func(x)
107 | return x
108 |
109 |
110 | def drop_add_residual_stochastic_depth(
111 | x: Tensor,
112 | residual_func: Callable[[Tensor], Tensor],
113 | sample_drop_ratio: float = 0.0,
114 | ) -> Tensor:
115 | # 1) extract subset using permutation
116 | b, n, d = x.shape
117 | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
118 | brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
119 | x_subset = x[brange]
120 |
121 | # 2) apply residual_func to get residual
122 | residual = residual_func(x_subset)
123 |
124 | x_flat = x.flatten(1)
125 | residual = residual.flatten(1)
126 |
127 | residual_scale_factor = b / sample_subset_size
128 |
129 | # 3) add the residual
130 | x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
131 | return x_plus_residual.view_as(x)
132 |
133 |
134 | def get_branges_scales(x, sample_drop_ratio=0.0):
135 | b, n, d = x.shape
136 | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
137 | brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
138 | residual_scale_factor = b / sample_subset_size
139 | return brange, residual_scale_factor
140 |
141 |
142 | def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
143 | if scaling_vector is None:
144 | x_flat = x.flatten(1)
145 | residual = residual.flatten(1)
146 | x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
147 | else:
148 | x_plus_residual = scaled_index_add(
149 | x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
150 | )
151 | return x_plus_residual
152 |
153 |
154 | attn_bias_cache: Dict[Tuple, Any] = {}
155 |
156 |
157 | def get_attn_bias_and_cat(x_list, branges=None):
158 | """
159 | this will perform the index select, cat the tensors, and provide the attn_bias from cache
160 | """
161 | batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
162 | all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
163 | if all_shapes not in attn_bias_cache.keys():
164 | seqlens = []
165 | for b, x in zip(batch_sizes, x_list):
166 | for _ in range(b):
167 | seqlens.append(x.shape[1])
168 | attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
169 | attn_bias._batch_sizes = batch_sizes
170 | attn_bias_cache[all_shapes] = attn_bias
171 |
172 | if branges is not None:
173 | cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
174 | else:
175 | tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
176 | cat_tensors = torch.cat(tensors_bs1, dim=1)
177 |
178 | return attn_bias_cache[all_shapes], cat_tensors
179 |
180 |
181 | def drop_add_residual_stochastic_depth_list(
182 | x_list: List[Tensor],
183 | residual_func: Callable[[Tensor, Any], Tensor],
184 | sample_drop_ratio: float = 0.0,
185 | scaling_vector=None,
186 | ) -> Tensor:
187 | # 1) generate random set of indices for dropping samples in the batch
188 | branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
189 | branges = [s[0] for s in branges_scales]
190 | residual_scale_factors = [s[1] for s in branges_scales]
191 |
192 | # 2) get attention bias and index+concat the tensors
193 | attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
194 |
195 | # 3) apply residual_func to get residual, and split the result
196 | residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
197 |
198 | outputs = []
199 | for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
200 | outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
201 | return outputs
202 |
203 |
204 | class NestedTensorBlock(Block):
205 | def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
206 | """
207 | x_list contains a list of tensors to nest together and run
208 | """
209 | assert isinstance(self.attn, MemEffAttention)
210 |
211 | if self.training and self.sample_drop_ratio > 0.0:
212 |
213 | def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
214 | return self.attn(self.norm1(x), attn_bias=attn_bias)
215 |
216 | def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
217 | return self.mlp(self.norm2(x))
218 |
219 | x_list = drop_add_residual_stochastic_depth_list(
220 | x_list,
221 | residual_func=attn_residual_func,
222 | sample_drop_ratio=self.sample_drop_ratio,
223 | scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
224 | )
225 | x_list = drop_add_residual_stochastic_depth_list(
226 | x_list,
227 | residual_func=ffn_residual_func,
228 | sample_drop_ratio=self.sample_drop_ratio,
229 | scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
230 | )
231 | return x_list
232 | else:
233 |
234 | def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
235 | return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
236 |
237 | def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
238 | return self.ls2(self.mlp(self.norm2(x)))
239 |
240 | attn_bias, x = get_attn_bias_and_cat(x_list)
241 | x = x + attn_residual_func(x, attn_bias=attn_bias)
242 | x = x + ffn_residual_func(x)
243 | return attn_bias.split(x)
244 |
245 | def forward(self, x_or_x_list):
246 | if isinstance(x_or_x_list, Tensor):
247 | return super().forward(x_or_x_list)
248 | elif isinstance(x_or_x_list, list):
249 | assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
250 | return self.forward_nested(x_or_x_list)
251 | else:
252 | raise AssertionError
253 |
--------------------------------------------------------------------------------
/torchhub/facebookresearch_dinov2_main/dinov2/layers/dino_head.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | import torch.nn as nn
9 | from torch.nn.init import trunc_normal_
10 | from torch.nn.utils import weight_norm
11 |
12 |
13 | class DINOHead(nn.Module):
14 | def __init__(
15 | self,
16 | in_dim,
17 | out_dim,
18 | use_bn=False,
19 | nlayers=3,
20 | hidden_dim=2048,
21 | bottleneck_dim=256,
22 | mlp_bias=True,
23 | ):
24 | super().__init__()
25 | nlayers = max(nlayers, 1)
26 | self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
27 | self.apply(self._init_weights)
28 | self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
29 | self.last_layer.weight_g.data.fill_(1)
30 |
31 | def _init_weights(self, m):
32 | if isinstance(m, nn.Linear):
33 | trunc_normal_(m.weight, std=0.02)
34 | if isinstance(m, nn.Linear) and m.bias is not None:
35 | nn.init.constant_(m.bias, 0)
36 |
37 | def forward(self, x):
38 | x = self.mlp(x)
39 | eps = 1e-6 if x.dtype == torch.float16 else 1e-12
40 | x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
41 | x = self.last_layer(x)
42 | return x
43 |
44 |
45 | def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
46 | if nlayers == 1:
47 | return nn.Linear(in_dim, bottleneck_dim, bias=bias)
48 | else:
49 | layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
50 | if use_bn:
51 | layers.append(nn.BatchNorm1d(hidden_dim))
52 | layers.append(nn.GELU())
53 | for _ in range(nlayers - 2):
54 | layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
55 | if use_bn:
56 | layers.append(nn.BatchNorm1d(hidden_dim))
57 | layers.append(nn.GELU())
58 | layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
59 | return nn.Sequential(*layers)
60 |
--------------------------------------------------------------------------------
/torchhub/facebookresearch_dinov2_main/dinov2/layers/drop_path.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # References:
8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
10 |
11 |
12 | from torch import nn
13 |
14 |
15 | def drop_path(x, drop_prob: float = 0.0, training: bool = False):
16 | if drop_prob == 0.0 or not training:
17 | return x
18 | keep_prob = 1 - drop_prob
19 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
20 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
21 | if keep_prob > 0.0:
22 | random_tensor.div_(keep_prob)
23 | output = x * random_tensor
24 | return output
25 |
26 |
27 | class DropPath(nn.Module):
28 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
29 |
30 | def __init__(self, drop_prob=None):
31 | super(DropPath, self).__init__()
32 | self.drop_prob = drop_prob
33 |
34 | def forward(self, x):
35 | return drop_path(x, self.drop_prob, self.training)
36 |
--------------------------------------------------------------------------------
/torchhub/facebookresearch_dinov2_main/dinov2/layers/layer_scale.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
8 |
9 | from typing import Union
10 |
11 | import torch
12 | from torch import Tensor
13 | from torch import nn
14 |
15 |
16 | class LayerScale(nn.Module):
17 | def __init__(
18 | self,
19 | dim: int,
20 | init_values: Union[float, Tensor] = 1e-5,
21 | inplace: bool = False,
22 | ) -> None:
23 | super().__init__()
24 | self.inplace = inplace
25 | self.gamma = nn.Parameter(init_values * torch.ones(dim))
26 |
27 | def forward(self, x: Tensor) -> Tensor:
28 | return x.mul_(self.gamma) if self.inplace else x * self.gamma
29 |
--------------------------------------------------------------------------------
/torchhub/facebookresearch_dinov2_main/dinov2/layers/mlp.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # References:
8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
10 |
11 |
12 | from typing import Callable, Optional
13 |
14 | from torch import Tensor, nn
15 |
16 |
17 | class Mlp(nn.Module):
18 | def __init__(
19 | self,
20 | in_features: int,
21 | hidden_features: Optional[int] = None,
22 | out_features: Optional[int] = None,
23 | act_layer: Callable[..., nn.Module] = nn.GELU,
24 | drop: float = 0.0,
25 | bias: bool = True,
26 | ) -> None:
27 | super().__init__()
28 | out_features = out_features or in_features
29 | hidden_features = hidden_features or in_features
30 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
31 | self.act = act_layer()
32 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
33 | self.drop = nn.Dropout(drop)
34 |
35 | def forward(self, x: Tensor) -> Tensor:
36 | x = self.fc1(x)
37 | x = self.act(x)
38 | x = self.drop(x)
39 | x = self.fc2(x)
40 | x = self.drop(x)
41 | return x
42 |
--------------------------------------------------------------------------------
/torchhub/facebookresearch_dinov2_main/dinov2/layers/patch_embed.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # References:
8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
10 |
11 | from typing import Callable, Optional, Tuple, Union
12 |
13 | from torch import Tensor
14 | import torch.nn as nn
15 |
16 |
17 | def make_2tuple(x):
18 | if isinstance(x, tuple):
19 | assert len(x) == 2
20 | return x
21 |
22 | assert isinstance(x, int)
23 | return (x, x)
24 |
25 |
26 | class PatchEmbed(nn.Module):
27 | """
28 | 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
29 |
30 | Args:
31 | img_size: Image size.
32 | patch_size: Patch token size.
33 | in_chans: Number of input image channels.
34 | embed_dim: Number of linear projection output channels.
35 | norm_layer: Normalization layer.
36 | """
37 |
38 | def __init__(
39 | self,
40 | img_size: Union[int, Tuple[int, int]] = 224,
41 | patch_size: Union[int, Tuple[int, int]] = 16,
42 | in_chans: int = 3,
43 | embed_dim: int = 768,
44 | norm_layer: Optional[Callable] = None,
45 | flatten_embedding: bool = True,
46 | ) -> None:
47 | super().__init__()
48 |
49 | image_HW = make_2tuple(img_size)
50 | patch_HW = make_2tuple(patch_size)
51 | patch_grid_size = (
52 | image_HW[0] // patch_HW[0],
53 | image_HW[1] // patch_HW[1],
54 | )
55 |
56 | self.img_size = image_HW
57 | self.patch_size = patch_HW
58 | self.patches_resolution = patch_grid_size
59 | self.num_patches = patch_grid_size[0] * patch_grid_size[1]
60 |
61 | self.in_chans = in_chans
62 | self.embed_dim = embed_dim
63 |
64 | self.flatten_embedding = flatten_embedding
65 |
66 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
67 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
68 |
69 | def forward(self, x: Tensor) -> Tensor:
70 | _, _, H, W = x.shape
71 | patch_H, patch_W = self.patch_size
72 |
73 | assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
74 | assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
75 |
76 | x = self.proj(x) # B C H W
77 | H, W = x.size(2), x.size(3)
78 | x = x.flatten(2).transpose(1, 2) # B HW C
79 | x = self.norm(x)
80 | if not self.flatten_embedding:
81 | x = x.reshape(-1, H, W, self.embed_dim) # B H W C
82 | return x
83 |
84 | def flops(self) -> float:
85 | Ho, Wo = self.patches_resolution
86 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
87 | if self.norm is not None:
88 | flops += Ho * Wo * self.embed_dim
89 | return flops
90 |
--------------------------------------------------------------------------------
/torchhub/facebookresearch_dinov2_main/dinov2/layers/swiglu_ffn.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import Callable, Optional
8 |
9 | from torch import Tensor, nn
10 | import torch.nn.functional as F
11 |
12 |
13 | class SwiGLUFFN(nn.Module):
14 | def __init__(
15 | self,
16 | in_features: int,
17 | hidden_features: Optional[int] = None,
18 | out_features: Optional[int] = None,
19 | act_layer: Callable[..., nn.Module] = None,
20 | drop: float = 0.0,
21 | bias: bool = True,
22 | ) -> None:
23 | super().__init__()
24 | out_features = out_features or in_features
25 | hidden_features = hidden_features or in_features
26 | self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
27 | self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
28 |
29 | def forward(self, x: Tensor) -> Tensor:
30 | x12 = self.w12(x)
31 | x1, x2 = x12.chunk(2, dim=-1)
32 | hidden = F.silu(x1) * x2
33 | return self.w3(hidden)
34 |
35 |
36 | try:
37 | from xformers.ops import SwiGLU
38 |
39 | XFORMERS_AVAILABLE = True
40 | except ImportError:
41 | SwiGLU = SwiGLUFFN
42 | XFORMERS_AVAILABLE = False
43 |
44 |
45 | class SwiGLUFFNFused(SwiGLU):
46 | def __init__(
47 | self,
48 | in_features: int,
49 | hidden_features: Optional[int] = None,
50 | out_features: Optional[int] = None,
51 | act_layer: Callable[..., nn.Module] = None,
52 | drop: float = 0.0,
53 | bias: bool = True,
54 | ) -> None:
55 | out_features = out_features or in_features
56 | hidden_features = hidden_features or in_features
57 | hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
58 | super().__init__(
59 | in_features=in_features,
60 | hidden_features=hidden_features,
61 | out_features=out_features,
62 | bias=bias,
63 | )
64 |
--------------------------------------------------------------------------------
/torchhub/facebookresearch_dinov2_main/dinov2/logging/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import functools
8 | import logging
9 | import os
10 | import sys
11 | from typing import Optional
12 |
13 | import dinov2.distributed as distributed
14 | from .helpers import MetricLogger, SmoothedValue
15 |
16 |
17 | # So that calling _configure_logger multiple times won't add many handlers
18 | @functools.lru_cache()
19 | def _configure_logger(
20 | name: Optional[str] = None,
21 | *,
22 | level: int = logging.DEBUG,
23 | output: Optional[str] = None,
24 | ):
25 | """
26 | Configure a logger.
27 |
28 | Adapted from Detectron2.
29 |
30 | Args:
31 | name: The name of the logger to configure.
32 | level: The logging level to use.
33 | output: A file name or a directory to save log. If None, will not save log file.
34 | If ends with ".txt" or ".log", assumed to be a file name.
35 | Otherwise, logs will be saved to `output/log.txt`.
36 |
37 | Returns:
38 | The configured logger.
39 | """
40 |
41 | logger = logging.getLogger(name)
42 | logger.setLevel(level)
43 | logger.propagate = False
44 |
45 | # Loosely match Google glog format:
46 | # [IWEF]yyyymmdd hh:mm:ss.uuuuuu threadid file:line] msg
47 | # but use a shorter timestamp and include the logger name:
48 | # [IWEF]yyyymmdd hh:mm:ss logger threadid file:line] msg
49 | fmt_prefix = "%(levelname).1s%(asctime)s %(process)s %(name)s %(filename)s:%(lineno)s] "
50 | fmt_message = "%(message)s"
51 | fmt = fmt_prefix + fmt_message
52 | datefmt = "%Y%m%d %H:%M:%S"
53 | formatter = logging.Formatter(fmt=fmt, datefmt=datefmt)
54 |
55 | # stdout logging for main worker only
56 | if distributed.is_main_process():
57 | handler = logging.StreamHandler(stream=sys.stdout)
58 | handler.setLevel(logging.DEBUG)
59 | handler.setFormatter(formatter)
60 | logger.addHandler(handler)
61 |
62 | # file logging for all workers
63 | if output:
64 | if os.path.splitext(output)[-1] in (".txt", ".log"):
65 | filename = output
66 | else:
67 | filename = os.path.join(output, "logs", "log.txt")
68 |
69 | if not distributed.is_main_process():
70 | global_rank = distributed.get_global_rank()
71 | filename = filename + ".rank{}".format(global_rank)
72 |
73 | os.makedirs(os.path.dirname(filename), exist_ok=True)
74 |
75 | handler = logging.StreamHandler(open(filename, "a"))
76 | handler.setLevel(logging.DEBUG)
77 | handler.setFormatter(formatter)
78 | logger.addHandler(handler)
79 |
80 | return logger
81 |
82 |
83 | def setup_logging(
84 | output: Optional[str] = None,
85 | *,
86 | name: Optional[str] = None,
87 | level: int = logging.DEBUG,
88 | capture_warnings: bool = True,
89 | ) -> None:
90 | """
91 | Setup logging.
92 |
93 | Args:
94 | output: A file name or a directory to save log files. If None, log
95 | files will not be saved. If output ends with ".txt" or ".log", it
96 | is assumed to be a file name.
97 | Otherwise, logs will be saved to `output/log.txt`.
98 | name: The name of the logger to configure, by default the root logger.
99 | level: The logging level to use.
100 | capture_warnings: Whether warnings should be captured as logs.
101 | """
102 | logging.captureWarnings(capture_warnings)
103 | _configure_logger(name, level=level, output=output)
104 |
--------------------------------------------------------------------------------
/torchhub/facebookresearch_dinov2_main/dinov2/logging/helpers.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from collections import defaultdict, deque
8 | import datetime
9 | import json
10 | import logging
11 | import time
12 |
13 | import torch
14 |
15 | import dinov2.distributed as distributed
16 |
17 |
18 | logger = logging.getLogger("dinov2")
19 |
20 |
21 | class MetricLogger(object):
22 | def __init__(self, delimiter="\t", output_file=None):
23 | self.meters = defaultdict(SmoothedValue)
24 | self.delimiter = delimiter
25 | self.output_file = output_file
26 |
27 | def update(self, **kwargs):
28 | for k, v in kwargs.items():
29 | if isinstance(v, torch.Tensor):
30 | v = v.item()
31 | assert isinstance(v, (float, int))
32 | self.meters[k].update(v)
33 |
34 | def __getattr__(self, attr):
35 | if attr in self.meters:
36 | return self.meters[attr]
37 | if attr in self.__dict__:
38 | return self.__dict__[attr]
39 | raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr))
40 |
41 | def __str__(self):
42 | loss_str = []
43 | for name, meter in self.meters.items():
44 | loss_str.append("{}: {}".format(name, str(meter)))
45 | return self.delimiter.join(loss_str)
46 |
47 | def synchronize_between_processes(self):
48 | for meter in self.meters.values():
49 | meter.synchronize_between_processes()
50 |
51 | def add_meter(self, name, meter):
52 | self.meters[name] = meter
53 |
54 | def dump_in_output_file(self, iteration, iter_time, data_time):
55 | if self.output_file is None or not distributed.is_main_process():
56 | return
57 | dict_to_dump = dict(
58 | iteration=iteration,
59 | iter_time=iter_time,
60 | data_time=data_time,
61 | )
62 | dict_to_dump.update({k: v.median for k, v in self.meters.items()})
63 | with open(self.output_file, "a") as f:
64 | f.write(json.dumps(dict_to_dump) + "\n")
65 | pass
66 |
67 | def log_every(self, iterable, print_freq, header=None, n_iterations=None, start_iteration=0):
68 | i = start_iteration
69 | if not header:
70 | header = ""
71 | start_time = time.time()
72 | end = time.time()
73 | iter_time = SmoothedValue(fmt="{avg:.6f}")
74 | data_time = SmoothedValue(fmt="{avg:.6f}")
75 |
76 | if n_iterations is None:
77 | n_iterations = len(iterable)
78 |
79 | space_fmt = ":" + str(len(str(n_iterations))) + "d"
80 |
81 | log_list = [
82 | header,
83 | "[{0" + space_fmt + "}/{1}]",
84 | "eta: {eta}",
85 | "{meters}",
86 | "time: {time}",
87 | "data: {data}",
88 | ]
89 | if torch.cuda.is_available():
90 | log_list += ["max mem: {memory:.0f}"]
91 |
92 | log_msg = self.delimiter.join(log_list)
93 | MB = 1024.0 * 1024.0
94 | for obj in iterable:
95 | data_time.update(time.time() - end)
96 | yield obj
97 | iter_time.update(time.time() - end)
98 | if i % print_freq == 0 or i == n_iterations - 1:
99 | self.dump_in_output_file(iteration=i, iter_time=iter_time.avg, data_time=data_time.avg)
100 | eta_seconds = iter_time.global_avg * (n_iterations - i)
101 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
102 | if torch.cuda.is_available():
103 | logger.info(
104 | log_msg.format(
105 | i,
106 | n_iterations,
107 | eta=eta_string,
108 | meters=str(self),
109 | time=str(iter_time),
110 | data=str(data_time),
111 | memory=torch.cuda.max_memory_allocated() / MB,
112 | )
113 | )
114 | else:
115 | logger.info(
116 | log_msg.format(
117 | i,
118 | n_iterations,
119 | eta=eta_string,
120 | meters=str(self),
121 | time=str(iter_time),
122 | data=str(data_time),
123 | )
124 | )
125 | i += 1
126 | end = time.time()
127 | if i >= n_iterations:
128 | break
129 | total_time = time.time() - start_time
130 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
131 | logger.info("{} Total time: {} ({:.6f} s / it)".format(header, total_time_str, total_time / n_iterations))
132 |
133 |
134 | class SmoothedValue:
135 | """Track a series of values and provide access to smoothed values over a
136 | window or the global series average.
137 | """
138 |
139 | def __init__(self, window_size=20, fmt=None):
140 | if fmt is None:
141 | fmt = "{median:.4f} ({global_avg:.4f})"
142 | self.deque = deque(maxlen=window_size)
143 | self.total = 0.0
144 | self.count = 0
145 | self.fmt = fmt
146 |
147 | def update(self, value, num=1):
148 | self.deque.append(value)
149 | self.count += num
150 | self.total += value * num
151 |
152 | def synchronize_between_processes(self):
153 | """
154 | Distributed synchronization of the metric
155 | Warning: does not synchronize the deque!
156 | """
157 | if not distributed.is_enabled():
158 | return
159 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
160 | torch.distributed.barrier()
161 | torch.distributed.all_reduce(t)
162 | t = t.tolist()
163 | self.count = int(t[0])
164 | self.total = t[1]
165 |
166 | @property
167 | def median(self):
168 | d = torch.tensor(list(self.deque))
169 | return d.median().item()
170 |
171 | @property
172 | def avg(self):
173 | d = torch.tensor(list(self.deque), dtype=torch.float32)
174 | return d.mean().item()
175 |
176 | @property
177 | def global_avg(self):
178 | return self.total / self.count
179 |
180 | @property
181 | def max(self):
182 | return max(self.deque)
183 |
184 | @property
185 | def value(self):
186 | return self.deque[-1]
187 |
188 | def __str__(self):
189 | return self.fmt.format(
190 | median=self.median,
191 | avg=self.avg,
192 | global_avg=self.global_avg,
193 | max=self.max,
194 | value=self.value,
195 | )
196 |
--------------------------------------------------------------------------------
/torchhub/facebookresearch_dinov2_main/dinov2/loss/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from .dino_clstoken_loss import DINOLoss
8 | from .ibot_patch_loss import iBOTPatchLoss
9 | from .koleo_loss import KoLeoLoss
10 |
--------------------------------------------------------------------------------
/torchhub/facebookresearch_dinov2_main/dinov2/loss/dino_clstoken_loss.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | import torch.distributed as dist
9 | import torch.nn.functional as F
10 | from torch import nn
11 |
12 |
13 | class DINOLoss(nn.Module):
14 | def __init__(
15 | self,
16 | out_dim,
17 | student_temp=0.1,
18 | center_momentum=0.9,
19 | ):
20 | super().__init__()
21 | self.student_temp = student_temp
22 | self.center_momentum = center_momentum
23 | self.register_buffer("center", torch.zeros(1, out_dim))
24 | self.updated = True
25 | self.reduce_handle = None
26 | self.len_teacher_output = None
27 | self.async_batch_center = None
28 |
29 | @torch.no_grad()
30 | def softmax_center_teacher(self, teacher_output, teacher_temp):
31 | self.apply_center_update()
32 | # teacher centering and sharpening
33 | return F.softmax((teacher_output - self.center) / teacher_temp, dim=-1)
34 |
35 | @torch.no_grad()
36 | def sinkhorn_knopp_teacher(self, teacher_output, teacher_temp, n_iterations=3):
37 | teacher_output = teacher_output.float()
38 | world_size = dist.get_world_size() if dist.is_initialized() else 1
39 | Q = torch.exp(teacher_output / teacher_temp).t() # Q is K-by-B for consistency with notations from our paper
40 | B = Q.shape[1] * world_size # number of samples to assign
41 | K = Q.shape[0] # how many prototypes
42 |
43 | # make the matrix sums to 1
44 | sum_Q = torch.sum(Q)
45 | if dist.is_initialized():
46 | dist.all_reduce(sum_Q)
47 | Q /= sum_Q
48 |
49 | for it in range(n_iterations):
50 | # normalize each row: total weight per prototype must be 1/K
51 | sum_of_rows = torch.sum(Q, dim=1, keepdim=True)
52 | if dist.is_initialized():
53 | dist.all_reduce(sum_of_rows)
54 | Q /= sum_of_rows
55 | Q /= K
56 |
57 | # normalize each column: total weight per sample must be 1/B
58 | Q /= torch.sum(Q, dim=0, keepdim=True)
59 | Q /= B
60 |
61 | Q *= B # the columns must sum to 1 so that Q is an assignment
62 | return Q.t()
63 |
64 | def forward(self, student_output_list, teacher_out_softmaxed_centered_list):
65 | """
66 | Cross-entropy between softmax outputs of the teacher and student networks.
67 | """
68 | # TODO: Use cross_entropy_distribution here
69 | total_loss = 0
70 | for s in student_output_list:
71 | lsm = F.log_softmax(s / self.student_temp, dim=-1)
72 | for t in teacher_out_softmaxed_centered_list:
73 | loss = torch.sum(t * lsm, dim=-1)
74 | total_loss -= loss.mean()
75 | return total_loss
76 |
77 | @torch.no_grad()
78 | def update_center(self, teacher_output):
79 | self.reduce_center_update(teacher_output)
80 |
81 | @torch.no_grad()
82 | def reduce_center_update(self, teacher_output):
83 | self.updated = False
84 | self.len_teacher_output = len(teacher_output)
85 | self.async_batch_center = torch.sum(teacher_output, dim=0, keepdim=True)
86 | if dist.is_initialized():
87 | self.reduce_handle = dist.all_reduce(self.async_batch_center, async_op=True)
88 |
89 | @torch.no_grad()
90 | def apply_center_update(self):
91 | if self.updated is False:
92 | world_size = dist.get_world_size() if dist.is_initialized() else 1
93 |
94 | if self.reduce_handle is not None:
95 | self.reduce_handle.wait()
96 | _t = self.async_batch_center / (self.len_teacher_output * world_size)
97 |
98 | self.center = self.center * self.center_momentum + _t * (1 - self.center_momentum)
99 |
100 | self.updated = True
101 |
--------------------------------------------------------------------------------
/torchhub/facebookresearch_dinov2_main/dinov2/loss/ibot_patch_loss.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | import torch.distributed as dist
9 | import torch.nn.functional as F
10 | from torch import nn
11 |
12 | import logging
13 |
14 |
15 | logger = logging.getLogger("dinov2")
16 |
17 |
18 | try:
19 | from xformers.ops import cross_entropy
20 |
21 | def lossfunc(t, s, temp):
22 | s = s.float()
23 | t = t.float()
24 | if s.ndim == 2:
25 | return -cross_entropy(s.unsqueeze(0), t.unsqueeze(0), temp, bw_inplace=True).squeeze(0)
26 | elif s.ndim == 3:
27 | return -cross_entropy(s, t, temp, bw_inplace=True)
28 |
29 | except ImportError:
30 |
31 | def lossfunc(t, s, temp):
32 | return torch.sum(t * F.log_softmax(s / temp, dim=-1), dim=-1)
33 |
34 |
35 | class iBOTPatchLoss(nn.Module):
36 | def __init__(self, patch_out_dim, student_temp=0.1, center_momentum=0.9):
37 | super().__init__()
38 | self.student_temp = student_temp
39 | self.center_momentum = center_momentum
40 | self.register_buffer("center", torch.zeros(1, 1, patch_out_dim))
41 | self.updated = True
42 | self.reduce_handle = None
43 | self.len_teacher_patch_tokens = None
44 | self.async_batch_center = None
45 |
46 | @torch.no_grad()
47 | def softmax_center_teacher(self, teacher_patch_tokens, teacher_temp):
48 | self.apply_center_update()
49 | # teacher centering and sharpening
50 | #
51 | # WARNING:
52 | # as self.center is a float32, everything gets casted to float32 afterwards
53 | #
54 | # teacher_patch_tokens = teacher_patch_tokens.float()
55 | # return F.softmax((teacher_patch_tokens.sub_(self.center.to(teacher_patch_tokens.dtype))).mul_(1 / teacher_temp), dim=-1)
56 |
57 | return F.softmax((teacher_patch_tokens - self.center) / teacher_temp, dim=-1)
58 |
59 | # this is experimental, keep everything in float16 and let's see what happens:
60 | # return F.softmax((teacher_patch_tokens.sub_(self.center)) / teacher_temp, dim=-1)
61 |
62 | @torch.no_grad()
63 | def sinkhorn_knopp_teacher(self, teacher_output, teacher_temp, n_masked_patches_tensor, n_iterations=3):
64 | teacher_output = teacher_output.float()
65 | # world_size = dist.get_world_size() if dist.is_initialized() else 1
66 | Q = torch.exp(teacher_output / teacher_temp).t() # Q is K-by-B for consistency with notations from our paper
67 | # B = Q.shape[1] * world_size # number of samples to assign
68 | B = n_masked_patches_tensor
69 | dist.all_reduce(B)
70 | K = Q.shape[0] # how many prototypes
71 |
72 | # make the matrix sums to 1
73 | sum_Q = torch.sum(Q)
74 | if dist.is_initialized():
75 | dist.all_reduce(sum_Q)
76 | Q /= sum_Q
77 |
78 | for it in range(n_iterations):
79 | # normalize each row: total weight per prototype must be 1/K
80 | sum_of_rows = torch.sum(Q, dim=1, keepdim=True)
81 | if dist.is_initialized():
82 | dist.all_reduce(sum_of_rows)
83 | Q /= sum_of_rows
84 | Q /= K
85 |
86 | # normalize each column: total weight per sample must be 1/B
87 | Q /= torch.sum(Q, dim=0, keepdim=True)
88 | Q /= B
89 |
90 | Q *= B # the columns must sum to 1 so that Q is an assignment
91 | return Q.t()
92 |
93 | def forward(self, student_patch_tokens, teacher_patch_tokens, student_masks_flat):
94 | """
95 | Cross-entropy between softmax outputs of the teacher and student networks.
96 | student_patch_tokens: (B, N, D) tensor
97 | teacher_patch_tokens: (B, N, D) tensor
98 | student_masks_flat: (B, N) tensor
99 | """
100 | t = teacher_patch_tokens
101 | s = student_patch_tokens
102 | loss = torch.sum(t * F.log_softmax(s / self.student_temp, dim=-1), dim=-1)
103 | loss = torch.sum(loss * student_masks_flat.float(), dim=-1) / student_masks_flat.sum(dim=-1).clamp(min=1.0)
104 | return -loss.mean()
105 |
106 | def forward_masked(
107 | self,
108 | student_patch_tokens_masked,
109 | teacher_patch_tokens_masked,
110 | student_masks_flat,
111 | n_masked_patches=None,
112 | masks_weight=None,
113 | ):
114 | t = teacher_patch_tokens_masked
115 | s = student_patch_tokens_masked
116 | # loss = torch.sum(t * F.log_softmax(s / self.student_temp, dim=-1), dim=-1)
117 | loss = lossfunc(t, s, self.student_temp)
118 | if masks_weight is None:
119 | masks_weight = (
120 | (1 / student_masks_flat.sum(-1).clamp(min=1.0))
121 | .unsqueeze(-1)
122 | .expand_as(student_masks_flat)[student_masks_flat]
123 | )
124 | if n_masked_patches is not None:
125 | loss = loss[:n_masked_patches]
126 | loss = loss * masks_weight
127 | return -loss.sum() / student_masks_flat.shape[0]
128 |
129 | @torch.no_grad()
130 | def update_center(self, teacher_patch_tokens):
131 | self.reduce_center_update(teacher_patch_tokens)
132 |
133 | @torch.no_grad()
134 | def reduce_center_update(self, teacher_patch_tokens):
135 | self.updated = False
136 | self.len_teacher_patch_tokens = len(teacher_patch_tokens)
137 | self.async_batch_center = torch.sum(teacher_patch_tokens.mean(1), dim=0, keepdim=True)
138 | if dist.is_initialized():
139 | self.reduce_handle = dist.all_reduce(self.async_batch_center, async_op=True)
140 |
141 | @torch.no_grad()
142 | def apply_center_update(self):
143 | if self.updated is False:
144 | world_size = dist.get_world_size() if dist.is_initialized() else 1
145 |
146 | if self.reduce_handle is not None:
147 | self.reduce_handle.wait()
148 | _t = self.async_batch_center / (self.len_teacher_patch_tokens * world_size)
149 |
150 | self.center = self.center * self.center_momentum + _t * (1 - self.center_momentum)
151 |
152 | self.updated = True
153 |
--------------------------------------------------------------------------------
/torchhub/facebookresearch_dinov2_main/dinov2/loss/koleo_loss.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import logging
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 |
13 | # import torch.distributed as dist
14 |
15 |
16 | logger = logging.getLogger("dinov2")
17 |
18 |
19 | class KoLeoLoss(nn.Module):
20 | """Kozachenko-Leonenko entropic loss regularizer from Sablayrolles et al. - 2018 - Spreading vectors for similarity search"""
21 |
22 | def __init__(self):
23 | super().__init__()
24 | self.pdist = nn.PairwiseDistance(2, eps=1e-8)
25 |
26 | def pairwise_NNs_inner(self, x):
27 | """
28 | Pairwise nearest neighbors for L2-normalized vectors.
29 | Uses Torch rather than Faiss to remain on GPU.
30 | """
31 | # parwise dot products (= inverse distance)
32 | dots = torch.mm(x, x.t())
33 | n = x.shape[0]
34 | dots.view(-1)[:: (n + 1)].fill_(-1) # Trick to fill diagonal with -1
35 | # max inner prod -> min distance
36 | _, I = torch.max(dots, dim=1) # noqa: E741
37 | return I
38 |
39 | def forward(self, student_output, eps=1e-8):
40 | """
41 | Args:
42 | student_output (BxD): backbone output of student
43 | """
44 | with torch.cuda.amp.autocast(enabled=False):
45 | student_output = F.normalize(student_output, eps=eps, p=2, dim=-1)
46 | I = self.pairwise_NNs_inner(student_output) # noqa: E741
47 | distances = self.pdist(student_output, student_output[I]) # BxD, BxD -> B
48 | loss = -torch.log(distances + eps).mean()
49 | return loss
50 |
--------------------------------------------------------------------------------
/torchhub/facebookresearch_dinov2_main/dinov2/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import logging
8 |
9 | from . import vision_transformer as vits
10 |
11 |
12 | logger = logging.getLogger("dinov2")
13 |
14 |
15 | def build_model(args, only_teacher=False, img_size=224):
16 | args.arch = args.arch.removesuffix("_memeff")
17 | if "vit" in args.arch:
18 | vit_kwargs = dict(
19 | img_size=img_size,
20 | patch_size=args.patch_size,
21 | init_values=args.layerscale,
22 | ffn_layer=args.ffn_layer,
23 | block_chunks=args.block_chunks,
24 | qkv_bias=args.qkv_bias,
25 | proj_bias=args.proj_bias,
26 | ffn_bias=args.ffn_bias,
27 | )
28 | teacher = vits.__dict__[args.arch](**vit_kwargs)
29 | if only_teacher:
30 | return teacher, teacher.embed_dim
31 | student = vits.__dict__[args.arch](
32 | **vit_kwargs,
33 | drop_path_rate=args.drop_path_rate,
34 | drop_path_uniform=args.drop_path_uniform,
35 | )
36 | embed_dim = student.embed_dim
37 | return student, teacher, embed_dim
38 |
39 |
40 | def build_model_from_cfg(cfg, only_teacher=False):
41 | return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size)
42 |
--------------------------------------------------------------------------------
/torchhub/facebookresearch_dinov2_main/dinov2/run/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/torchhub/facebookresearch_dinov2_main/dinov2/run/eval/knn.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import logging
8 | import os
9 | import sys
10 |
11 | from dinov2.eval.knn import get_args_parser as get_knn_args_parser
12 | from dinov2.logging import setup_logging
13 | from dinov2.run.submit import get_args_parser, submit_jobs
14 |
15 |
16 | logger = logging.getLogger("dinov2")
17 |
18 |
19 | class Evaluator:
20 | def __init__(self, args):
21 | self.args = args
22 |
23 | def __call__(self):
24 | from dinov2.eval.knn import main as knn_main
25 |
26 | self._setup_args()
27 | knn_main(self.args)
28 |
29 | def checkpoint(self):
30 | import submitit
31 |
32 | logger.info(f"Requeuing {self.args}")
33 | empty = type(self)(self.args)
34 | return submitit.helpers.DelayedSubmission(empty)
35 |
36 | def _setup_args(self):
37 | import submitit
38 |
39 | job_env = submitit.JobEnvironment()
40 | self.args.output_dir = self.args.output_dir.replace("%j", str(job_env.job_id))
41 | logger.info(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}")
42 | logger.info(f"Args: {self.args}")
43 |
44 |
45 | def main():
46 | description = "Submitit launcher for DINOv2 k-NN evaluation"
47 | knn_args_parser = get_knn_args_parser(add_help=False)
48 | parents = [knn_args_parser]
49 | args_parser = get_args_parser(description=description, parents=parents)
50 | args = args_parser.parse_args()
51 |
52 | setup_logging()
53 |
54 | assert os.path.exists(args.config_file), "Configuration file does not exist!"
55 | submit_jobs(Evaluator, args, name="dinov2:knn")
56 | return 0
57 |
58 |
59 | if __name__ == "__main__":
60 | sys.exit(main())
61 |
--------------------------------------------------------------------------------
/torchhub/facebookresearch_dinov2_main/dinov2/run/eval/linear.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import logging
8 | import os
9 | import sys
10 |
11 | from dinov2.eval.linear import get_args_parser as get_linear_args_parser
12 | from dinov2.logging import setup_logging
13 | from dinov2.run.submit import get_args_parser, submit_jobs
14 |
15 |
16 | logger = logging.getLogger("dinov2")
17 |
18 |
19 | class Evaluator:
20 | def __init__(self, args):
21 | self.args = args
22 |
23 | def __call__(self):
24 | from dinov2.eval.linear import main as linear_main
25 |
26 | self._setup_args()
27 | linear_main(self.args)
28 |
29 | def checkpoint(self):
30 | import submitit
31 |
32 | logger.info(f"Requeuing {self.args}")
33 | empty = type(self)(self.args)
34 | return submitit.helpers.DelayedSubmission(empty)
35 |
36 | def _setup_args(self):
37 | import submitit
38 |
39 | job_env = submitit.JobEnvironment()
40 | self.args.output_dir = self.args.output_dir.replace("%j", str(job_env.job_id))
41 | logger.info(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}")
42 | logger.info(f"Args: {self.args}")
43 |
44 |
45 | def main():
46 | description = "Submitit launcher for DINOv2 linear evaluation"
47 | linear_args_parser = get_linear_args_parser(add_help=False)
48 | parents = [linear_args_parser]
49 | args_parser = get_args_parser(description=description, parents=parents)
50 | args = args_parser.parse_args()
51 |
52 | setup_logging()
53 |
54 | assert os.path.exists(args.config_file), "Configuration file does not exist!"
55 | submit_jobs(Evaluator, args, name="dinov2:linear")
56 | return 0
57 |
58 |
59 | if __name__ == "__main__":
60 | sys.exit(main())
61 |
--------------------------------------------------------------------------------
/torchhub/facebookresearch_dinov2_main/dinov2/run/eval/log_regression.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import logging
8 | import os
9 | import sys
10 |
11 | from dinov2.eval.log_regression import get_args_parser as get_log_regression_args_parser
12 | from dinov2.logging import setup_logging
13 | from dinov2.run.submit import get_args_parser, submit_jobs
14 |
15 |
16 | logger = logging.getLogger("dinov2")
17 |
18 |
19 | class Evaluator:
20 | def __init__(self, args):
21 | self.args = args
22 |
23 | def __call__(self):
24 | from dinov2.eval.log_regression import main as log_regression_main
25 |
26 | self._setup_args()
27 | log_regression_main(self.args)
28 |
29 | def checkpoint(self):
30 | import submitit
31 |
32 | logger.info(f"Requeuing {self.args}")
33 | empty = type(self)(self.args)
34 | return submitit.helpers.DelayedSubmission(empty)
35 |
36 | def _setup_args(self):
37 | import submitit
38 |
39 | job_env = submitit.JobEnvironment()
40 | self.args.output_dir = self.args.output_dir.replace("%j", str(job_env.job_id))
41 | logger.info(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}")
42 | logger.info(f"Args: {self.args}")
43 |
44 |
45 | def main():
46 | description = "Submitit launcher for DINOv2 logistic evaluation"
47 | log_regression_args_parser = get_log_regression_args_parser(add_help=False)
48 | parents = [log_regression_args_parser]
49 | args_parser = get_args_parser(description=description, parents=parents)
50 | args = args_parser.parse_args()
51 |
52 | setup_logging()
53 |
54 | assert os.path.exists(args.config_file), "Configuration file does not exist!"
55 | submit_jobs(Evaluator, args, name="dinov2:logreg")
56 | return 0
57 |
58 |
59 | if __name__ == "__main__":
60 | sys.exit(main())
61 |
--------------------------------------------------------------------------------
/torchhub/facebookresearch_dinov2_main/dinov2/run/submit.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import argparse
8 | import logging
9 | import os
10 | from pathlib import Path
11 | from typing import List, Optional
12 |
13 | import submitit
14 |
15 | from dinov2.utils.cluster import (
16 | get_slurm_executor_parameters,
17 | get_slurm_partition,
18 | get_user_checkpoint_path,
19 | )
20 |
21 |
22 | logger = logging.getLogger("dinov2")
23 |
24 |
25 | def get_args_parser(
26 | description: Optional[str] = None,
27 | parents: Optional[List[argparse.ArgumentParser]] = None,
28 | add_help: bool = True,
29 | ) -> argparse.ArgumentParser:
30 | parents = parents or []
31 | slurm_partition = get_slurm_partition()
32 | parser = argparse.ArgumentParser(
33 | description=description,
34 | parents=parents,
35 | add_help=add_help,
36 | )
37 | parser.add_argument(
38 | "--ngpus",
39 | "--gpus",
40 | "--gpus-per-node",
41 | default=8,
42 | type=int,
43 | help="Number of GPUs to request on each node",
44 | )
45 | parser.add_argument(
46 | "--nodes",
47 | "--nnodes",
48 | default=2,
49 | type=int,
50 | help="Number of nodes to request",
51 | )
52 | parser.add_argument(
53 | "--timeout",
54 | default=2800,
55 | type=int,
56 | help="Duration of the job",
57 | )
58 | parser.add_argument(
59 | "--partition",
60 | default=slurm_partition,
61 | type=str,
62 | help="Partition where to submit",
63 | )
64 | parser.add_argument(
65 | "--use-volta32",
66 | action="store_true",
67 | help="Request V100-32GB GPUs",
68 | )
69 | parser.add_argument(
70 | "--comment",
71 | default="",
72 | type=str,
73 | help="Comment to pass to scheduler, e.g. priority message",
74 | )
75 | parser.add_argument(
76 | "--exclude",
77 | default="",
78 | type=str,
79 | help="Nodes to exclude",
80 | )
81 | return parser
82 |
83 |
84 | def get_shared_folder() -> Path:
85 | user_checkpoint_path = get_user_checkpoint_path()
86 | if user_checkpoint_path is None:
87 | raise RuntimeError("Path to user checkpoint cannot be determined")
88 | path = user_checkpoint_path / "experiments"
89 | path.mkdir(exist_ok=True)
90 | return path
91 |
92 |
93 | def submit_jobs(task_class, args, name: str):
94 | if not args.output_dir:
95 | args.output_dir = str(get_shared_folder() / "%j")
96 |
97 | Path(args.output_dir).mkdir(parents=True, exist_ok=True)
98 | executor = submitit.AutoExecutor(folder=args.output_dir, slurm_max_num_timeout=30)
99 |
100 | kwargs = {}
101 | if args.use_volta32:
102 | kwargs["slurm_constraint"] = "volta32gb"
103 | if args.comment:
104 | kwargs["slurm_comment"] = args.comment
105 | if args.exclude:
106 | kwargs["slurm_exclude"] = args.exclude
107 |
108 | executor_params = get_slurm_executor_parameters(
109 | nodes=args.nodes,
110 | num_gpus_per_node=args.ngpus,
111 | timeout_min=args.timeout, # max is 60 * 72
112 | slurm_signal_delay_s=120,
113 | slurm_partition=args.partition,
114 | **kwargs,
115 | )
116 | executor.update_parameters(name=name, **executor_params)
117 |
118 | task = task_class(args)
119 | job = executor.submit(task)
120 |
121 | logger.info(f"Submitted job_id: {job.job_id}")
122 | str_output_dir = os.path.abspath(args.output_dir).replace("%j", str(job.job_id))
123 | logger.info(f"Logs and checkpoints will be saved at: {str_output_dir}")
124 |
--------------------------------------------------------------------------------
/torchhub/facebookresearch_dinov2_main/dinov2/run/train/train.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import logging
8 | import os
9 | import sys
10 |
11 | from dinov2.logging import setup_logging
12 | from dinov2.train import get_args_parser as get_train_args_parser
13 | from dinov2.run.submit import get_args_parser, submit_jobs
14 |
15 |
16 | logger = logging.getLogger("dinov2")
17 |
18 |
19 | class Trainer(object):
20 | def __init__(self, args):
21 | self.args = args
22 |
23 | def __call__(self):
24 | from dinov2.train import main as train_main
25 |
26 | self._setup_args()
27 | train_main(self.args)
28 |
29 | def checkpoint(self):
30 | import submitit
31 |
32 | logger.info(f"Requeuing {self.args}")
33 | empty = type(self)(self.args)
34 | return submitit.helpers.DelayedSubmission(empty)
35 |
36 | def _setup_args(self):
37 | import submitit
38 |
39 | job_env = submitit.JobEnvironment()
40 | self.args.output_dir = self.args.output_dir.replace("%j", str(job_env.job_id))
41 | logger.info(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}")
42 | logger.info(f"Args: {self.args}")
43 |
44 |
45 | def main():
46 | description = "Submitit launcher for DINOv2 training"
47 | train_args_parser = get_train_args_parser(add_help=False)
48 | parents = [train_args_parser]
49 | args_parser = get_args_parser(description=description, parents=parents)
50 | args = args_parser.parse_args()
51 |
52 | setup_logging()
53 |
54 | assert os.path.exists(args.config_file), "Configuration file does not exist!"
55 | submit_jobs(Trainer, args, name="dinov2:train")
56 | return 0
57 |
58 |
59 | if __name__ == "__main__":
60 | sys.exit(main())
61 |
--------------------------------------------------------------------------------
/torchhub/facebookresearch_dinov2_main/dinov2/train/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from .train import get_args_parser, main
8 | from .ssl_meta_arch import SSLMetaArch
9 |
--------------------------------------------------------------------------------
/torchhub/facebookresearch_dinov2_main/dinov2/train/train.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import argparse
8 | import logging
9 | import math
10 | import os
11 | from functools import partial
12 |
13 | from fvcore.common.checkpoint import PeriodicCheckpointer
14 | import torch
15 |
16 | from dinov2.data import SamplerType, make_data_loader, make_dataset
17 | from dinov2.data import collate_data_and_cast, DataAugmentationDINO, MaskingGenerator
18 | import dinov2.distributed as distributed
19 | from dinov2.fsdp import FSDPCheckpointer
20 | from dinov2.logging import MetricLogger
21 | from dinov2.utils.config import setup
22 | from dinov2.utils.utils import CosineScheduler
23 |
24 | from dinov2.train.ssl_meta_arch import SSLMetaArch
25 |
26 |
27 | torch.backends.cuda.matmul.allow_tf32 = True # PyTorch 1.12 sets this to False by default
28 | logger = logging.getLogger("dinov2")
29 |
30 |
31 | def get_args_parser(add_help: bool = True):
32 | parser = argparse.ArgumentParser("DINOv2 training", add_help=add_help)
33 | parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file")
34 | parser.add_argument(
35 | "--no-resume",
36 | action="store_true",
37 | help="Whether to not attempt to resume from the checkpoint directory. ",
38 | )
39 | parser.add_argument("--eval-only", action="store_true", help="perform evaluation only")
40 | parser.add_argument("--eval", type=str, default="", help="Eval type to perform")
41 | parser.add_argument(
42 | "opts",
43 | help="""
44 | Modify config options at the end of the command. For Yacs configs, use
45 | space-separated "PATH.KEY VALUE" pairs.
46 | For python-based LazyConfig, use "path.key=value".
47 | """.strip(),
48 | default=None,
49 | nargs=argparse.REMAINDER,
50 | )
51 | parser.add_argument(
52 | "--output-dir",
53 | "--output_dir",
54 | default="",
55 | type=str,
56 | help="Output directory to save logs and checkpoints",
57 | )
58 |
59 | return parser
60 |
61 |
62 | def build_optimizer(cfg, params_groups):
63 | return torch.optim.AdamW(params_groups, betas=(cfg.optim.adamw_beta1, cfg.optim.adamw_beta2))
64 |
65 |
66 | def build_schedulers(cfg):
67 | OFFICIAL_EPOCH_LENGTH = cfg.train.OFFICIAL_EPOCH_LENGTH
68 | lr = dict(
69 | base_value=cfg.optim["lr"],
70 | final_value=cfg.optim["min_lr"],
71 | total_iters=cfg.optim["epochs"] * OFFICIAL_EPOCH_LENGTH,
72 | warmup_iters=cfg.optim["warmup_epochs"] * OFFICIAL_EPOCH_LENGTH,
73 | start_warmup_value=0,
74 | )
75 | wd = dict(
76 | base_value=cfg.optim["weight_decay"],
77 | final_value=cfg.optim["weight_decay_end"],
78 | total_iters=cfg.optim["epochs"] * OFFICIAL_EPOCH_LENGTH,
79 | )
80 | momentum = dict(
81 | base_value=cfg.teacher["momentum_teacher"],
82 | final_value=cfg.teacher["final_momentum_teacher"],
83 | total_iters=cfg.optim["epochs"] * OFFICIAL_EPOCH_LENGTH,
84 | )
85 | teacher_temp = dict(
86 | base_value=cfg.teacher["teacher_temp"],
87 | final_value=cfg.teacher["teacher_temp"],
88 | total_iters=cfg.teacher["warmup_teacher_temp_epochs"] * OFFICIAL_EPOCH_LENGTH,
89 | warmup_iters=cfg.teacher["warmup_teacher_temp_epochs"] * OFFICIAL_EPOCH_LENGTH,
90 | start_warmup_value=cfg.teacher["warmup_teacher_temp"],
91 | )
92 |
93 | lr_schedule = CosineScheduler(**lr)
94 | wd_schedule = CosineScheduler(**wd)
95 | momentum_schedule = CosineScheduler(**momentum)
96 | teacher_temp_schedule = CosineScheduler(**teacher_temp)
97 | last_layer_lr_schedule = CosineScheduler(**lr)
98 |
99 | last_layer_lr_schedule.schedule[
100 | : cfg.optim["freeze_last_layer_epochs"] * OFFICIAL_EPOCH_LENGTH
101 | ] = 0 # mimicking the original schedules
102 |
103 | logger.info("Schedulers ready.")
104 |
105 | return (
106 | lr_schedule,
107 | wd_schedule,
108 | momentum_schedule,
109 | teacher_temp_schedule,
110 | last_layer_lr_schedule,
111 | )
112 |
113 |
114 | def apply_optim_scheduler(optimizer, lr, wd, last_layer_lr):
115 | for param_group in optimizer.param_groups:
116 | is_last_layer = param_group["is_last_layer"]
117 | lr_multiplier = param_group["lr_multiplier"]
118 | wd_multiplier = param_group["wd_multiplier"]
119 | param_group["weight_decay"] = wd * wd_multiplier
120 | param_group["lr"] = (last_layer_lr if is_last_layer else lr) * lr_multiplier
121 |
122 |
123 | def do_test(cfg, model, iteration):
124 | new_state_dict = model.teacher.state_dict()
125 |
126 | if distributed.is_main_process():
127 | iterstring = str(iteration)
128 | eval_dir = os.path.join(cfg.train.output_dir, "eval", iterstring)
129 | os.makedirs(eval_dir, exist_ok=True)
130 | # save teacher checkpoint
131 | teacher_ckp_path = os.path.join(eval_dir, "teacher_checkpoint.pth")
132 | torch.save({"teacher": new_state_dict}, teacher_ckp_path)
133 |
134 |
135 | def do_train(cfg, model, resume=False):
136 | model.train()
137 | inputs_dtype = torch.half
138 | fp16_scaler = model.fp16_scaler # for mixed precision training
139 |
140 | # setup optimizer
141 |
142 | optimizer = build_optimizer(cfg, model.get_params_groups())
143 | (
144 | lr_schedule,
145 | wd_schedule,
146 | momentum_schedule,
147 | teacher_temp_schedule,
148 | last_layer_lr_schedule,
149 | ) = build_schedulers(cfg)
150 |
151 | # checkpointer
152 | checkpointer = FSDPCheckpointer(model, cfg.train.output_dir, optimizer=optimizer, save_to_disk=True)
153 |
154 | start_iter = checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=resume).get("iteration", -1) + 1
155 |
156 | OFFICIAL_EPOCH_LENGTH = cfg.train.OFFICIAL_EPOCH_LENGTH
157 | max_iter = cfg.optim.epochs * OFFICIAL_EPOCH_LENGTH
158 |
159 | periodic_checkpointer = PeriodicCheckpointer(
160 | checkpointer,
161 | period=3 * OFFICIAL_EPOCH_LENGTH,
162 | max_iter=max_iter,
163 | max_to_keep=3,
164 | )
165 |
166 | # setup data preprocessing
167 |
168 | img_size = cfg.crops.global_crops_size
169 | patch_size = cfg.student.patch_size
170 | n_tokens = (img_size // patch_size) ** 2
171 | mask_generator = MaskingGenerator(
172 | input_size=(img_size // patch_size, img_size // patch_size),
173 | max_num_patches=0.5 * img_size // patch_size * img_size // patch_size,
174 | )
175 |
176 | data_transform = DataAugmentationDINO(
177 | cfg.crops.global_crops_scale,
178 | cfg.crops.local_crops_scale,
179 | cfg.crops.local_crops_number,
180 | global_crops_size=cfg.crops.global_crops_size,
181 | local_crops_size=cfg.crops.local_crops_size,
182 | )
183 |
184 | collate_fn = partial(
185 | collate_data_and_cast,
186 | mask_ratio_tuple=cfg.ibot.mask_ratio_min_max,
187 | mask_probability=cfg.ibot.mask_sample_probability,
188 | n_tokens=n_tokens,
189 | mask_generator=mask_generator,
190 | dtype=inputs_dtype,
191 | )
192 |
193 | # setup data loader
194 |
195 | dataset = make_dataset(
196 | dataset_str=cfg.train.dataset_path,
197 | transform=data_transform,
198 | target_transform=lambda _: (),
199 | )
200 | # sampler_type = SamplerType.INFINITE
201 | sampler_type = SamplerType.SHARDED_INFINITE
202 | data_loader = make_data_loader(
203 | dataset=dataset,
204 | batch_size=cfg.train.batch_size_per_gpu,
205 | num_workers=cfg.train.num_workers,
206 | shuffle=True,
207 | seed=start_iter, # TODO: Fix this -- cfg.train.seed
208 | sampler_type=sampler_type,
209 | sampler_advance=0, # TODO(qas): fix this -- start_iter * cfg.train.batch_size_per_gpu,
210 | drop_last=True,
211 | collate_fn=collate_fn,
212 | )
213 |
214 | # training loop
215 |
216 | iteration = start_iter
217 |
218 | logger.info("Starting training from iteration {}".format(start_iter))
219 | metrics_file = os.path.join(cfg.train.output_dir, "training_metrics.json")
220 | metric_logger = MetricLogger(delimiter=" ", output_file=metrics_file)
221 | header = "Training"
222 |
223 | for data in metric_logger.log_every(
224 | data_loader,
225 | 10,
226 | header,
227 | max_iter,
228 | start_iter,
229 | ):
230 | current_batch_size = data["collated_global_crops"].shape[0] / 2
231 | if iteration > max_iter:
232 | return
233 |
234 | # apply schedules
235 |
236 | lr = lr_schedule[iteration]
237 | wd = wd_schedule[iteration]
238 | mom = momentum_schedule[iteration]
239 | teacher_temp = teacher_temp_schedule[iteration]
240 | last_layer_lr = last_layer_lr_schedule[iteration]
241 | apply_optim_scheduler(optimizer, lr, wd, last_layer_lr)
242 |
243 | # compute losses
244 |
245 | optimizer.zero_grad(set_to_none=True)
246 | loss_dict = model.forward_backward(data, teacher_temp=teacher_temp)
247 |
248 | # clip gradients
249 |
250 | if fp16_scaler is not None:
251 | if cfg.optim.clip_grad:
252 | fp16_scaler.unscale_(optimizer)
253 | for v in model.student.values():
254 | v.clip_grad_norm_(cfg.optim.clip_grad)
255 | fp16_scaler.step(optimizer)
256 | fp16_scaler.update()
257 | else:
258 | if cfg.optim.clip_grad:
259 | for v in model.student.values():
260 | v.clip_grad_norm_(cfg.optim.clip_grad)
261 | optimizer.step()
262 |
263 | # perform teacher EMA update
264 |
265 | model.update_teacher(mom)
266 |
267 | # logging
268 |
269 | if distributed.get_global_size() > 1:
270 | for v in loss_dict.values():
271 | torch.distributed.all_reduce(v)
272 | loss_dict_reduced = {k: v.item() / distributed.get_global_size() for k, v in loss_dict.items()}
273 |
274 | if math.isnan(sum(loss_dict_reduced.values())):
275 | logger.info("NaN detected")
276 | raise AssertionError
277 | losses_reduced = sum(loss for loss in loss_dict_reduced.values())
278 |
279 | metric_logger.update(lr=lr)
280 | metric_logger.update(wd=wd)
281 | metric_logger.update(mom=mom)
282 | metric_logger.update(last_layer_lr=last_layer_lr)
283 | metric_logger.update(current_batch_size=current_batch_size)
284 | metric_logger.update(total_loss=losses_reduced, **loss_dict_reduced)
285 |
286 | # checkpointing and testing
287 |
288 | if cfg.evaluation.eval_period_iterations > 0 and (iteration + 1) % cfg.evaluation.eval_period_iterations == 0:
289 | do_test(cfg, model, f"training_{iteration}")
290 | torch.cuda.synchronize()
291 | periodic_checkpointer.step(iteration)
292 |
293 | iteration = iteration + 1
294 | metric_logger.synchronize_between_processes()
295 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
296 |
297 |
298 | def main(args):
299 | cfg = setup(args)
300 |
301 | model = SSLMetaArch(cfg).to(torch.device("cuda"))
302 | model.prepare_for_distributed_training()
303 |
304 | logger.info("Model:\n{}".format(model))
305 | if args.eval_only:
306 | iteration = (
307 | FSDPCheckpointer(model, save_dir=cfg.train.output_dir)
308 | .resume_or_load(cfg.MODEL.WEIGHTS, resume=not args.no_resume)
309 | .get("iteration", -1)
310 | + 1
311 | )
312 | return do_test(cfg, model, f"manual_{iteration}")
313 |
314 | do_train(cfg, model, resume=not args.no_resume)
315 |
316 |
317 | if __name__ == "__main__":
318 | args = get_args_parser(add_help=True).parse_args()
319 | main(args)
320 |
--------------------------------------------------------------------------------
/torchhub/facebookresearch_dinov2_main/dinov2/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/torchhub/facebookresearch_dinov2_main/dinov2/utils/cluster.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from enum import Enum
8 | import os
9 | from pathlib import Path
10 | from typing import Any, Dict, Optional
11 |
12 |
13 | class ClusterType(Enum):
14 | AWS = "aws"
15 | FAIR = "fair"
16 | RSC = "rsc"
17 |
18 |
19 | def _guess_cluster_type() -> ClusterType:
20 | uname = os.uname()
21 | if uname.sysname == "Linux":
22 | if uname.release.endswith("-aws"):
23 | # Linux kernel versions on AWS instances are of the form "5.4.0-1051-aws"
24 | return ClusterType.AWS
25 | elif uname.nodename.startswith("rsc"):
26 | # Linux kernel versions on RSC instances are standard ones but hostnames start with "rsc"
27 | return ClusterType.RSC
28 |
29 | return ClusterType.FAIR
30 |
31 |
32 | def get_cluster_type(cluster_type: Optional[ClusterType] = None) -> Optional[ClusterType]:
33 | if cluster_type is None:
34 | return _guess_cluster_type()
35 |
36 | return cluster_type
37 |
38 |
39 | def get_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]:
40 | cluster_type = get_cluster_type(cluster_type)
41 | if cluster_type is None:
42 | return None
43 |
44 | CHECKPOINT_DIRNAMES = {
45 | ClusterType.AWS: "checkpoints",
46 | ClusterType.FAIR: "checkpoint",
47 | ClusterType.RSC: "checkpoint/dino",
48 | }
49 | return Path("/") / CHECKPOINT_DIRNAMES[cluster_type]
50 |
51 |
52 | def get_user_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]:
53 | checkpoint_path = get_checkpoint_path(cluster_type)
54 | if checkpoint_path is None:
55 | return None
56 |
57 | username = os.environ.get("USER")
58 | assert username is not None
59 | return checkpoint_path / username
60 |
61 |
62 | def get_slurm_partition(cluster_type: Optional[ClusterType] = None) -> Optional[str]:
63 | cluster_type = get_cluster_type(cluster_type)
64 | if cluster_type is None:
65 | return None
66 |
67 | SLURM_PARTITIONS = {
68 | ClusterType.AWS: "learnlab",
69 | ClusterType.FAIR: "learnlab",
70 | ClusterType.RSC: "learn",
71 | }
72 | return SLURM_PARTITIONS[cluster_type]
73 |
74 |
75 | def get_slurm_executor_parameters(
76 | nodes: int, num_gpus_per_node: int, cluster_type: Optional[ClusterType] = None, **kwargs
77 | ) -> Dict[str, Any]:
78 | # create default parameters
79 | params = {
80 | "mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html
81 | "gpus_per_node": num_gpus_per_node,
82 | "tasks_per_node": num_gpus_per_node, # one task per GPU
83 | "cpus_per_task": 10,
84 | "nodes": nodes,
85 | "slurm_partition": get_slurm_partition(cluster_type),
86 | }
87 | # apply cluster-specific adjustments
88 | cluster_type = get_cluster_type(cluster_type)
89 | if cluster_type == ClusterType.AWS:
90 | params["cpus_per_task"] = 12
91 | del params["mem_gb"]
92 | elif cluster_type == ClusterType.RSC:
93 | params["cpus_per_task"] = 12
94 | # set additional parameters / apply overrides
95 | params.update(kwargs)
96 | return params
97 |
--------------------------------------------------------------------------------
/torchhub/facebookresearch_dinov2_main/dinov2/utils/config.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import math
8 | import logging
9 | import os
10 |
11 | from omegaconf import OmegaConf
12 |
13 | import dinov2.distributed as distributed
14 | from dinov2.logging import setup_logging
15 | from dinov2.utils import utils
16 | from dinov2.configs import dinov2_default_config
17 |
18 |
19 | logger = logging.getLogger("dinov2")
20 |
21 |
22 | def apply_scaling_rules_to_cfg(cfg): # to fix
23 | if cfg.optim.scaling_rule == "sqrt_wrt_1024":
24 | base_lr = cfg.optim.base_lr
25 | cfg.optim.lr = base_lr
26 | cfg.optim.lr *= math.sqrt(cfg.train.batch_size_per_gpu * distributed.get_global_size() / 1024.0)
27 | logger.info(f"sqrt scaling learning rate; base: {base_lr}, new: {cfg.optim.lr}")
28 | else:
29 | raise NotImplementedError
30 | return cfg
31 |
32 |
33 | def write_config(cfg, output_dir, name="config.yaml"):
34 | logger.info(OmegaConf.to_yaml(cfg))
35 | saved_cfg_path = os.path.join(output_dir, name)
36 | with open(saved_cfg_path, "w") as f:
37 | OmegaConf.save(config=cfg, f=f)
38 | return saved_cfg_path
39 |
40 |
41 | def get_cfg_from_args(args):
42 | args.output_dir = os.path.abspath(args.output_dir)
43 | args.opts += [f"train.output_dir={args.output_dir}"]
44 | default_cfg = OmegaConf.create(dinov2_default_config)
45 | cfg = OmegaConf.load(args.config_file)
46 | cfg = OmegaConf.merge(default_cfg, cfg, OmegaConf.from_cli(args.opts))
47 | return cfg
48 |
49 |
50 | def default_setup(args):
51 | distributed.enable(overwrite=True)
52 | seed = getattr(args, "seed", 0)
53 | rank = distributed.get_global_rank()
54 |
55 | global logger
56 | setup_logging(output=args.output_dir, level=logging.INFO)
57 | logger = logging.getLogger("dinov2")
58 |
59 | utils.fix_random_seeds(seed + rank)
60 | logger.info("git:\n {}\n".format(utils.get_sha()))
61 | logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
62 |
63 |
64 | def setup(args):
65 | """
66 | Create configs and perform basic setups.
67 | """
68 | cfg = get_cfg_from_args(args)
69 | os.makedirs(args.output_dir, exist_ok=True)
70 | default_setup(args)
71 | apply_scaling_rules_to_cfg(cfg)
72 | write_config(cfg, args.output_dir)
73 | return cfg
74 |
--------------------------------------------------------------------------------
/torchhub/facebookresearch_dinov2_main/dinov2/utils/dtype.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 |
8 | from typing import Dict, Union
9 |
10 | import numpy as np
11 | import torch
12 |
13 |
14 | TypeSpec = Union[str, np.dtype, torch.dtype]
15 |
16 |
17 | _NUMPY_TO_TORCH_DTYPE: Dict[np.dtype, torch.dtype] = {
18 | np.dtype("bool"): torch.bool,
19 | np.dtype("uint8"): torch.uint8,
20 | np.dtype("int8"): torch.int8,
21 | np.dtype("int16"): torch.int16,
22 | np.dtype("int32"): torch.int32,
23 | np.dtype("int64"): torch.int64,
24 | np.dtype("float16"): torch.float16,
25 | np.dtype("float32"): torch.float32,
26 | np.dtype("float64"): torch.float64,
27 | np.dtype("complex64"): torch.complex64,
28 | np.dtype("complex128"): torch.complex128,
29 | }
30 |
31 |
32 | def as_torch_dtype(dtype: TypeSpec) -> torch.dtype:
33 | if isinstance(dtype, torch.dtype):
34 | return dtype
35 | if isinstance(dtype, str):
36 | dtype = np.dtype(dtype)
37 | assert isinstance(dtype, np.dtype), f"Expected an instance of nunpy dtype, got {type(dtype)}"
38 | return _NUMPY_TO_TORCH_DTYPE[dtype]
39 |
--------------------------------------------------------------------------------
/torchhub/facebookresearch_dinov2_main/dinov2/utils/param_groups.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from collections import defaultdict
8 | import logging
9 |
10 |
11 | logger = logging.getLogger("dinov2")
12 |
13 |
14 | def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12, force_is_backbone=False, chunked_blocks=False):
15 | """
16 | Calculate lr decay rate for different ViT blocks.
17 | Args:
18 | name (string): parameter name.
19 | lr_decay_rate (float): base lr decay rate.
20 | num_layers (int): number of ViT blocks.
21 | Returns:
22 | lr decay rate for the given parameter.
23 | """
24 | layer_id = num_layers + 1
25 | if name.startswith("backbone") or force_is_backbone:
26 | if ".pos_embed" in name or ".patch_embed" in name or ".mask_token" in name or ".cls_token" in name:
27 | layer_id = 0
28 | elif force_is_backbone and (
29 | "pos_embed" in name or "patch_embed" in name or "mask_token" in name or "cls_token" in name
30 | ):
31 | layer_id = 0
32 | elif ".blocks." in name and ".residual." not in name:
33 | layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1
34 | elif chunked_blocks and "blocks." in name and "residual." not in name:
35 | layer_id = int(name[name.find("blocks.") :].split(".")[2]) + 1
36 | elif "blocks." in name and "residual." not in name:
37 | layer_id = int(name[name.find("blocks.") :].split(".")[1]) + 1
38 |
39 | return lr_decay_rate ** (num_layers + 1 - layer_id)
40 |
41 |
42 | def get_params_groups_with_decay(model, lr_decay_rate=1.0, patch_embed_lr_mult=1.0):
43 | chunked_blocks = False
44 | if hasattr(model, "n_blocks"):
45 | logger.info("chunked fsdp")
46 | n_blocks = model.n_blocks
47 | chunked_blocks = model.chunked_blocks
48 | elif hasattr(model, "blocks"):
49 | logger.info("first code branch")
50 | n_blocks = len(model.blocks)
51 | elif hasattr(model, "backbone"):
52 | logger.info("second code branch")
53 | n_blocks = len(model.backbone.blocks)
54 | else:
55 | logger.info("else code branch")
56 | n_blocks = 0
57 | all_param_groups = []
58 |
59 | for name, param in model.named_parameters():
60 | name = name.replace("_fsdp_wrapped_module.", "")
61 | if not param.requires_grad:
62 | continue
63 | decay_rate = get_vit_lr_decay_rate(
64 | name, lr_decay_rate, num_layers=n_blocks, force_is_backbone=n_blocks > 0, chunked_blocks=chunked_blocks
65 | )
66 | d = {"params": param, "is_last_layer": False, "lr_multiplier": decay_rate, "wd_multiplier": 1.0, "name": name}
67 |
68 | if "last_layer" in name:
69 | d.update({"is_last_layer": True})
70 |
71 | if name.endswith(".bias") or "norm" in name or "gamma" in name:
72 | d.update({"wd_multiplier": 0.0})
73 |
74 | if "patch_embed" in name:
75 | d.update({"lr_multiplier": d["lr_multiplier"] * patch_embed_lr_mult})
76 |
77 | all_param_groups.append(d)
78 | logger.info(f"""{name}: lr_multiplier: {d["lr_multiplier"]}, wd_multiplier: {d["wd_multiplier"]}""")
79 |
80 | return all_param_groups
81 |
82 |
83 | def fuse_params_groups(all_params_groups, keys=("lr_multiplier", "wd_multiplier", "is_last_layer")):
84 | fused_params_groups = defaultdict(lambda: {"params": []})
85 | for d in all_params_groups:
86 | identifier = ""
87 | for k in keys:
88 | identifier += k + str(d[k]) + "_"
89 |
90 | for k in keys:
91 | fused_params_groups[identifier][k] = d[k]
92 | fused_params_groups[identifier]["params"].append(d["params"])
93 |
94 | return fused_params_groups.values()
95 |
--------------------------------------------------------------------------------
/torchhub/facebookresearch_dinov2_main/dinov2/utils/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import logging
8 | import os
9 | import random
10 | import subprocess
11 | from urllib.parse import urlparse
12 |
13 | import numpy as np
14 | import torch
15 | from torch import nn
16 |
17 |
18 | logger = logging.getLogger("dinov2")
19 |
20 |
21 | def load_pretrained_weights(model, pretrained_weights, checkpoint_key):
22 | if urlparse(pretrained_weights).scheme: # If it looks like an URL
23 | state_dict = torch.hub.load_state_dict_from_url(pretrained_weights, map_location="cpu")
24 | else:
25 | state_dict = torch.load(pretrained_weights, map_location="cpu")
26 | if checkpoint_key is not None and checkpoint_key in state_dict:
27 | logger.info(f"Take key {checkpoint_key} in provided checkpoint dict")
28 | state_dict = state_dict[checkpoint_key]
29 | # remove `module.` prefix
30 | state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
31 | # remove `backbone.` prefix induced by multicrop wrapper
32 | state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
33 | msg = model.load_state_dict(state_dict, strict=False)
34 | logger.info("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg))
35 |
36 |
37 | def fix_random_seeds(seed=31):
38 | """
39 | Fix random seeds.
40 | """
41 | torch.manual_seed(seed)
42 | torch.cuda.manual_seed_all(seed)
43 | np.random.seed(seed)
44 | random.seed(seed)
45 |
46 |
47 | def get_sha():
48 | cwd = os.path.dirname(os.path.abspath(__file__))
49 |
50 | def _run(command):
51 | return subprocess.check_output(command, cwd=cwd).decode("ascii").strip()
52 |
53 | sha = "N/A"
54 | diff = "clean"
55 | branch = "N/A"
56 | try:
57 | sha = _run(["git", "rev-parse", "HEAD"])
58 | subprocess.check_output(["git", "diff"], cwd=cwd)
59 | diff = _run(["git", "diff-index", "HEAD"])
60 | diff = "has uncommitted changes" if diff else "clean"
61 | branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"])
62 | except Exception:
63 | pass
64 | message = f"sha: {sha}, status: {diff}, branch: {branch}"
65 | return message
66 |
67 |
68 | class CosineScheduler(object):
69 | def __init__(self, base_value, final_value, total_iters, warmup_iters=0, start_warmup_value=0, freeze_iters=0):
70 | super().__init__()
71 | self.final_value = final_value
72 | self.total_iters = total_iters
73 |
74 | freeze_schedule = np.zeros((freeze_iters))
75 |
76 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
77 |
78 | iters = np.arange(total_iters - warmup_iters - freeze_iters)
79 | schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))
80 | self.schedule = np.concatenate((freeze_schedule, warmup_schedule, schedule))
81 |
82 | assert len(self.schedule) == self.total_iters
83 |
84 | def __getitem__(self, it):
85 | if it >= self.total_iters:
86 | return self.final_value
87 | else:
88 | return self.schedule[it]
89 |
90 |
91 | def has_batchnorms(model):
92 | bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)
93 | for name, module in model.named_modules():
94 | if isinstance(module, bn_types):
95 | return True
96 | return False
97 |
--------------------------------------------------------------------------------
/torchhub/facebookresearch_dinov2_main/hubconf.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the Apache License, Version 2.0
4 | # found in the LICENSE file in the root directory of this source tree.
5 |
6 | from enum import Enum
7 | from typing import Union
8 |
9 | import torch
10 |
11 | _DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"
12 |
13 |
14 | def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str:
15 | compact_arch_name = arch_name.replace("_", "")[:4]
16 | registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else ""
17 | return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}"
18 |
19 |
20 | class Weights(Enum):
21 | LVD142M = "LVD142M"
22 |
23 |
24 | def _make_dinov2_model(
25 | *,
26 | arch_name: str = "vit_large",
27 | img_size: int = 518,
28 | patch_size: int = 14,
29 | init_values: float = 1.0,
30 | ffn_layer: str = "mlp",
31 | block_chunks: int = 0,
32 | num_register_tokens: int = 0,
33 | interpolate_antialias: bool = False,
34 | interpolate_offset: float = 0.1,
35 | pretrained: bool = True,
36 | weights: Union[Weights, str] = Weights.LVD142M,
37 | **kwargs,
38 | ):
39 | import vision_transformer as vits
40 |
41 | if isinstance(weights, str):
42 | try:
43 | weights = Weights[weights]
44 | except KeyError:
45 | raise AssertionError(f"Unsupported weights: {weights}")
46 |
47 | model_base_name = _make_dinov2_model_name(arch_name, patch_size)
48 | vit_kwargs = dict(
49 | img_size=img_size,
50 | patch_size=patch_size,
51 | init_values=init_values,
52 | ffn_layer=ffn_layer,
53 | block_chunks=block_chunks,
54 | num_register_tokens=num_register_tokens,
55 | interpolate_antialias=interpolate_antialias,
56 | interpolate_offset=interpolate_offset,
57 | )
58 | vit_kwargs.update(**kwargs)
59 | model = vits.__dict__[arch_name](**vit_kwargs)
60 |
61 | if pretrained:
62 | model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens)
63 | url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth"
64 | state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
65 | model.load_state_dict(state_dict, strict=True)
66 |
67 | return model
68 |
69 |
70 | def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
71 | """
72 | DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset.
73 | """
74 | return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs)
75 |
76 |
77 | def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
78 | """
79 | DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset.
80 | """
81 | return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs)
82 |
83 |
84 | def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
85 | """
86 | DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.
87 | """
88 | return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs)
89 |
90 |
91 | def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
92 | """
93 | DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset.
94 | """
95 | return _make_dinov2_model(
96 | arch_name="vit_giant2",
97 | ffn_layer="swiglufused",
98 | weights=weights,
99 | pretrained=pretrained,
100 | **kwargs,
101 | )
102 |
103 |
104 | def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
105 | """
106 | DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset.
107 | """
108 | return _make_dinov2_model(
109 | arch_name="vit_small",
110 | pretrained=pretrained,
111 | weights=weights,
112 | num_register_tokens=4,
113 | interpolate_antialias=True,
114 | interpolate_offset=0.0,
115 | **kwargs,
116 | )
117 |
118 |
119 | def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
120 | """
121 | DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset.
122 | """
123 | return _make_dinov2_model(
124 | arch_name="vit_base",
125 | pretrained=pretrained,
126 | weights=weights,
127 | num_register_tokens=4,
128 | interpolate_antialias=True,
129 | interpolate_offset=0.0,
130 | **kwargs,
131 | )
132 |
133 |
134 | def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
135 | """
136 | DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset.
137 | """
138 | return _make_dinov2_model(
139 | arch_name="vit_large",
140 | pretrained=pretrained,
141 | weights=weights,
142 | num_register_tokens=4,
143 | interpolate_antialias=True,
144 | interpolate_offset=0.0,
145 | **kwargs,
146 | )
147 |
148 |
149 | def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
150 | """
151 | DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset.
152 | """
153 | return _make_dinov2_model(
154 | arch_name="vit_giant2",
155 | ffn_layer="swiglufused",
156 | weights=weights,
157 | pretrained=pretrained,
158 | num_register_tokens=4,
159 | interpolate_antialias=True,
160 | interpolate_offset=0.0,
161 | **kwargs,
162 | )
163 |
--------------------------------------------------------------------------------
/torchhub/facebookresearch_dinov2_main/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.black]
2 | line-length = 120
3 |
4 | [tool.pylint.master]
5 | persistent = false
6 | score = false
7 |
8 | [tool.pylint.messages_control]
9 | disable = "all"
10 | enable = [
11 | "miscellaneous",
12 | "similarities",
13 | ]
14 |
15 | [tool.pylint.similarities]
16 | ignore-comments = true
17 | ignore-docstrings = true
18 | ignore-imports = true
19 | min-similarity-lines = 8
20 |
21 | [tool.pylint.reports]
22 | reports = false
23 |
24 | [tool.pylint.miscellaneous]
25 | notes = [
26 | "FIXME",
27 | "XXX",
28 | "TODO",
29 | ]
30 |
--------------------------------------------------------------------------------
/torchhub/facebookresearch_dinov2_main/scripts/lint.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | if [ -n "$1" ]; then
4 | echo "linting \"$1\""
5 | fi
6 |
7 | echo "running black"
8 | if [ -n "$1" ]; then
9 | black "$1"
10 | else
11 | black dinov2
12 | fi
13 |
14 | echo "running flake8"
15 | if [ -n "$1" ]; then
16 | flake8 "$1"
17 | else
18 | flake8
19 | fi
20 |
21 | echo "running pylint"
22 | if [ -n "$1" ]; then
23 | pylint "$1"
24 | else
25 | pylint dinov2
26 | fi
27 |
28 | exit 0
29 |
--------------------------------------------------------------------------------
/torchhub/facebookresearch_dinov2_main/setup.cfg:
--------------------------------------------------------------------------------
1 | [flake8]
2 | max-line-length = 120
3 | ignore = E203,E501,W503
4 | per-file-ignores =
5 | __init__.py:F401
6 | exclude =
7 | venv
8 |
--------------------------------------------------------------------------------
/torchhub/facebookresearch_dinov2_main/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from pathlib import Path
8 | import re
9 | from typing import List, Tuple
10 |
11 | from setuptools import setup, find_packages
12 |
13 |
14 | NAME = "dinov2"
15 | DESCRIPTION = "PyTorch code and models for the DINOv2 self-supervised learning method."
16 |
17 | URL = "https://github.com/facebookresearch/dinov2"
18 | AUTHOR = "FAIR"
19 | REQUIRES_PYTHON = ">=3.9.0"
20 | HERE = Path(__file__).parent
21 |
22 |
23 | try:
24 | with open(HERE / "README.md", encoding="utf-8") as f:
25 | long_description = "\n" + f.read()
26 | except FileNotFoundError:
27 | long_description = DESCRIPTION
28 |
29 |
30 | def get_requirements(path: str = HERE / "requirements.txt") -> Tuple[List[str], List[str]]:
31 | requirements = []
32 | extra_indices = []
33 | with open(path) as f:
34 | for line in f.readlines():
35 | line = line.rstrip("\r\n")
36 | if line.startswith("--extra-index-url "):
37 | extra_indices.append(line[18:])
38 | continue
39 | requirements.append(line)
40 | return requirements, extra_indices
41 |
42 |
43 | def get_package_version() -> str:
44 | with open(HERE / "dinov2/__init__.py") as f:
45 | result = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", f.read(), re.M)
46 | if result:
47 | return result.group(1)
48 | raise RuntimeError("Can't get package version")
49 |
50 |
51 | requirements, extra_indices = get_requirements()
52 | version = get_package_version()
53 | dev_requirements, _ = get_requirements(HERE / "requirements-dev.txt")
54 |
55 |
56 | setup(
57 | name=NAME,
58 | version=version,
59 | description=DESCRIPTION,
60 | long_description=long_description,
61 | long_description_content_type="text/markdown",
62 | author=AUTHOR,
63 | python_requires=REQUIRES_PYTHON,
64 | url=URL,
65 | packages=find_packages(),
66 | package_data={
67 | "": ["*.yaml"],
68 | },
69 | install_requires=requirements,
70 | dependency_links=extra_indices,
71 | extras_require={
72 | "dev": dev_requirements,
73 | },
74 | install_package_data=True,
75 | license="CC-BY-NC",
76 | license_files=("LICENSE",),
77 | classifiers=[
78 | # Trove classifiers: https://github.com/pypa/trove-classifiers/blob/main/src/trove_classifiers/__init__.py
79 | "Development Status :: 3 - Alpha",
80 | "Intended Audience :: Developers",
81 | "Intended Audience :: Science/Research",
82 | "License :: Other/Proprietary License",
83 | "Programming Language :: Python :: 3.9",
84 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
85 | "Topic :: Software Development :: Libraries :: Python Modules",
86 | ],
87 | )
88 |
--------------------------------------------------------------------------------
/torchhub/facebookresearch_dinov2_main/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the Apache License, Version 2.0
4 | # found in the LICENSE file in the root directory of this source tree.
5 |
6 | import itertools
7 | import math
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 |
13 |
14 | _DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"
15 |
16 |
17 | def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str:
18 | compact_arch_name = arch_name.replace("_", "")[:4]
19 | registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else ""
20 | return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}"
21 |
22 |
23 | class CenterPadding(nn.Module):
24 | def __init__(self, multiple):
25 | super().__init__()
26 | self.multiple = multiple
27 |
28 | def _get_pad(self, size):
29 | new_size = math.ceil(size / self.multiple) * self.multiple
30 | pad_size = new_size - size
31 | pad_size_left = pad_size // 2
32 | pad_size_right = pad_size - pad_size_left
33 | return pad_size_left, pad_size_right
34 |
35 | @torch.inference_mode()
36 | def forward(self, x):
37 | pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1]))
38 | output = F.pad(x, pads)
39 | return output
40 |
--------------------------------------------------------------------------------