├── .github ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── attention_maps.png └── dino.gif ├── LICENSE ├── README.md ├── eval_copy_detection.py ├── eval_image_retrieval.py ├── eval_knn.py ├── eval_linear.py ├── eval_video_segmentation.py ├── hubconf.py ├── main_dino.py ├── run_with_submitit.py ├── utils.py ├── video_generation.py ├── vision_transformer.py └── visualize_attention.py /.github/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | Facebook has adopted a Code of Conduct that we expect project participants to adhere to. 4 | Please read the [full text](https://code.fb.com/codeofconduct/) 5 | so that you can understand what actions will and will not be tolerated. 6 | -------------------------------------------------------------------------------- /.github/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | In the context of this project, we do not expect pull requests. 4 | If you find a bug, or would like to suggest an improvement, please open an issue. 5 | -------------------------------------------------------------------------------- /.github/attention_maps.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/dino/7c446df5b9f45747937fb0d72314eb9f7b66930a/.github/attention_maps.png -------------------------------------------------------------------------------- /.github/dino.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/dino/7c446df5b9f45747937fb0d72314eb9f7b66930a/.github/dino.gif -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | :new: *Please check out our more recent [DINOv2](https://github.com/facebookresearch/dinov2) effort in the same line of work.* 2 | 3 | # Self-Supervised Vision Transformers with DINO 4 | 5 | PyTorch implementation and pretrained models for DINO. For details, see **Emerging Properties in Self-Supervised Vision Transformers**. 6 | [[`blogpost`](https://ai.facebook.com/blog/dino-paws-computer-vision-with-self-supervised-transformers-and-10x-more-efficient-training)] [[`arXiv`](https://arxiv.org/abs/2104.14294)] [[`Yannic Kilcher's video`](https://www.youtube.com/watch?v=h3ij3F3cPIk)] 7 | 8 |
9 | DINO illustration 10 |
11 | 12 | ## Pretrained models 13 | You can choose to download only the weights of the pretrained backbone used for downstream tasks, or the full checkpoint which contains backbone and projection head weights for both student and teacher networks. We also provide the backbone in `onnx` format, as well as detailed arguments and training/evaluation logs. Note that `DeiT-S` and `ViT-S` names refer exactly to the same architecture. 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 |
archparamsk-nnlineardownload
ViT-S/1621M74.5%77.0%backbone onlyfull ckptonnxargslogseval logs
ViT-S/821M78.3%79.7%backbone onlyfull ckptonnxargslogseval logs
ViT-B/1685M76.1%78.2%backbone onlyfull ckptonnxargslogseval logs
ViT-B/885M77.4%80.1%backbone onlyfull ckptonnxargslogseval logs
ResNet-5023M67.5%75.3%backbone onlyfull ckptonnxargslogseval logs
84 | 85 | We also release XCiT models ([[`arXiv`](https://arxiv.org/abs/2106.09681)] [[`code`](https://github.com/facebookresearch/xcit)]) trained with DINO: 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 |
archparamsk-nnlineardownload
xcit_small_12_p1626M76.0%77.8%backbone onlyfull ckptargslogseval
xcit_small_12_p826M77.1%79.2%backbone onlyfull ckptargslogseval
xcit_medium_24_p1684M76.4%78.8%backbone onlyfull ckptargslogseval
xcit_medium_24_p884M77.9%80.3%backbone onlyfull ckptargslogseval
139 | 140 | ### Pretrained models on PyTorch Hub 141 | ```python 142 | import torch 143 | vits16 = torch.hub.load('facebookresearch/dino:main', 'dino_vits16') 144 | vits8 = torch.hub.load('facebookresearch/dino:main', 'dino_vits8') 145 | vitb16 = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16') 146 | vitb8 = torch.hub.load('facebookresearch/dino:main', 'dino_vitb8') 147 | xcit_small_12_p16 = torch.hub.load('facebookresearch/dino:main', 'dino_xcit_small_12_p16') 148 | xcit_small_12_p8 = torch.hub.load('facebookresearch/dino:main', 'dino_xcit_small_12_p8') 149 | xcit_medium_24_p16 = torch.hub.load('facebookresearch/dino:main', 'dino_xcit_medium_24_p16') 150 | xcit_medium_24_p8 = torch.hub.load('facebookresearch/dino:main', 'dino_xcit_medium_24_p8') 151 | resnet50 = torch.hub.load('facebookresearch/dino:main', 'dino_resnet50') 152 | ``` 153 | 154 | ## Training 155 | 156 | ### Documentation 157 | Please install [PyTorch](https://pytorch.org/) and download the [ImageNet](https://imagenet.stanford.edu/) dataset. This codebase has been developed with python version 3.6, PyTorch version 1.7.1, CUDA 11.0 and torchvision 0.8.2. The exact arguments to reproduce the models presented in our paper can be found in the `args` column of the [pretrained models section](https://github.com/facebookresearch/dino#pretrained-models). For a glimpse at the full documentation of DINO training please run: 158 | ``` 159 | python main_dino.py --help 160 | ``` 161 | 162 | ### Vanilla DINO training :sauropod: 163 | Run DINO with ViT-small network on a single node with 8 GPUs for 100 epochs with the following command. Training time is 1.75 day and the resulting checkpoint should reach 69.3% on k-NN eval and 74.0% on linear eval. We provide [training](https://dl.fbaipublicfiles.com/dino/example_runs_logs/dino_vanilla_deitsmall16_log.txt) and [linear evaluation](https://dl.fbaipublicfiles.com/dino/example_runs_logs/dino_vanilla_deitsmall16_eval.txt) logs (with batch size 256 at evaluation time) for this run to help reproducibility. 164 | ``` 165 | python -m torch.distributed.launch --nproc_per_node=8 main_dino.py --arch vit_small --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir 166 | ``` 167 | 168 | ### Multi-node training 169 | We use Slurm and [submitit](https://github.com/facebookincubator/submitit) (`pip install submitit`). To train on 2 nodes with 8 GPUs each (total 16 GPUs): 170 | ``` 171 | python run_with_submitit.py --nodes 2 --ngpus 8 --arch vit_small --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir 172 | ``` 173 | 174 |
175 | 176 | DINO with ViT-base network. 177 | 178 | 179 | ``` 180 | python run_with_submitit.py --nodes 2 --ngpus 8 --use_volta32 --arch vit_base --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir 181 | ``` 182 | 183 |
184 | 185 | ### Boosting DINO performance :t-rex: 186 | You can improve the performance of the vanilla run by: 187 | - training for more epochs: `--epochs 300`, 188 | - increasing the teacher temperature: `--teacher_temp 0.07 --warmup_teacher_temp_epochs 30`. 189 | - removing last layer normalization (only safe with `--arch vit_small`): `--norm_last_layer false`, 190 | 191 |
192 | 193 | Full command. 194 | 195 | 196 | ``` 197 | python run_with_submitit.py --arch vit_small --epochs 300 --teacher_temp 0.07 --warmup_teacher_temp_epochs 30 --norm_last_layer false --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir 198 | ``` 199 | 200 |
201 | 202 | The resulting pretrained model should reach 73.3% on k-NN eval and 76.0% on linear eval. Training time is 2.6 days with 16 GPUs. We provide [training](https://dl.fbaipublicfiles.com/dino/example_runs_logs/dino_boost_deitsmall16_log.txt) and [linear evaluation](https://dl.fbaipublicfiles.com/dino/example_runs_logs/dino_boost_deitsmall16_eval.txt) logs (with batch size 256 at evaluation time) for this run to help reproducibility. 203 | 204 | ### ResNet-50 and other convnets trainings 205 | This code also works for training DINO on convolutional networks, like ResNet-50 for example. We highly recommend to adapt some optimization arguments in this case. For example following is a command to train DINO on ResNet-50 on a single node with 8 GPUs for 100 epochs. We provide [training logs](https://dl.fbaipublicfiles.com/dino/example_runs_logs/dino_rn50_log.txt) and [final checkpoint](https://dl.fbaipublicfiles.com/dino/example_runs_logs/dino_rn50_checkpoint.pth) for this run. 206 | ``` 207 | python -m torch.distributed.launch --nproc_per_node=8 main_dino.py --arch resnet50 --optimizer sgd --lr 0.03 --weight_decay 1e-4 --weight_decay_end 1e-4 --global_crops_scale 0.14 1 --local_crops_scale 0.05 0.14 --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir 208 | ``` 209 | 210 | ## Self-attention visualization 211 | You can look at the self-attention of the [CLS] token on the different heads of the last layer by running: 212 | ``` 213 | python visualize_attention.py 214 | ``` 215 | 216 |
217 | Self-attention from a Vision Transformer with 8x8 patches trained with DINO 218 |
219 | 220 | ## Self-attention video generation 221 | You can generate videos like the one on the blog post with `video_generation.py`. 222 | 223 | https://user-images.githubusercontent.com/46140458/116817761-47885e80-ab68-11eb-9975-d61d5a919e13.mp4 224 | 225 | Extract frames from input video and generate attention video: 226 | ``` 227 | python video_generation.py --pretrained_weights dino_deitsmall8_pretrain.pth \ 228 | --input_path input/video.mp4 \ 229 | --output_path output/ \ 230 | --fps 25 231 | ``` 232 | 233 | Use folder of frames already extracted and generate attention video: 234 | ``` 235 | python video_generation.py --pretrained_weights dino_deitsmall8_pretrain.pth \ 236 | --input_path output/frames/ \ 237 | --output_path output/ \ 238 | --resize 256 \ 239 | ``` 240 | 241 | Only generate video from folder of attention maps images: 242 | ``` 243 | python video_generation.py --input_path output/attention \ 244 | --output_path output/ \ 245 | --video_only \ 246 | --video_format avi 247 | ``` 248 | 249 | 250 | ## Evaluation: k-NN classification on ImageNet 251 | To evaluate a simple k-NN classifier with a single GPU on a pre-trained model, run: 252 | ``` 253 | python -m torch.distributed.launch --nproc_per_node=1 eval_knn.py --data_path /path/to/imagenet 254 | ``` 255 | If you choose not to specify `--pretrained_weights`, then DINO reference weights are used by default. If you want instead to evaluate checkpoints from a run of your own, you can run for example: 256 | ``` 257 | python -m torch.distributed.launch --nproc_per_node=1 eval_knn.py --pretrained_weights /path/to/checkpoint.pth --checkpoint_key teacher --data_path /path/to/imagenet 258 | ``` 259 | 260 | ## Evaluation: Linear classification on ImageNet 261 | To train a supervised linear classifier on frozen weights on a single node with 8 gpus, run: 262 | ``` 263 | python -m torch.distributed.launch --nproc_per_node=8 eval_linear.py --data_path /path/to/imagenet 264 | ``` 265 | 266 | We release the logs and weights from evaluating the different models: 267 | 268 | 269 | 270 | 271 | 272 | 273 | 274 | 275 | 276 | 277 | 278 | 279 | 280 | 281 | 282 | 283 | 284 | 285 | 286 | 287 | 288 | 289 | 290 | 291 | 292 | 293 | 294 | 295 | 296 | 297 | 298 | 299 | 300 | 301 | 302 | 303 | 304 | 305 | 306 | 307 | 308 | 309 | 310 | 311 | 312 | 313 | 314 | 315 | 316 | 317 | 318 | 319 | 320 | 321 | 322 | 323 | 324 | 325 | 326 | 327 | 328 |
archtop-1 ImageNetlinear evaluation
ViT-S/1677.0%linear weightslogs
ViT-S/879.7%linear weightslogs
ViT-B/1678.2%linear weightslogs
ViT-B/880.1%linear weightslogs
xcit_small_12_p1677.8%linear weightslogs
xcit_small_12_p879.2%linear weightslogs
xcit_medium_24_p1678.8%linear weightslogs
xcit_medium_24_p880.3%linear weightslogs
ResNet-5075.3%linear weightslogs
329 | 330 | You can check the performance of the pretrained weights on ImageNet validation set by running the following command lines: 331 | ``` 332 | python eval_linear.py --evaluate --arch vit_small --patch_size 16 --data_path /path/to/imagenet/train 333 | ``` 334 | 335 | ``` 336 | python eval_linear.py --evaluate --arch vit_small --patch_size 8 --data_path /path/to/imagenet/train 337 | ``` 338 | 339 | ``` 340 | python eval_linear.py --evaluate --arch vit_base --patch_size 16 --n_last_blocks 1 --avgpool_patchtokens true --data_path /path/to/imagenet/train 341 | ``` 342 | 343 | ``` 344 | python eval_linear.py --evaluate --arch vit_base --patch_size 8 --n_last_blocks 1 --avgpool_patchtokens true --data_path /path/to/imagenet/train 345 | ``` 346 | 347 | ``` 348 | python eval_linear.py --evaluate --arch resnet50 --data_path /path/to/imagenet/train 349 | ``` 350 | 351 | ## Evaluation: DAVIS 2017 Video object segmentation 352 | Please verify that you're using pytorch version 1.7.1 since we are not able to reproduce the results with most recent pytorch 1.8.1 at the moment. 353 | 354 | **Step 1: Prepare DAVIS 2017 data** 355 | ``` 356 | cd $HOME 357 | git clone https://github.com/davisvideochallenge/davis-2017 && cd davis-2017 358 | ./data/get_davis.sh 359 | ``` 360 | 361 | **Step 2: Video object segmentation** 362 | ``` 363 | python eval_video_segmentation.py --data_path $HOME/davis-2017/DAVIS/ --output_dir /path/to/saving_dir 364 | ``` 365 | 366 | **Step 3: Evaluate the obtained segmentation** 367 | ``` 368 | git clone https://github.com/davisvideochallenge/davis2017-evaluation $HOME/davis2017-evaluation 369 | python $HOME/davis2017-evaluation/evaluation_method.py --task semi-supervised --results_path /path/to/saving_dir --davis_path $HOME/davis-2017/DAVIS/ 370 | ``` 371 | 372 | ## Evaluation: Image Retrieval on revisited Oxford and Paris 373 | Step 1: Prepare revisited Oxford and Paris by following [this repo](https://github.com/filipradenovic/revisitop). 374 | 375 | Step 2: Image retrieval (if you do not specify weights with `--pretrained_weights` then by default [DINO weights pretrained on Google Landmark v2 dataset](https://dl.fbaipublicfiles.com/dino/dino_vitsmall16_googlelandmark_pretrain/dino_vitsmall16_googlelandmark_pretrain.pth) will be used). 376 | 377 | Paris: 378 | ``` 379 | python -m torch.distributed.launch --use_env --nproc_per_node=1 eval_image_retrieval.py --imsize 512 --multiscale 1 --data_path /path/to/revisited_paris_oxford/ --dataset rparis6k 380 | ``` 381 | 382 | Oxford: 383 | ``` 384 | python -m torch.distributed.launch --use_env --nproc_per_node=1 eval_image_retrieval.py --imsize 224 --multiscale 0 --data_path /path/to/revisited_paris_oxford/ --dataset roxford5k 385 | ``` 386 | 387 | ## Evaluation: Copy detection on Copydays 388 | Step 1: Prepare [Copydays dataset](https://lear.inrialpes.fr/~jegou/data.php#copydays). 389 | 390 | Step 2 (opt): Prepare a set of image distractors and a set of images on which to learn the whitening operator. 391 | In our paper, we use 10k random images from YFCC100M as distractors and 20k random images from YFCC100M (different from the distractors) for computing the whitening operation. 392 | 393 | Step 3: Run copy detection: 394 | ``` 395 | python -m torch.distributed.launch --use_env --nproc_per_node=1 eval_copy_detection.py --data_path /path/to/copydays/ --whitening_path /path/to/whitening_data/ --distractors_path /path/to/distractors/ 396 | ``` 397 | We report result on the strong subset. For example in the stdout from the command above we get: `eval on strong mAP=0.858`. 398 | 399 | ## License 400 | This repository is released under the Apache 2.0 license as found in the [LICENSE](LICENSE) file. 401 | 402 | ## Citation 403 | If you find this repository useful, please consider giving a star :star: and citation :t-rex:: 404 | ``` 405 | @inproceedings{caron2021emerging, 406 | title={Emerging Properties in Self-Supervised Vision Transformers}, 407 | author={Caron, Mathilde and Touvron, Hugo and Misra, Ishan and J\'egou, Herv\'e and Mairal, Julien and Bojanowski, Piotr and Joulin, Armand}, 408 | booktitle={Proceedings of the International Conference on Computer Vision (ICCV)}, 409 | year={2021} 410 | } 411 | ``` 412 | -------------------------------------------------------------------------------- /eval_copy_detection.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import os 15 | import sys 16 | import pickle 17 | import argparse 18 | 19 | import torch 20 | from torch import nn 21 | import torch.distributed as dist 22 | import torch.backends.cudnn as cudnn 23 | from torchvision import models as torchvision_models 24 | from torchvision import transforms as pth_transforms 25 | from PIL import Image, ImageFile 26 | import numpy as np 27 | 28 | import utils 29 | import vision_transformer as vits 30 | from eval_knn import extract_features 31 | 32 | 33 | class CopydaysDataset(): 34 | def __init__(self, basedir): 35 | self.basedir = basedir 36 | self.block_names = ( 37 | ['original', 'strong'] + 38 | ['jpegqual/%d' % i for i in 39 | [3, 5, 8, 10, 15, 20, 30, 50, 75]] + 40 | ['crops/%d' % i for i in 41 | [10, 15, 20, 30, 40, 50, 60, 70, 80]]) 42 | self.nblocks = len(self.block_names) 43 | 44 | self.query_blocks = range(self.nblocks) 45 | self.q_block_sizes = np.ones(self.nblocks, dtype=int) * 157 46 | self.q_block_sizes[1] = 229 47 | # search only among originals 48 | self.database_blocks = [0] 49 | 50 | def get_block(self, i): 51 | dirname = self.basedir + '/' + self.block_names[i] 52 | fnames = [dirname + '/' + fname 53 | for fname in sorted(os.listdir(dirname)) 54 | if fname.endswith('.jpg')] 55 | return fnames 56 | 57 | def get_block_filenames(self, subdir_name): 58 | dirname = self.basedir + '/' + subdir_name 59 | return [fname 60 | for fname in sorted(os.listdir(dirname)) 61 | if fname.endswith('.jpg')] 62 | 63 | def eval_result(self, ids, distances): 64 | j0 = 0 65 | for i in range(self.nblocks): 66 | j1 = j0 + self.q_block_sizes[i] 67 | block_name = self.block_names[i] 68 | I = ids[j0:j1] # block size 69 | sum_AP = 0 70 | if block_name != 'strong': 71 | # 1:1 mapping of files to names 72 | positives_per_query = [[i] for i in range(j1 - j0)] 73 | else: 74 | originals = self.get_block_filenames('original') 75 | strongs = self.get_block_filenames('strong') 76 | 77 | # check if prefixes match 78 | positives_per_query = [ 79 | [j for j, bname in enumerate(originals) 80 | if bname[:4] == qname[:4]] 81 | for qname in strongs] 82 | 83 | for qno, Iline in enumerate(I): 84 | positives = positives_per_query[qno] 85 | ranks = [] 86 | for rank, bno in enumerate(Iline): 87 | if bno in positives: 88 | ranks.append(rank) 89 | sum_AP += score_ap_from_ranks_1(ranks, len(positives)) 90 | 91 | print("eval on %s mAP=%.3f" % ( 92 | block_name, sum_AP / (j1 - j0))) 93 | j0 = j1 94 | 95 | 96 | # from the Holidays evaluation package 97 | def score_ap_from_ranks_1(ranks, nres): 98 | """ Compute the average precision of one search. 99 | ranks = ordered list of ranks of true positives 100 | nres = total number of positives in dataset 101 | """ 102 | 103 | # accumulate trapezoids in PR-plot 104 | ap = 0.0 105 | 106 | # All have an x-size of: 107 | recall_step = 1.0 / nres 108 | 109 | for ntp, rank in enumerate(ranks): 110 | 111 | # y-size on left side of trapezoid: 112 | # ntp = nb of true positives so far 113 | # rank = nb of retrieved items so far 114 | if rank == 0: 115 | precision_0 = 1.0 116 | else: 117 | precision_0 = ntp / float(rank) 118 | 119 | # y-size on right side of trapezoid: 120 | # ntp and rank are increased by one 121 | precision_1 = (ntp + 1) / float(rank + 1) 122 | 123 | ap += (precision_1 + precision_0) * recall_step / 2.0 124 | 125 | return ap 126 | 127 | 128 | class ImgListDataset(torch.utils.data.Dataset): 129 | def __init__(self, img_list, transform=None): 130 | self.samples = img_list 131 | self.transform = transform 132 | 133 | def __getitem__(self, i): 134 | with open(self.samples[i], 'rb') as f: 135 | img = Image.open(f) 136 | img = img.convert('RGB') 137 | if self.transform is not None: 138 | img = self.transform(img) 139 | return img, i 140 | 141 | def __len__(self): 142 | return len(self.samples) 143 | 144 | 145 | def is_image_file(s): 146 | ext = s.split(".")[-1] 147 | if ext in ['jpg', 'jpeg', 'png', 'ppm', 'bmp', 'pgm', 'tif', 'tiff', 'webp']: 148 | return True 149 | return False 150 | 151 | 152 | @torch.no_grad() 153 | def extract_features(image_list, model, args): 154 | transform = pth_transforms.Compose([ 155 | pth_transforms.Resize((args.imsize, args.imsize), interpolation=3), 156 | pth_transforms.ToTensor(), 157 | pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 158 | ]) 159 | tempdataset = ImgListDataset(image_list, transform=transform) 160 | data_loader = torch.utils.data.DataLoader(tempdataset, batch_size=args.batch_size_per_gpu, 161 | num_workers=args.num_workers, drop_last=False, 162 | sampler=torch.utils.data.DistributedSampler(tempdataset, shuffle=False)) 163 | features = None 164 | for samples, index in utils.MetricLogger(delimiter=" ").log_every(data_loader, 10): 165 | samples, index = samples.cuda(non_blocking=True), index.cuda(non_blocking=True) 166 | feats = model.get_intermediate_layers(samples, n=1)[0].clone() 167 | 168 | cls_output_token = feats[:, 0, :] # [CLS] token 169 | # GeM with exponent 4 for output patch tokens 170 | b, h, w, d = len(samples), int(samples.shape[-2] / model.patch_embed.patch_size), int(samples.shape[-1] / model.patch_embed.patch_size), feats.shape[-1] 171 | feats = feats[:, 1:, :].reshape(b, h, w, d) 172 | feats = feats.clamp(min=1e-6).permute(0, 3, 1, 2) 173 | feats = nn.functional.avg_pool2d(feats.pow(4), (h, w)).pow(1. / 4).reshape(b, -1) 174 | # concatenate [CLS] token and GeM pooled patch tokens 175 | feats = torch.cat((cls_output_token, feats), dim=1) 176 | 177 | # init storage feature matrix 178 | if dist.get_rank() == 0 and features is None: 179 | features = torch.zeros(len(data_loader.dataset), feats.shape[-1]) 180 | if args.use_cuda: 181 | features = features.cuda(non_blocking=True) 182 | 183 | # get indexes from all processes 184 | y_all = torch.empty(dist.get_world_size(), index.size(0), dtype=index.dtype, device=index.device) 185 | y_l = list(y_all.unbind(0)) 186 | y_all_reduce = torch.distributed.all_gather(y_l, index, async_op=True) 187 | y_all_reduce.wait() 188 | index_all = torch.cat(y_l) 189 | 190 | # share features between processes 191 | feats_all = torch.empty(dist.get_world_size(), feats.size(0), feats.size(1), 192 | dtype=feats.dtype, device=feats.device) 193 | output_l = list(feats_all.unbind(0)) 194 | output_all_reduce = torch.distributed.all_gather(output_l, feats, async_op=True) 195 | output_all_reduce.wait() 196 | 197 | # update storage feature matrix 198 | if dist.get_rank() == 0: 199 | if args.use_cuda: 200 | features.index_copy_(0, index_all, torch.cat(output_l)) 201 | else: 202 | features.index_copy_(0, index_all.cpu(), torch.cat(output_l).cpu()) 203 | return features # features is still None for every rank which is not 0 (main) 204 | 205 | 206 | if __name__ == '__main__': 207 | parser = argparse.ArgumentParser('Copy detection on Copydays') 208 | parser.add_argument('--data_path', default='/path/to/copydays/', type=str, 209 | help="See https://lear.inrialpes.fr/~jegou/data.php#copydays") 210 | parser.add_argument('--whitening_path', default='/path/to/whitening_data/', type=str, 211 | help="""Path to directory with images used for computing the whitening operator. 212 | In our paper, we use 20k random images from YFCC100M.""") 213 | parser.add_argument('--distractors_path', default='/path/to/distractors/', type=str, 214 | help="Path to directory with distractors images. In our paper, we use 10k random images from YFCC100M.") 215 | parser.add_argument('--imsize', default=320, type=int, help='Image size (square image)') 216 | parser.add_argument('--batch_size_per_gpu', default=16, type=int, help='Per-GPU batch-size') 217 | parser.add_argument('--pretrained_weights', default='', type=str, help="Path to pretrained weights to evaluate.") 218 | parser.add_argument('--use_cuda', default=True, type=utils.bool_flag) 219 | parser.add_argument('--arch', default='vit_base', type=str, help='Architecture') 220 | parser.add_argument('--patch_size', default=8, type=int, help='Patch resolution of the model.') 221 | parser.add_argument("--checkpoint_key", default="teacher", type=str, 222 | help='Key to use in the checkpoint (example: "teacher")') 223 | parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.') 224 | parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up 225 | distributed training; see https://pytorch.org/docs/stable/distributed.html""") 226 | parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.") 227 | args = parser.parse_args() 228 | 229 | utils.init_distributed_mode(args) 230 | print("git:\n {}\n".format(utils.get_sha())) 231 | print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))) 232 | cudnn.benchmark = True 233 | 234 | # ============ building network ... ============ 235 | if "vit" in args.arch: 236 | model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0) 237 | print(f"Model {args.arch} {args.patch_size}x{args.patch_size} built.") 238 | else: 239 | print(f"Architecture {args.arch} non supported") 240 | sys.exit(1) 241 | if args.use_cuda: 242 | model.cuda() 243 | model.eval() 244 | utils.load_pretrained_weights(model, args.pretrained_weights, args.checkpoint_key, args.arch, args.patch_size) 245 | 246 | dataset = CopydaysDataset(args.data_path) 247 | 248 | # ============ Extract features ... ============ 249 | # extract features for queries 250 | queries = [] 251 | for q in dataset.query_blocks: 252 | queries.append(extract_features(dataset.get_block(q), model, args)) 253 | if utils.get_rank() == 0: 254 | queries = torch.cat(queries) 255 | print(f"Extraction of queries features done. Shape: {queries.shape}") 256 | 257 | # extract features for database 258 | database = [] 259 | for b in dataset.database_blocks: 260 | database.append(extract_features(dataset.get_block(b), model, args)) 261 | 262 | # extract features for distractors 263 | if os.path.isdir(args.distractors_path): 264 | print("Using distractors...") 265 | list_distractors = [os.path.join(args.distractors_path, s) for s in os.listdir(args.distractors_path) if is_image_file(s)] 266 | database.append(extract_features(list_distractors, model, args)) 267 | if utils.get_rank() == 0: 268 | database = torch.cat(database) 269 | print(f"Extraction of database and distractors features done. Shape: {database.shape}") 270 | 271 | # ============ Whitening ... ============ 272 | if os.path.isdir(args.whitening_path): 273 | print(f"Extracting features on images from {args.whitening_path} for learning the whitening operator.") 274 | list_whit = [os.path.join(args.whitening_path, s) for s in os.listdir(args.whitening_path) if is_image_file(s)] 275 | features_for_whitening = extract_features(list_whit, model, args) 276 | if utils.get_rank() == 0: 277 | # center 278 | mean_feature = torch.mean(features_for_whitening, dim=0) 279 | database -= mean_feature 280 | queries -= mean_feature 281 | pca = utils.PCA(dim=database.shape[-1], whit=0.5) 282 | # compute covariance 283 | cov = torch.mm(features_for_whitening.T, features_for_whitening) / features_for_whitening.shape[0] 284 | pca.train_pca(cov.cpu().numpy()) 285 | database = pca.apply(database) 286 | queries = pca.apply(queries) 287 | 288 | # ============ Copy detection ... ============ 289 | if utils.get_rank() == 0: 290 | # l2 normalize the features 291 | database = nn.functional.normalize(database, dim=1, p=2) 292 | queries = nn.functional.normalize(queries, dim=1, p=2) 293 | 294 | # similarity 295 | similarity = torch.mm(queries, database.T) 296 | distances, indices = similarity.topk(20, largest=True, sorted=True) 297 | 298 | # evaluate 299 | retrieved = dataset.eval_result(indices, distances) 300 | dist.barrier() 301 | 302 | -------------------------------------------------------------------------------- /eval_image_retrieval.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import os 15 | import sys 16 | import pickle 17 | import argparse 18 | 19 | import torch 20 | from torch import nn 21 | import torch.distributed as dist 22 | import torch.backends.cudnn as cudnn 23 | from torchvision import models as torchvision_models 24 | from torchvision import transforms as pth_transforms 25 | from PIL import Image, ImageFile 26 | import numpy as np 27 | 28 | import utils 29 | import vision_transformer as vits 30 | from eval_knn import extract_features 31 | 32 | 33 | class OxfordParisDataset(torch.utils.data.Dataset): 34 | def __init__(self, dir_main, dataset, split, transform=None, imsize=None): 35 | if dataset not in ['roxford5k', 'rparis6k']: 36 | raise ValueError('Unknown dataset: {}!'.format(dataset)) 37 | 38 | # loading imlist, qimlist, and gnd, in cfg as a dict 39 | gnd_fname = os.path.join(dir_main, dataset, 'gnd_{}.pkl'.format(dataset)) 40 | with open(gnd_fname, 'rb') as f: 41 | cfg = pickle.load(f) 42 | cfg['gnd_fname'] = gnd_fname 43 | cfg['ext'] = '.jpg' 44 | cfg['qext'] = '.jpg' 45 | cfg['dir_data'] = os.path.join(dir_main, dataset) 46 | cfg['dir_images'] = os.path.join(cfg['dir_data'], 'jpg') 47 | cfg['n'] = len(cfg['imlist']) 48 | cfg['nq'] = len(cfg['qimlist']) 49 | cfg['im_fname'] = config_imname 50 | cfg['qim_fname'] = config_qimname 51 | cfg['dataset'] = dataset 52 | self.cfg = cfg 53 | 54 | self.samples = cfg["qimlist"] if split == "query" else cfg["imlist"] 55 | self.transform = transform 56 | self.imsize = imsize 57 | 58 | def __len__(self): 59 | return len(self.samples) 60 | 61 | def __getitem__(self, index): 62 | path = os.path.join(self.cfg["dir_images"], self.samples[index] + ".jpg") 63 | ImageFile.LOAD_TRUNCATED_IMAGES = True 64 | with open(path, 'rb') as f: 65 | img = Image.open(f) 66 | img = img.convert('RGB') 67 | if self.imsize is not None: 68 | img.thumbnail((self.imsize, self.imsize), Image.ANTIALIAS) 69 | if self.transform is not None: 70 | img = self.transform(img) 71 | return img, index 72 | 73 | 74 | def config_imname(cfg, i): 75 | return os.path.join(cfg['dir_images'], cfg['imlist'][i] + cfg['ext']) 76 | 77 | 78 | def config_qimname(cfg, i): 79 | return os.path.join(cfg['dir_images'], cfg['qimlist'][i] + cfg['qext']) 80 | 81 | 82 | if __name__ == '__main__': 83 | parser = argparse.ArgumentParser('Image Retrieval on revisited Paris and Oxford') 84 | parser.add_argument('--data_path', default='/path/to/revisited_paris_oxford/', type=str) 85 | parser.add_argument('--dataset', default='roxford5k', type=str, choices=['roxford5k', 'rparis6k']) 86 | parser.add_argument('--multiscale', default=False, type=utils.bool_flag) 87 | parser.add_argument('--imsize', default=224, type=int, help='Image size') 88 | parser.add_argument('--pretrained_weights', default='', type=str, help="Path to pretrained weights to evaluate.") 89 | parser.add_argument('--use_cuda', default=True, type=utils.bool_flag) 90 | parser.add_argument('--arch', default='vit_small', type=str, help='Architecture') 91 | parser.add_argument('--patch_size', default=16, type=int, help='Patch resolution of the model.') 92 | parser.add_argument("--checkpoint_key", default="teacher", type=str, 93 | help='Key to use in the checkpoint (example: "teacher")') 94 | parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.') 95 | parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up 96 | distributed training; see https://pytorch.org/docs/stable/distributed.html""") 97 | parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.") 98 | args = parser.parse_args() 99 | 100 | utils.init_distributed_mode(args) 101 | print("git:\n {}\n".format(utils.get_sha())) 102 | print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))) 103 | cudnn.benchmark = True 104 | 105 | # ============ preparing data ... ============ 106 | transform = pth_transforms.Compose([ 107 | pth_transforms.ToTensor(), 108 | pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 109 | ]) 110 | dataset_train = OxfordParisDataset(args.data_path, args.dataset, split="train", transform=transform, imsize=args.imsize) 111 | dataset_query = OxfordParisDataset(args.data_path, args.dataset, split="query", transform=transform, imsize=args.imsize) 112 | sampler = torch.utils.data.DistributedSampler(dataset_train, shuffle=False) 113 | data_loader_train = torch.utils.data.DataLoader( 114 | dataset_train, 115 | sampler=sampler, 116 | batch_size=1, 117 | num_workers=args.num_workers, 118 | pin_memory=True, 119 | drop_last=False, 120 | ) 121 | data_loader_query = torch.utils.data.DataLoader( 122 | dataset_query, 123 | batch_size=1, 124 | num_workers=args.num_workers, 125 | pin_memory=True, 126 | drop_last=False, 127 | ) 128 | print(f"train: {len(dataset_train)} imgs / query: {len(dataset_query)} imgs") 129 | 130 | # ============ building network ... ============ 131 | if "vit" in args.arch: 132 | model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0) 133 | print(f"Model {args.arch} {args.patch_size}x{args.patch_size} built.") 134 | elif "xcit" in args.arch: 135 | model = torch.hub.load('facebookresearch/xcit:main', args.arch, num_classes=0) 136 | elif args.arch in torchvision_models.__dict__.keys(): 137 | model = torchvision_models.__dict__[args.arch](num_classes=0) 138 | else: 139 | print(f"Architecture {args.arch} non supported") 140 | sys.exit(1) 141 | if args.use_cuda: 142 | model.cuda() 143 | model.eval() 144 | 145 | # load pretrained weights 146 | if os.path.isfile(args.pretrained_weights): 147 | state_dict = torch.load(args.pretrained_weights, map_location="cpu") 148 | if args.checkpoint_key is not None and args.checkpoint_key in state_dict: 149 | print(f"Take key {args.checkpoint_key} in provided checkpoint dict") 150 | state_dict = state_dict[args.checkpoint_key] 151 | # remove `module.` prefix 152 | state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} 153 | # remove `backbone.` prefix induced by multicrop wrapper 154 | state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} 155 | msg = model.load_state_dict(state_dict, strict=False) 156 | print('Pretrained weights found at {} and loaded with msg: {}'.format(args.pretrained_weights, msg)) 157 | elif args.arch == "vit_small" and args.patch_size == 16: 158 | print("Since no pretrained weights have been provided, we load pretrained DINO weights on Google Landmark v2.") 159 | model.load_state_dict(torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/dino_vitsmall16_googlelandmark_pretrain/dino_vitsmall16_googlelandmark_pretrain.pth")) 160 | else: 161 | print("Warning: We use random weights.") 162 | 163 | ############################################################################ 164 | # Step 1: extract features 165 | train_features = extract_features(model, data_loader_train, args.use_cuda, multiscale=args.multiscale) 166 | query_features = extract_features(model, data_loader_query, args.use_cuda, multiscale=args.multiscale) 167 | 168 | if utils.get_rank() == 0: # only rank 0 will work from now on 169 | # normalize features 170 | train_features = nn.functional.normalize(train_features, dim=1, p=2) 171 | query_features = nn.functional.normalize(query_features, dim=1, p=2) 172 | 173 | ############################################################################ 174 | # Step 2: similarity 175 | sim = torch.mm(train_features, query_features.T) 176 | ranks = torch.argsort(-sim, dim=0).cpu().numpy() 177 | 178 | ############################################################################ 179 | # Step 3: evaluate 180 | gnd = dataset_train.cfg['gnd'] 181 | # evaluate ranks 182 | ks = [1, 5, 10] 183 | # search for easy & hard 184 | gnd_t = [] 185 | for i in range(len(gnd)): 186 | g = {} 187 | g['ok'] = np.concatenate([gnd[i]['easy'], gnd[i]['hard']]) 188 | g['junk'] = np.concatenate([gnd[i]['junk']]) 189 | gnd_t.append(g) 190 | mapM, apsM, mprM, prsM = utils.compute_map(ranks, gnd_t, ks) 191 | # search for hard 192 | gnd_t = [] 193 | for i in range(len(gnd)): 194 | g = {} 195 | g['ok'] = np.concatenate([gnd[i]['hard']]) 196 | g['junk'] = np.concatenate([gnd[i]['junk'], gnd[i]['easy']]) 197 | gnd_t.append(g) 198 | mapH, apsH, mprH, prsH = utils.compute_map(ranks, gnd_t, ks) 199 | print('>> {}: mAP M: {}, H: {}'.format(args.dataset, np.around(mapM*100, decimals=2), np.around(mapH*100, decimals=2))) 200 | print('>> {}: mP@k{} M: {}, H: {}'.format(args.dataset, np.array(ks), np.around(mprM*100, decimals=2), np.around(mprH*100, decimals=2))) 201 | dist.barrier() 202 | -------------------------------------------------------------------------------- /eval_knn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import os 15 | import sys 16 | import argparse 17 | 18 | import torch 19 | from torch import nn 20 | import torch.distributed as dist 21 | import torch.backends.cudnn as cudnn 22 | from torchvision import datasets 23 | from torchvision import transforms as pth_transforms 24 | from torchvision import models as torchvision_models 25 | 26 | import utils 27 | import vision_transformer as vits 28 | 29 | 30 | def extract_feature_pipeline(args): 31 | # ============ preparing data ... ============ 32 | transform = pth_transforms.Compose([ 33 | pth_transforms.Resize(256, interpolation=3), 34 | pth_transforms.CenterCrop(224), 35 | pth_transforms.ToTensor(), 36 | pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 37 | ]) 38 | dataset_train = ReturnIndexDataset(os.path.join(args.data_path, "train"), transform=transform) 39 | dataset_val = ReturnIndexDataset(os.path.join(args.data_path, "val"), transform=transform) 40 | sampler = torch.utils.data.DistributedSampler(dataset_train, shuffle=False) 41 | data_loader_train = torch.utils.data.DataLoader( 42 | dataset_train, 43 | sampler=sampler, 44 | batch_size=args.batch_size_per_gpu, 45 | num_workers=args.num_workers, 46 | pin_memory=True, 47 | drop_last=False, 48 | ) 49 | data_loader_val = torch.utils.data.DataLoader( 50 | dataset_val, 51 | batch_size=args.batch_size_per_gpu, 52 | num_workers=args.num_workers, 53 | pin_memory=True, 54 | drop_last=False, 55 | ) 56 | print(f"Data loaded with {len(dataset_train)} train and {len(dataset_val)} val imgs.") 57 | 58 | # ============ building network ... ============ 59 | if "vit" in args.arch: 60 | model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0) 61 | print(f"Model {args.arch} {args.patch_size}x{args.patch_size} built.") 62 | elif "xcit" in args.arch: 63 | model = torch.hub.load('facebookresearch/xcit:main', args.arch, num_classes=0) 64 | elif args.arch in torchvision_models.__dict__.keys(): 65 | model = torchvision_models.__dict__[args.arch](num_classes=0) 66 | model.fc = nn.Identity() 67 | else: 68 | print(f"Architecture {args.arch} non supported") 69 | sys.exit(1) 70 | model.cuda() 71 | utils.load_pretrained_weights(model, args.pretrained_weights, args.checkpoint_key, args.arch, args.patch_size) 72 | model.eval() 73 | 74 | # ============ extract features ... ============ 75 | print("Extracting features for train set...") 76 | train_features = extract_features(model, data_loader_train, args.use_cuda) 77 | print("Extracting features for val set...") 78 | test_features = extract_features(model, data_loader_val, args.use_cuda) 79 | 80 | if utils.get_rank() == 0: 81 | train_features = nn.functional.normalize(train_features, dim=1, p=2) 82 | test_features = nn.functional.normalize(test_features, dim=1, p=2) 83 | 84 | train_labels = torch.tensor([s[-1] for s in dataset_train.samples]).long() 85 | test_labels = torch.tensor([s[-1] for s in dataset_val.samples]).long() 86 | # save features and labels 87 | if args.dump_features and dist.get_rank() == 0: 88 | torch.save(train_features.cpu(), os.path.join(args.dump_features, "trainfeat.pth")) 89 | torch.save(test_features.cpu(), os.path.join(args.dump_features, "testfeat.pth")) 90 | torch.save(train_labels.cpu(), os.path.join(args.dump_features, "trainlabels.pth")) 91 | torch.save(test_labels.cpu(), os.path.join(args.dump_features, "testlabels.pth")) 92 | return train_features, test_features, train_labels, test_labels 93 | 94 | 95 | @torch.no_grad() 96 | def extract_features(model, data_loader, use_cuda=True, multiscale=False): 97 | metric_logger = utils.MetricLogger(delimiter=" ") 98 | features = None 99 | for samples, index in metric_logger.log_every(data_loader, 10): 100 | samples = samples.cuda(non_blocking=True) 101 | index = index.cuda(non_blocking=True) 102 | if multiscale: 103 | feats = utils.multi_scale(samples, model) 104 | else: 105 | feats = model(samples).clone() 106 | 107 | # init storage feature matrix 108 | if dist.get_rank() == 0 and features is None: 109 | features = torch.zeros(len(data_loader.dataset), feats.shape[-1]) 110 | if use_cuda: 111 | features = features.cuda(non_blocking=True) 112 | print(f"Storing features into tensor of shape {features.shape}") 113 | 114 | # get indexes from all processes 115 | y_all = torch.empty(dist.get_world_size(), index.size(0), dtype=index.dtype, device=index.device) 116 | y_l = list(y_all.unbind(0)) 117 | y_all_reduce = torch.distributed.all_gather(y_l, index, async_op=True) 118 | y_all_reduce.wait() 119 | index_all = torch.cat(y_l) 120 | 121 | # share features between processes 122 | feats_all = torch.empty( 123 | dist.get_world_size(), 124 | feats.size(0), 125 | feats.size(1), 126 | dtype=feats.dtype, 127 | device=feats.device, 128 | ) 129 | output_l = list(feats_all.unbind(0)) 130 | output_all_reduce = torch.distributed.all_gather(output_l, feats, async_op=True) 131 | output_all_reduce.wait() 132 | 133 | # update storage feature matrix 134 | if dist.get_rank() == 0: 135 | if use_cuda: 136 | features.index_copy_(0, index_all, torch.cat(output_l)) 137 | else: 138 | features.index_copy_(0, index_all.cpu(), torch.cat(output_l).cpu()) 139 | return features 140 | 141 | 142 | @torch.no_grad() 143 | def knn_classifier(train_features, train_labels, test_features, test_labels, k, T, num_classes=1000): 144 | top1, top5, total = 0.0, 0.0, 0 145 | train_features = train_features.t() 146 | num_test_images, num_chunks = test_labels.shape[0], 100 147 | imgs_per_chunk = num_test_images // num_chunks 148 | retrieval_one_hot = torch.zeros(k, num_classes).to(train_features.device) 149 | for idx in range(0, num_test_images, imgs_per_chunk): 150 | # get the features for test images 151 | features = test_features[ 152 | idx : min((idx + imgs_per_chunk), num_test_images), : 153 | ] 154 | targets = test_labels[idx : min((idx + imgs_per_chunk), num_test_images)] 155 | batch_size = targets.shape[0] 156 | 157 | # calculate the dot product and compute top-k neighbors 158 | similarity = torch.mm(features, train_features) 159 | distances, indices = similarity.topk(k, largest=True, sorted=True) 160 | candidates = train_labels.view(1, -1).expand(batch_size, -1) 161 | retrieved_neighbors = torch.gather(candidates, 1, indices) 162 | 163 | retrieval_one_hot.resize_(batch_size * k, num_classes).zero_() 164 | retrieval_one_hot.scatter_(1, retrieved_neighbors.view(-1, 1), 1) 165 | distances_transform = distances.clone().div_(T).exp_() 166 | probs = torch.sum( 167 | torch.mul( 168 | retrieval_one_hot.view(batch_size, -1, num_classes), 169 | distances_transform.view(batch_size, -1, 1), 170 | ), 171 | 1, 172 | ) 173 | _, predictions = probs.sort(1, True) 174 | 175 | # find the predictions that match the target 176 | correct = predictions.eq(targets.data.view(-1, 1)) 177 | top1 = top1 + correct.narrow(1, 0, 1).sum().item() 178 | top5 = top5 + correct.narrow(1, 0, min(5, k)).sum().item() # top5 does not make sense if k < 5 179 | total += targets.size(0) 180 | top1 = top1 * 100.0 / total 181 | top5 = top5 * 100.0 / total 182 | return top1, top5 183 | 184 | 185 | class ReturnIndexDataset(datasets.ImageFolder): 186 | def __getitem__(self, idx): 187 | img, lab = super(ReturnIndexDataset, self).__getitem__(idx) 188 | return img, idx 189 | 190 | 191 | if __name__ == '__main__': 192 | parser = argparse.ArgumentParser('Evaluation with weighted k-NN on ImageNet') 193 | parser.add_argument('--batch_size_per_gpu', default=128, type=int, help='Per-GPU batch-size') 194 | parser.add_argument('--nb_knn', default=[10, 20, 100, 200], nargs='+', type=int, 195 | help='Number of NN to use. 20 is usually working the best.') 196 | parser.add_argument('--temperature', default=0.07, type=float, 197 | help='Temperature used in the voting coefficient') 198 | parser.add_argument('--pretrained_weights', default='', type=str, help="Path to pretrained weights to evaluate.") 199 | parser.add_argument('--use_cuda', default=True, type=utils.bool_flag, 200 | help="Should we store the features on GPU? We recommend setting this to False if you encounter OOM") 201 | parser.add_argument('--arch', default='vit_small', type=str, help='Architecture') 202 | parser.add_argument('--patch_size', default=16, type=int, help='Patch resolution of the model.') 203 | parser.add_argument("--checkpoint_key", default="teacher", type=str, 204 | help='Key to use in the checkpoint (example: "teacher")') 205 | parser.add_argument('--dump_features', default=None, 206 | help='Path where to save computed features, empty for no saving') 207 | parser.add_argument('--load_features', default=None, help="""If the features have 208 | already been computed, where to find them.""") 209 | parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.') 210 | parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up 211 | distributed training; see https://pytorch.org/docs/stable/distributed.html""") 212 | parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.") 213 | parser.add_argument('--data_path', default='/path/to/imagenet/', type=str) 214 | args = parser.parse_args() 215 | 216 | utils.init_distributed_mode(args) 217 | print("git:\n {}\n".format(utils.get_sha())) 218 | print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))) 219 | cudnn.benchmark = True 220 | 221 | if args.load_features: 222 | train_features = torch.load(os.path.join(args.load_features, "trainfeat.pth")) 223 | test_features = torch.load(os.path.join(args.load_features, "testfeat.pth")) 224 | train_labels = torch.load(os.path.join(args.load_features, "trainlabels.pth")) 225 | test_labels = torch.load(os.path.join(args.load_features, "testlabels.pth")) 226 | else: 227 | # need to extract features ! 228 | train_features, test_features, train_labels, test_labels = extract_feature_pipeline(args) 229 | 230 | if utils.get_rank() == 0: 231 | if args.use_cuda: 232 | train_features = train_features.cuda() 233 | test_features = test_features.cuda() 234 | train_labels = train_labels.cuda() 235 | test_labels = test_labels.cuda() 236 | 237 | print("Features are ready!\nStart the k-NN classification.") 238 | for k in args.nb_knn: 239 | top1, top5 = knn_classifier(train_features, train_labels, 240 | test_features, test_labels, k, args.temperature) 241 | print(f"{k}-NN classifier result: Top1: {top1}, Top5: {top5}") 242 | dist.barrier() 243 | -------------------------------------------------------------------------------- /eval_linear.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import os 15 | import argparse 16 | import json 17 | from pathlib import Path 18 | 19 | import torch 20 | from torch import nn 21 | import torch.distributed as dist 22 | import torch.backends.cudnn as cudnn 23 | from torchvision import datasets 24 | from torchvision import transforms as pth_transforms 25 | from torchvision import models as torchvision_models 26 | 27 | import utils 28 | import vision_transformer as vits 29 | 30 | 31 | def eval_linear(args): 32 | utils.init_distributed_mode(args) 33 | print("git:\n {}\n".format(utils.get_sha())) 34 | print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))) 35 | cudnn.benchmark = True 36 | 37 | # ============ building network ... ============ 38 | # if the network is a Vision Transformer (i.e. vit_tiny, vit_small, vit_base) 39 | if args.arch in vits.__dict__.keys(): 40 | model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0) 41 | embed_dim = model.embed_dim * (args.n_last_blocks + int(args.avgpool_patchtokens)) 42 | # if the network is a XCiT 43 | elif "xcit" in args.arch: 44 | model = torch.hub.load('facebookresearch/xcit:main', args.arch, num_classes=0) 45 | embed_dim = model.embed_dim 46 | # otherwise, we check if the architecture is in torchvision models 47 | elif args.arch in torchvision_models.__dict__.keys(): 48 | model = torchvision_models.__dict__[args.arch]() 49 | embed_dim = model.fc.weight.shape[1] 50 | model.fc = nn.Identity() 51 | else: 52 | print(f"Unknow architecture: {args.arch}") 53 | sys.exit(1) 54 | model.cuda() 55 | model.eval() 56 | # load weights to evaluate 57 | utils.load_pretrained_weights(model, args.pretrained_weights, args.checkpoint_key, args.arch, args.patch_size) 58 | print(f"Model {args.arch} built.") 59 | 60 | linear_classifier = LinearClassifier(embed_dim, num_labels=args.num_labels) 61 | linear_classifier = linear_classifier.cuda() 62 | linear_classifier = nn.parallel.DistributedDataParallel(linear_classifier, device_ids=[args.gpu]) 63 | 64 | # ============ preparing data ... ============ 65 | val_transform = pth_transforms.Compose([ 66 | pth_transforms.Resize(256, interpolation=3), 67 | pth_transforms.CenterCrop(224), 68 | pth_transforms.ToTensor(), 69 | pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 70 | ]) 71 | dataset_val = datasets.ImageFolder(os.path.join(args.data_path, "val"), transform=val_transform) 72 | val_loader = torch.utils.data.DataLoader( 73 | dataset_val, 74 | batch_size=args.batch_size_per_gpu, 75 | num_workers=args.num_workers, 76 | pin_memory=True, 77 | ) 78 | 79 | if args.evaluate: 80 | utils.load_pretrained_linear_weights(linear_classifier, args.arch, args.patch_size) 81 | test_stats = validate_network(val_loader, model, linear_classifier, args.n_last_blocks, args.avgpool_patchtokens) 82 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 83 | return 84 | 85 | train_transform = pth_transforms.Compose([ 86 | pth_transforms.RandomResizedCrop(224), 87 | pth_transforms.RandomHorizontalFlip(), 88 | pth_transforms.ToTensor(), 89 | pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 90 | ]) 91 | dataset_train = datasets.ImageFolder(os.path.join(args.data_path, "train"), transform=train_transform) 92 | sampler = torch.utils.data.distributed.DistributedSampler(dataset_train) 93 | train_loader = torch.utils.data.DataLoader( 94 | dataset_train, 95 | sampler=sampler, 96 | batch_size=args.batch_size_per_gpu, 97 | num_workers=args.num_workers, 98 | pin_memory=True, 99 | ) 100 | print(f"Data loaded with {len(dataset_train)} train and {len(dataset_val)} val imgs.") 101 | 102 | # set optimizer 103 | optimizer = torch.optim.SGD( 104 | linear_classifier.parameters(), 105 | args.lr * (args.batch_size_per_gpu * utils.get_world_size()) / 256., # linear scaling rule 106 | momentum=0.9, 107 | weight_decay=0, # we do not apply weight decay 108 | ) 109 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, eta_min=0) 110 | 111 | # Optionally resume from a checkpoint 112 | to_restore = {"epoch": 0, "best_acc": 0.} 113 | utils.restart_from_checkpoint( 114 | os.path.join(args.output_dir, "checkpoint.pth.tar"), 115 | run_variables=to_restore, 116 | state_dict=linear_classifier, 117 | optimizer=optimizer, 118 | scheduler=scheduler, 119 | ) 120 | start_epoch = to_restore["epoch"] 121 | best_acc = to_restore["best_acc"] 122 | 123 | for epoch in range(start_epoch, args.epochs): 124 | train_loader.sampler.set_epoch(epoch) 125 | 126 | train_stats = train(model, linear_classifier, optimizer, train_loader, epoch, args.n_last_blocks, args.avgpool_patchtokens) 127 | scheduler.step() 128 | 129 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 130 | 'epoch': epoch} 131 | if epoch % args.val_freq == 0 or epoch == args.epochs - 1: 132 | test_stats = validate_network(val_loader, model, linear_classifier, args.n_last_blocks, args.avgpool_patchtokens) 133 | print(f"Accuracy at epoch {epoch} of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 134 | best_acc = max(best_acc, test_stats["acc1"]) 135 | print(f'Max accuracy so far: {best_acc:.2f}%') 136 | log_stats = {**{k: v for k, v in log_stats.items()}, 137 | **{f'test_{k}': v for k, v in test_stats.items()}} 138 | if utils.is_main_process(): 139 | with (Path(args.output_dir) / "log.txt").open("a") as f: 140 | f.write(json.dumps(log_stats) + "\n") 141 | save_dict = { 142 | "epoch": epoch + 1, 143 | "state_dict": linear_classifier.state_dict(), 144 | "optimizer": optimizer.state_dict(), 145 | "scheduler": scheduler.state_dict(), 146 | "best_acc": best_acc, 147 | } 148 | torch.save(save_dict, os.path.join(args.output_dir, "checkpoint.pth.tar")) 149 | print("Training of the supervised linear classifier on frozen features completed.\n" 150 | "Top-1 test accuracy: {acc:.1f}".format(acc=best_acc)) 151 | 152 | 153 | def train(model, linear_classifier, optimizer, loader, epoch, n, avgpool): 154 | linear_classifier.train() 155 | metric_logger = utils.MetricLogger(delimiter=" ") 156 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 157 | header = 'Epoch: [{}]'.format(epoch) 158 | for (inp, target) in metric_logger.log_every(loader, 20, header): 159 | # move to gpu 160 | inp = inp.cuda(non_blocking=True) 161 | target = target.cuda(non_blocking=True) 162 | 163 | # forward 164 | with torch.no_grad(): 165 | if "vit" in args.arch: 166 | intermediate_output = model.get_intermediate_layers(inp, n) 167 | output = torch.cat([x[:, 0] for x in intermediate_output], dim=-1) 168 | if avgpool: 169 | output = torch.cat((output.unsqueeze(-1), torch.mean(intermediate_output[-1][:, 1:], dim=1).unsqueeze(-1)), dim=-1) 170 | output = output.reshape(output.shape[0], -1) 171 | else: 172 | output = model(inp) 173 | output = linear_classifier(output) 174 | 175 | # compute cross entropy loss 176 | loss = nn.CrossEntropyLoss()(output, target) 177 | 178 | # compute the gradients 179 | optimizer.zero_grad() 180 | loss.backward() 181 | 182 | # step 183 | optimizer.step() 184 | 185 | # log 186 | torch.cuda.synchronize() 187 | metric_logger.update(loss=loss.item()) 188 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 189 | # gather the stats from all processes 190 | metric_logger.synchronize_between_processes() 191 | print("Averaged stats:", metric_logger) 192 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 193 | 194 | 195 | @torch.no_grad() 196 | def validate_network(val_loader, model, linear_classifier, n, avgpool): 197 | linear_classifier.eval() 198 | metric_logger = utils.MetricLogger(delimiter=" ") 199 | header = 'Test:' 200 | for inp, target in metric_logger.log_every(val_loader, 20, header): 201 | # move to gpu 202 | inp = inp.cuda(non_blocking=True) 203 | target = target.cuda(non_blocking=True) 204 | 205 | # forward 206 | with torch.no_grad(): 207 | if "vit" in args.arch: 208 | intermediate_output = model.get_intermediate_layers(inp, n) 209 | output = torch.cat([x[:, 0] for x in intermediate_output], dim=-1) 210 | if avgpool: 211 | output = torch.cat((output.unsqueeze(-1), torch.mean(intermediate_output[-1][:, 1:], dim=1).unsqueeze(-1)), dim=-1) 212 | output = output.reshape(output.shape[0], -1) 213 | else: 214 | output = model(inp) 215 | output = linear_classifier(output) 216 | loss = nn.CrossEntropyLoss()(output, target) 217 | 218 | if linear_classifier.module.num_labels >= 5: 219 | acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) 220 | else: 221 | acc1, = utils.accuracy(output, target, topk=(1,)) 222 | 223 | batch_size = inp.shape[0] 224 | metric_logger.update(loss=loss.item()) 225 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 226 | if linear_classifier.module.num_labels >= 5: 227 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 228 | if linear_classifier.module.num_labels >= 5: 229 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 230 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 231 | else: 232 | print('* Acc@1 {top1.global_avg:.3f} loss {losses.global_avg:.3f}' 233 | .format(top1=metric_logger.acc1, losses=metric_logger.loss)) 234 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 235 | 236 | 237 | class LinearClassifier(nn.Module): 238 | """Linear layer to train on top of frozen features""" 239 | def __init__(self, dim, num_labels=1000): 240 | super(LinearClassifier, self).__init__() 241 | self.num_labels = num_labels 242 | self.linear = nn.Linear(dim, num_labels) 243 | self.linear.weight.data.normal_(mean=0.0, std=0.01) 244 | self.linear.bias.data.zero_() 245 | 246 | def forward(self, x): 247 | # flatten 248 | x = x.view(x.size(0), -1) 249 | 250 | # linear layer 251 | return self.linear(x) 252 | 253 | 254 | if __name__ == '__main__': 255 | parser = argparse.ArgumentParser('Evaluation with linear classification on ImageNet') 256 | parser.add_argument('--n_last_blocks', default=4, type=int, help="""Concatenate [CLS] tokens 257 | for the `n` last blocks. We use `n=4` when evaluating ViT-Small and `n=1` with ViT-Base.""") 258 | parser.add_argument('--avgpool_patchtokens', default=False, type=utils.bool_flag, 259 | help="""Whether ot not to concatenate the global average pooled features to the [CLS] token. 260 | We typically set this to False for ViT-Small and to True with ViT-Base.""") 261 | parser.add_argument('--arch', default='vit_small', type=str, help='Architecture') 262 | parser.add_argument('--patch_size', default=16, type=int, help='Patch resolution of the model.') 263 | parser.add_argument('--pretrained_weights', default='', type=str, help="Path to pretrained weights to evaluate.") 264 | parser.add_argument("--checkpoint_key", default="teacher", type=str, help='Key to use in the checkpoint (example: "teacher")') 265 | parser.add_argument('--epochs', default=100, type=int, help='Number of epochs of training.') 266 | parser.add_argument("--lr", default=0.001, type=float, help="""Learning rate at the beginning of 267 | training (highest LR used during training). The learning rate is linearly scaled 268 | with the batch size, and specified here for a reference batch size of 256. 269 | We recommend tweaking the LR depending on the checkpoint evaluated.""") 270 | parser.add_argument('--batch_size_per_gpu', default=128, type=int, help='Per-GPU batch-size') 271 | parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up 272 | distributed training; see https://pytorch.org/docs/stable/distributed.html""") 273 | parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.") 274 | parser.add_argument('--data_path', default='/path/to/imagenet/', type=str) 275 | parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.') 276 | parser.add_argument('--val_freq', default=1, type=int, help="Epoch frequency for validation.") 277 | parser.add_argument('--output_dir', default=".", help='Path to save logs and checkpoints') 278 | parser.add_argument('--num_labels', default=1000, type=int, help='Number of labels for linear classifier') 279 | parser.add_argument('--evaluate', dest='evaluate', action='store_true', help='evaluate model on validation set') 280 | args = parser.parse_args() 281 | eval_linear(args) 282 | -------------------------------------------------------------------------------- /eval_video_segmentation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Some parts are taken from https://github.com/Liusifei/UVC 16 | """ 17 | import os 18 | import copy 19 | import glob 20 | import queue 21 | from urllib.request import urlopen 22 | import argparse 23 | import numpy as np 24 | from tqdm import tqdm 25 | 26 | import cv2 27 | import torch 28 | import torch.nn as nn 29 | from torch.nn import functional as F 30 | from PIL import Image 31 | from torchvision import transforms 32 | 33 | import utils 34 | import vision_transformer as vits 35 | 36 | 37 | @torch.no_grad() 38 | def eval_video_tracking_davis(args, model, frame_list, video_dir, first_seg, seg_ori, color_palette): 39 | """ 40 | Evaluate tracking on a video given first frame & segmentation 41 | """ 42 | video_folder = os.path.join(args.output_dir, video_dir.split('/')[-1]) 43 | os.makedirs(video_folder, exist_ok=True) 44 | 45 | # The queue stores the n preceeding frames 46 | que = queue.Queue(args.n_last_frames) 47 | 48 | # first frame 49 | frame1, ori_h, ori_w = read_frame(frame_list[0]) 50 | # extract first frame feature 51 | frame1_feat = extract_feature(model, frame1).T # dim x h*w 52 | 53 | # saving first segmentation 54 | out_path = os.path.join(video_folder, "00000.png") 55 | imwrite_indexed(out_path, seg_ori, color_palette) 56 | mask_neighborhood = None 57 | for cnt in tqdm(range(1, len(frame_list))): 58 | frame_tar = read_frame(frame_list[cnt])[0] 59 | 60 | # we use the first segmentation and the n previous ones 61 | used_frame_feats = [frame1_feat] + [pair[0] for pair in list(que.queue)] 62 | used_segs = [first_seg] + [pair[1] for pair in list(que.queue)] 63 | 64 | frame_tar_avg, feat_tar, mask_neighborhood = label_propagation(args, model, frame_tar, used_frame_feats, used_segs, mask_neighborhood) 65 | 66 | # pop out oldest frame if neccessary 67 | if que.qsize() == args.n_last_frames: 68 | que.get() 69 | # push current results into queue 70 | seg = copy.deepcopy(frame_tar_avg) 71 | que.put([feat_tar, seg]) 72 | 73 | # upsampling & argmax 74 | frame_tar_avg = F.interpolate(frame_tar_avg, scale_factor=args.patch_size, mode='bilinear', align_corners=False, recompute_scale_factor=False)[0] 75 | frame_tar_avg = norm_mask(frame_tar_avg) 76 | _, frame_tar_seg = torch.max(frame_tar_avg, dim=0) 77 | 78 | # saving to disk 79 | frame_tar_seg = np.array(frame_tar_seg.squeeze().cpu(), dtype=np.uint8) 80 | frame_tar_seg = np.array(Image.fromarray(frame_tar_seg).resize((ori_w, ori_h), 0)) 81 | frame_nm = frame_list[cnt].split('/')[-1].replace(".jpg", ".png") 82 | imwrite_indexed(os.path.join(video_folder, frame_nm), frame_tar_seg, color_palette) 83 | 84 | 85 | def restrict_neighborhood(h, w): 86 | # We restrict the set of source nodes considered to a spatial neighborhood of the query node (i.e. ``local attention'') 87 | mask = torch.zeros(h, w, h, w) 88 | for i in range(h): 89 | for j in range(w): 90 | for p in range(2 * args.size_mask_neighborhood + 1): 91 | for q in range(2 * args.size_mask_neighborhood + 1): 92 | if i - args.size_mask_neighborhood + p < 0 or i - args.size_mask_neighborhood + p >= h: 93 | continue 94 | if j - args.size_mask_neighborhood + q < 0 or j - args.size_mask_neighborhood + q >= w: 95 | continue 96 | mask[i, j, i - args.size_mask_neighborhood + p, j - args.size_mask_neighborhood + q] = 1 97 | 98 | mask = mask.reshape(h * w, h * w) 99 | return mask.cuda(non_blocking=True) 100 | 101 | 102 | def norm_mask(mask): 103 | c, h, w = mask.size() 104 | for cnt in range(c): 105 | mask_cnt = mask[cnt,:,:] 106 | if(mask_cnt.max() > 0): 107 | mask_cnt = (mask_cnt - mask_cnt.min()) 108 | mask_cnt = mask_cnt/mask_cnt.max() 109 | mask[cnt,:,:] = mask_cnt 110 | return mask 111 | 112 | 113 | def label_propagation(args, model, frame_tar, list_frame_feats, list_segs, mask_neighborhood=None): 114 | """ 115 | propagate segs of frames in list_frames to frame_tar 116 | """ 117 | ## we only need to extract feature of the target frame 118 | feat_tar, h, w = extract_feature(model, frame_tar, return_h_w=True) 119 | 120 | return_feat_tar = feat_tar.T # dim x h*w 121 | 122 | ncontext = len(list_frame_feats) 123 | feat_sources = torch.stack(list_frame_feats) # nmb_context x dim x h*w 124 | 125 | feat_tar = F.normalize(feat_tar, dim=1, p=2) 126 | feat_sources = F.normalize(feat_sources, dim=1, p=2) 127 | 128 | feat_tar = feat_tar.unsqueeze(0).repeat(ncontext, 1, 1) 129 | aff = torch.exp(torch.bmm(feat_tar, feat_sources) / 0.1) # nmb_context x h*w (tar: query) x h*w (source: keys) 130 | 131 | if args.size_mask_neighborhood > 0: 132 | if mask_neighborhood is None: 133 | mask_neighborhood = restrict_neighborhood(h, w) 134 | mask_neighborhood = mask_neighborhood.unsqueeze(0).repeat(ncontext, 1, 1) 135 | aff *= mask_neighborhood 136 | 137 | aff = aff.transpose(2, 1).reshape(-1, h * w) # nmb_context*h*w (source: keys) x h*w (tar: queries) 138 | tk_val, _ = torch.topk(aff, dim=0, k=args.topk) 139 | tk_val_min, _ = torch.min(tk_val, dim=0) 140 | aff[aff < tk_val_min] = 0 141 | 142 | aff = aff / torch.sum(aff, keepdim=True, axis=0) 143 | 144 | list_segs = [s.cuda() for s in list_segs] 145 | segs = torch.cat(list_segs) 146 | nmb_context, C, h, w = segs.shape 147 | segs = segs.reshape(nmb_context, C, -1).transpose(2, 1).reshape(-1, C).T # C x nmb_context*h*w 148 | seg_tar = torch.mm(segs, aff) 149 | seg_tar = seg_tar.reshape(1, C, h, w) 150 | return seg_tar, return_feat_tar, mask_neighborhood 151 | 152 | 153 | def extract_feature(model, frame, return_h_w=False): 154 | """Extract one frame feature everytime.""" 155 | out = model.get_intermediate_layers(frame.unsqueeze(0).cuda(), n=1)[0] 156 | out = out[:, 1:, :] # we discard the [CLS] token 157 | h, w = int(frame.shape[1] / model.patch_embed.patch_size), int(frame.shape[2] / model.patch_embed.patch_size) 158 | dim = out.shape[-1] 159 | out = out[0].reshape(h, w, dim) 160 | out = out.reshape(-1, dim) 161 | if return_h_w: 162 | return out, h, w 163 | return out 164 | 165 | 166 | def imwrite_indexed(filename, array, color_palette): 167 | """ Save indexed png for DAVIS.""" 168 | if np.atleast_3d(array).shape[2] != 1: 169 | raise Exception("Saving indexed PNGs requires 2D array.") 170 | 171 | im = Image.fromarray(array) 172 | im.putpalette(color_palette.ravel()) 173 | im.save(filename, format='PNG') 174 | 175 | 176 | def to_one_hot(y_tensor, n_dims=None): 177 | """ 178 | Take integer y (tensor or variable) with n dims & 179 | convert it to 1-hot representation with n+1 dims. 180 | """ 181 | if(n_dims is None): 182 | n_dims = int(y_tensor.max()+ 1) 183 | _,h,w = y_tensor.size() 184 | y_tensor = y_tensor.type(torch.LongTensor).view(-1, 1) 185 | n_dims = n_dims if n_dims is not None else int(torch.max(y_tensor)) + 1 186 | y_one_hot = torch.zeros(y_tensor.size()[0], n_dims).scatter_(1, y_tensor, 1) 187 | y_one_hot = y_one_hot.view(h,w,n_dims) 188 | return y_one_hot.permute(2, 0, 1).unsqueeze(0) 189 | 190 | 191 | def read_frame_list(video_dir): 192 | frame_list = [img for img in glob.glob(os.path.join(video_dir,"*.jpg"))] 193 | frame_list = sorted(frame_list) 194 | return frame_list 195 | 196 | 197 | def read_frame(frame_dir, scale_size=[480]): 198 | """ 199 | read a single frame & preprocess 200 | """ 201 | img = cv2.imread(frame_dir) 202 | ori_h, ori_w, _ = img.shape 203 | if len(scale_size) == 1: 204 | if(ori_h > ori_w): 205 | tw = scale_size[0] 206 | th = (tw * ori_h) / ori_w 207 | th = int((th // 64) * 64) 208 | else: 209 | th = scale_size[0] 210 | tw = (th * ori_w) / ori_h 211 | tw = int((tw // 64) * 64) 212 | else: 213 | th, tw = scale_size 214 | img = cv2.resize(img, (tw, th)) 215 | img = img.astype(np.float32) 216 | img = img / 255.0 217 | img = img[:, :, ::-1] 218 | img = np.transpose(img.copy(), (2, 0, 1)) 219 | img = torch.from_numpy(img).float() 220 | img = color_normalize(img) 221 | return img, ori_h, ori_w 222 | 223 | 224 | def read_seg(seg_dir, factor, scale_size=[480]): 225 | seg = Image.open(seg_dir) 226 | _w, _h = seg.size # note PIL.Image.Image's size is (w, h) 227 | if len(scale_size) == 1: 228 | if(_w > _h): 229 | _th = scale_size[0] 230 | _tw = (_th * _w) / _h 231 | _tw = int((_tw // 64) * 64) 232 | else: 233 | _tw = scale_size[0] 234 | _th = (_tw * _h) / _w 235 | _th = int((_th // 64) * 64) 236 | else: 237 | _th = scale_size[1] 238 | _tw = scale_size[0] 239 | small_seg = np.array(seg.resize((_tw // factor, _th // factor), 0)) 240 | small_seg = torch.from_numpy(small_seg.copy()).contiguous().float().unsqueeze(0) 241 | return to_one_hot(small_seg), np.asarray(seg) 242 | 243 | 244 | def color_normalize(x, mean=[0.485, 0.456, 0.406], std=[0.228, 0.224, 0.225]): 245 | for t, m, s in zip(x, mean, std): 246 | t.sub_(m) 247 | t.div_(s) 248 | return x 249 | 250 | 251 | if __name__ == '__main__': 252 | parser = argparse.ArgumentParser('Evaluation with video object segmentation on DAVIS 2017') 253 | parser.add_argument('--pretrained_weights', default='', type=str, help="Path to pretrained weights to evaluate.") 254 | parser.add_argument('--arch', default='vit_small', type=str, 255 | choices=['vit_tiny', 'vit_small', 'vit_base'], help='Architecture (support only ViT atm).') 256 | parser.add_argument('--patch_size', default=16, type=int, help='Patch resolution of the model.') 257 | parser.add_argument("--checkpoint_key", default="teacher", type=str, help='Key to use in the checkpoint (example: "teacher")') 258 | parser.add_argument('--output_dir', default=".", help='Path where to save segmentations') 259 | parser.add_argument('--data_path', default='/path/to/davis/', type=str) 260 | parser.add_argument("--n_last_frames", type=int, default=7, help="number of preceeding frames") 261 | parser.add_argument("--size_mask_neighborhood", default=12, type=int, 262 | help="We restrict the set of source nodes considered to a spatial neighborhood of the query node") 263 | parser.add_argument("--topk", type=int, default=5, help="accumulate label from top k neighbors") 264 | parser.add_argument("--bs", type=int, default=6, help="Batch size, try to reduce if OOM") 265 | args = parser.parse_args() 266 | 267 | print("git:\n {}\n".format(utils.get_sha())) 268 | print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))) 269 | 270 | # building network 271 | model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0) 272 | print(f"Model {args.arch} {args.patch_size}x{args.patch_size} built.") 273 | model.cuda() 274 | utils.load_pretrained_weights(model, args.pretrained_weights, args.checkpoint_key, args.arch, args.patch_size) 275 | for param in model.parameters(): 276 | param.requires_grad = False 277 | model.eval() 278 | 279 | color_palette = [] 280 | for line in urlopen("https://raw.githubusercontent.com/Liusifei/UVC/master/libs/data/palette.txt"): 281 | color_palette.append([int(i) for i in line.decode("utf-8").split('\n')[0].split(" ")]) 282 | color_palette = np.asarray(color_palette, dtype=np.uint8).reshape(-1,3) 283 | 284 | video_list = open(os.path.join(args.data_path, "ImageSets/2017/val.txt")).readlines() 285 | for i, video_name in enumerate(video_list): 286 | video_name = video_name.strip() 287 | print(f'[{i}/{len(video_list)}] Begin to segmentate video {video_name}.') 288 | video_dir = os.path.join(args.data_path, "JPEGImages/480p/", video_name) 289 | frame_list = read_frame_list(video_dir) 290 | seg_path = frame_list[0].replace("JPEGImages", "Annotations").replace("jpg", "png") 291 | first_seg, seg_ori = read_seg(seg_path, args.patch_size) 292 | eval_video_tracking_davis(args, model, frame_list, video_dir, first_seg, seg_ori, color_palette) 293 | -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import torch 15 | from torchvision.models.resnet import resnet50 16 | 17 | import vision_transformer as vits 18 | 19 | dependencies = ["torch", "torchvision"] 20 | 21 | 22 | def dino_vits16(pretrained=True, **kwargs): 23 | """ 24 | ViT-Small/16x16 pre-trained with DINO. 25 | Achieves 74.5% top-1 accuracy on ImageNet with k-NN classification. 26 | """ 27 | model = vits.__dict__["vit_small"](patch_size=16, num_classes=0, **kwargs) 28 | if pretrained: 29 | state_dict = torch.hub.load_state_dict_from_url( 30 | url="https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth", 31 | map_location="cpu", 32 | ) 33 | model.load_state_dict(state_dict, strict=True) 34 | return model 35 | 36 | 37 | def dino_vits8(pretrained=True, **kwargs): 38 | """ 39 | ViT-Small/8x8 pre-trained with DINO. 40 | Achieves 78.3% top-1 accuracy on ImageNet with k-NN classification. 41 | """ 42 | model = vits.__dict__["vit_small"](patch_size=8, num_classes=0, **kwargs) 43 | if pretrained: 44 | state_dict = torch.hub.load_state_dict_from_url( 45 | url="https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth", 46 | map_location="cpu", 47 | ) 48 | model.load_state_dict(state_dict, strict=True) 49 | return model 50 | 51 | 52 | def dino_vitb16(pretrained=True, **kwargs): 53 | """ 54 | ViT-Base/16x16 pre-trained with DINO. 55 | Achieves 76.1% top-1 accuracy on ImageNet with k-NN classification. 56 | """ 57 | model = vits.__dict__["vit_base"](patch_size=16, num_classes=0, **kwargs) 58 | if pretrained: 59 | state_dict = torch.hub.load_state_dict_from_url( 60 | url="https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth", 61 | map_location="cpu", 62 | ) 63 | model.load_state_dict(state_dict, strict=True) 64 | return model 65 | 66 | 67 | def dino_vitb8(pretrained=True, **kwargs): 68 | """ 69 | ViT-Base/8x8 pre-trained with DINO. 70 | Achieves 77.4% top-1 accuracy on ImageNet with k-NN classification. 71 | """ 72 | model = vits.__dict__["vit_base"](patch_size=8, num_classes=0, **kwargs) 73 | if pretrained: 74 | state_dict = torch.hub.load_state_dict_from_url( 75 | url="https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth", 76 | map_location="cpu", 77 | ) 78 | model.load_state_dict(state_dict, strict=True) 79 | return model 80 | 81 | 82 | def dino_resnet50(pretrained=True, **kwargs): 83 | """ 84 | ResNet-50 pre-trained with DINO. 85 | Achieves 75.3% top-1 accuracy on ImageNet linear evaluation benchmark (requires to train `fc`). 86 | """ 87 | model = resnet50(pretrained=False, **kwargs) 88 | model.fc = torch.nn.Identity() 89 | if pretrained: 90 | state_dict = torch.hub.load_state_dict_from_url( 91 | url="https://dl.fbaipublicfiles.com/dino/dino_resnet50_pretrain/dino_resnet50_pretrain.pth", 92 | map_location="cpu", 93 | ) 94 | model.load_state_dict(state_dict, strict=False) 95 | return model 96 | 97 | 98 | def dino_xcit_small_12_p16(pretrained=True, **kwargs): 99 | """ 100 | XCiT-Small-12/16 pre-trained with DINO. 101 | """ 102 | model = torch.hub.load('facebookresearch/xcit:main', "xcit_small_12_p16", num_classes=0, **kwargs) 103 | if pretrained: 104 | state_dict = torch.hub.load_state_dict_from_url( 105 | url="https://dl.fbaipublicfiles.com/dino/dino_xcit_small_12_p16_pretrain/dino_xcit_small_12_p16_pretrain.pth", 106 | map_location="cpu", 107 | ) 108 | model.load_state_dict(state_dict, strict=True) 109 | return model 110 | 111 | 112 | def dino_xcit_small_12_p8(pretrained=True, **kwargs): 113 | """ 114 | XCiT-Small-12/8 pre-trained with DINO. 115 | """ 116 | model = torch.hub.load('facebookresearch/xcit:main', "xcit_small_12_p8", num_classes=0, **kwargs) 117 | if pretrained: 118 | state_dict = torch.hub.load_state_dict_from_url( 119 | url="https://dl.fbaipublicfiles.com/dino/dino_xcit_small_12_p8_pretrain/dino_xcit_small_12_p8_pretrain.pth", 120 | map_location="cpu", 121 | ) 122 | model.load_state_dict(state_dict, strict=True) 123 | return model 124 | 125 | 126 | def dino_xcit_medium_24_p16(pretrained=True, **kwargs): 127 | """ 128 | XCiT-Medium-24/16 pre-trained with DINO. 129 | """ 130 | model = torch.hub.load('facebookresearch/xcit:main', "xcit_medium_24_p16", num_classes=0, **kwargs) 131 | if pretrained: 132 | state_dict = torch.hub.load_state_dict_from_url( 133 | url="https://dl.fbaipublicfiles.com/dino/dino_xcit_medium_24_p16_pretrain/dino_xcit_medium_24_p16_pretrain.pth", 134 | map_location="cpu", 135 | ) 136 | model.load_state_dict(state_dict, strict=True) 137 | return model 138 | 139 | 140 | def dino_xcit_medium_24_p8(pretrained=True, **kwargs): 141 | """ 142 | XCiT-Medium-24/8 pre-trained with DINO. 143 | """ 144 | model = torch.hub.load('facebookresearch/xcit:main', "xcit_medium_24_p8", num_classes=0, **kwargs) 145 | if pretrained: 146 | state_dict = torch.hub.load_state_dict_from_url( 147 | url="https://dl.fbaipublicfiles.com/dino/dino_xcit_medium_24_p8_pretrain/dino_xcit_medium_24_p8_pretrain.pth", 148 | map_location="cpu", 149 | ) 150 | model.load_state_dict(state_dict, strict=True) 151 | return model 152 | -------------------------------------------------------------------------------- /main_dino.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import argparse 15 | import os 16 | import sys 17 | import datetime 18 | import time 19 | import math 20 | import json 21 | from pathlib import Path 22 | 23 | import numpy as np 24 | from PIL import Image 25 | import torch 26 | import torch.nn as nn 27 | import torch.distributed as dist 28 | import torch.backends.cudnn as cudnn 29 | import torch.nn.functional as F 30 | from torchvision import datasets, transforms 31 | from torchvision import models as torchvision_models 32 | 33 | import utils 34 | import vision_transformer as vits 35 | from vision_transformer import DINOHead 36 | 37 | torchvision_archs = sorted(name for name in torchvision_models.__dict__ 38 | if name.islower() and not name.startswith("__") 39 | and callable(torchvision_models.__dict__[name])) 40 | 41 | def get_args_parser(): 42 | parser = argparse.ArgumentParser('DINO', add_help=False) 43 | 44 | # Model parameters 45 | parser.add_argument('--arch', default='vit_small', type=str, 46 | choices=['vit_tiny', 'vit_small', 'vit_base', 'xcit', 'deit_tiny', 'deit_small'] \ 47 | + torchvision_archs + torch.hub.list("facebookresearch/xcit:main"), 48 | help="""Name of architecture to train. For quick experiments with ViTs, 49 | we recommend using vit_tiny or vit_small.""") 50 | parser.add_argument('--patch_size', default=16, type=int, help="""Size in pixels 51 | of input square patches - default 16 (for 16x16 patches). Using smaller 52 | values leads to better performance but requires more memory. Applies only 53 | for ViTs (vit_tiny, vit_small and vit_base). If <16, we recommend disabling 54 | mixed precision training (--use_fp16 false) to avoid unstabilities.""") 55 | parser.add_argument('--out_dim', default=65536, type=int, help="""Dimensionality of 56 | the DINO head output. For complex and large datasets large values (like 65k) work well.""") 57 | parser.add_argument('--norm_last_layer', default=True, type=utils.bool_flag, 58 | help="""Whether or not to weight normalize the last layer of the DINO head. 59 | Not normalizing leads to better performance but can make the training unstable. 60 | In our experiments, we typically set this paramater to False with vit_small and True with vit_base.""") 61 | parser.add_argument('--momentum_teacher', default=0.996, type=float, help="""Base EMA 62 | parameter for teacher update. The value is increased to 1 during training with cosine schedule. 63 | We recommend setting a higher value with small batches: for example use 0.9995 with batch size of 256.""") 64 | parser.add_argument('--use_bn_in_head', default=False, type=utils.bool_flag, 65 | help="Whether to use batch normalizations in projection head (Default: False)") 66 | 67 | # Temperature teacher parameters 68 | parser.add_argument('--warmup_teacher_temp', default=0.04, type=float, 69 | help="""Initial value for the teacher temperature: 0.04 works well in most cases. 70 | Try decreasing it if the training loss does not decrease.""") 71 | parser.add_argument('--teacher_temp', default=0.04, type=float, help="""Final value (after linear warmup) 72 | of the teacher temperature. For most experiments, anything above 0.07 is unstable. We recommend 73 | starting with the default value of 0.04 and increase this slightly if needed.""") 74 | parser.add_argument('--warmup_teacher_temp_epochs', default=0, type=int, 75 | help='Number of warmup epochs for the teacher temperature (Default: 30).') 76 | 77 | # Training/Optimization parameters 78 | parser.add_argument('--use_fp16', type=utils.bool_flag, default=True, help="""Whether or not 79 | to use half precision for training. Improves training time and memory requirements, 80 | but can provoke instability and slight decay of performance. We recommend disabling 81 | mixed precision if the loss is unstable, if reducing the patch size or if training with bigger ViTs.""") 82 | parser.add_argument('--weight_decay', type=float, default=0.04, help="""Initial value of the 83 | weight decay. With ViT, a smaller value at the beginning of training works well.""") 84 | parser.add_argument('--weight_decay_end', type=float, default=0.4, help="""Final value of the 85 | weight decay. We use a cosine schedule for WD and using a larger decay by 86 | the end of training improves performance for ViTs.""") 87 | parser.add_argument('--clip_grad', type=float, default=3.0, help="""Maximal parameter 88 | gradient norm if using gradient clipping. Clipping with norm .3 ~ 1.0 can 89 | help optimization for larger ViT architectures. 0 for disabling.""") 90 | parser.add_argument('--batch_size_per_gpu', default=64, type=int, 91 | help='Per-GPU batch-size : number of distinct images loaded on one GPU.') 92 | parser.add_argument('--epochs', default=100, type=int, help='Number of epochs of training.') 93 | parser.add_argument('--freeze_last_layer', default=1, type=int, help="""Number of epochs 94 | during which we keep the output layer fixed. Typically doing so during 95 | the first epoch helps training. Try increasing this value if the loss does not decrease.""") 96 | parser.add_argument("--lr", default=0.0005, type=float, help="""Learning rate at the end of 97 | linear warmup (highest LR used during training). The learning rate is linearly scaled 98 | with the batch size, and specified here for a reference batch size of 256.""") 99 | parser.add_argument("--warmup_epochs", default=10, type=int, 100 | help="Number of epochs for the linear learning-rate warm up.") 101 | parser.add_argument('--min_lr', type=float, default=1e-6, help="""Target LR at the 102 | end of optimization. We use a cosine LR schedule with linear warmup.""") 103 | parser.add_argument('--optimizer', default='adamw', type=str, 104 | choices=['adamw', 'sgd', 'lars'], help="""Type of optimizer. We recommend using adamw with ViTs.""") 105 | parser.add_argument('--drop_path_rate', type=float, default=0.1, help="stochastic depth rate") 106 | 107 | # Multi-crop parameters 108 | parser.add_argument('--global_crops_scale', type=float, nargs='+', default=(0.4, 1.), 109 | help="""Scale range of the cropped image before resizing, relatively to the origin image. 110 | Used for large global view cropping. When disabling multi-crop (--local_crops_number 0), we 111 | recommand using a wider range of scale ("--global_crops_scale 0.14 1." for example)""") 112 | parser.add_argument('--local_crops_number', type=int, default=8, help="""Number of small 113 | local views to generate. Set this parameter to 0 to disable multi-crop training. 114 | When disabling multi-crop we recommend to use "--global_crops_scale 0.14 1." """) 115 | parser.add_argument('--local_crops_scale', type=float, nargs='+', default=(0.05, 0.4), 116 | help="""Scale range of the cropped image before resizing, relatively to the origin image. 117 | Used for small local view cropping of multi-crop.""") 118 | 119 | # Misc 120 | parser.add_argument('--data_path', default='/path/to/imagenet/train/', type=str, 121 | help='Please specify path to the ImageNet training data.') 122 | parser.add_argument('--output_dir', default=".", type=str, help='Path to save logs and checkpoints.') 123 | parser.add_argument('--saveckp_freq', default=20, type=int, help='Save checkpoint every x epochs.') 124 | parser.add_argument('--seed', default=0, type=int, help='Random seed.') 125 | parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.') 126 | parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up 127 | distributed training; see https://pytorch.org/docs/stable/distributed.html""") 128 | parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.") 129 | return parser 130 | 131 | 132 | def train_dino(args): 133 | utils.init_distributed_mode(args) 134 | utils.fix_random_seeds(args.seed) 135 | print("git:\n {}\n".format(utils.get_sha())) 136 | print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))) 137 | cudnn.benchmark = True 138 | 139 | # ============ preparing data ... ============ 140 | transform = DataAugmentationDINO( 141 | args.global_crops_scale, 142 | args.local_crops_scale, 143 | args.local_crops_number, 144 | ) 145 | dataset = datasets.ImageFolder(args.data_path, transform=transform) 146 | sampler = torch.utils.data.DistributedSampler(dataset, shuffle=True) 147 | data_loader = torch.utils.data.DataLoader( 148 | dataset, 149 | sampler=sampler, 150 | batch_size=args.batch_size_per_gpu, 151 | num_workers=args.num_workers, 152 | pin_memory=True, 153 | drop_last=True, 154 | ) 155 | print(f"Data loaded: there are {len(dataset)} images.") 156 | 157 | # ============ building student and teacher networks ... ============ 158 | # we changed the name DeiT-S for ViT-S to avoid confusions 159 | args.arch = args.arch.replace("deit", "vit") 160 | # if the network is a Vision Transformer (i.e. vit_tiny, vit_small, vit_base) 161 | if args.arch in vits.__dict__.keys(): 162 | student = vits.__dict__[args.arch]( 163 | patch_size=args.patch_size, 164 | drop_path_rate=args.drop_path_rate, # stochastic depth 165 | ) 166 | teacher = vits.__dict__[args.arch](patch_size=args.patch_size) 167 | embed_dim = student.embed_dim 168 | # if the network is a XCiT 169 | elif args.arch in torch.hub.list("facebookresearch/xcit:main"): 170 | student = torch.hub.load('facebookresearch/xcit:main', args.arch, 171 | pretrained=False, drop_path_rate=args.drop_path_rate) 172 | teacher = torch.hub.load('facebookresearch/xcit:main', args.arch, pretrained=False) 173 | embed_dim = student.embed_dim 174 | # otherwise, we check if the architecture is in torchvision models 175 | elif args.arch in torchvision_models.__dict__.keys(): 176 | student = torchvision_models.__dict__[args.arch]() 177 | teacher = torchvision_models.__dict__[args.arch]() 178 | embed_dim = student.fc.weight.shape[1] 179 | else: 180 | print(f"Unknow architecture: {args.arch}") 181 | 182 | # multi-crop wrapper handles forward with inputs of different resolutions 183 | student = utils.MultiCropWrapper(student, DINOHead( 184 | embed_dim, 185 | args.out_dim, 186 | use_bn=args.use_bn_in_head, 187 | norm_last_layer=args.norm_last_layer, 188 | )) 189 | teacher = utils.MultiCropWrapper( 190 | teacher, 191 | DINOHead(embed_dim, args.out_dim, args.use_bn_in_head), 192 | ) 193 | # move networks to gpu 194 | student, teacher = student.cuda(), teacher.cuda() 195 | # synchronize batch norms (if any) 196 | if utils.has_batchnorms(student): 197 | student = nn.SyncBatchNorm.convert_sync_batchnorm(student) 198 | teacher = nn.SyncBatchNorm.convert_sync_batchnorm(teacher) 199 | 200 | # we need DDP wrapper to have synchro batch norms working... 201 | teacher = nn.parallel.DistributedDataParallel(teacher, device_ids=[args.gpu]) 202 | teacher_without_ddp = teacher.module 203 | else: 204 | # teacher_without_ddp and teacher are the same thing 205 | teacher_without_ddp = teacher 206 | student = nn.parallel.DistributedDataParallel(student, device_ids=[args.gpu]) 207 | # teacher and student start with the same weights 208 | teacher_without_ddp.load_state_dict(student.module.state_dict()) 209 | # there is no backpropagation through the teacher, so no need for gradients 210 | for p in teacher.parameters(): 211 | p.requires_grad = False 212 | print(f"Student and Teacher are built: they are both {args.arch} network.") 213 | 214 | # ============ preparing loss ... ============ 215 | dino_loss = DINOLoss( 216 | args.out_dim, 217 | args.local_crops_number + 2, # total number of crops = 2 global crops + local_crops_number 218 | args.warmup_teacher_temp, 219 | args.teacher_temp, 220 | args.warmup_teacher_temp_epochs, 221 | args.epochs, 222 | ).cuda() 223 | 224 | # ============ preparing optimizer ... ============ 225 | params_groups = utils.get_params_groups(student) 226 | if args.optimizer == "adamw": 227 | optimizer = torch.optim.AdamW(params_groups) # to use with ViTs 228 | elif args.optimizer == "sgd": 229 | optimizer = torch.optim.SGD(params_groups, lr=0, momentum=0.9) # lr is set by scheduler 230 | elif args.optimizer == "lars": 231 | optimizer = utils.LARS(params_groups) # to use with convnet and large batches 232 | # for mixed precision training 233 | fp16_scaler = None 234 | if args.use_fp16: 235 | fp16_scaler = torch.cuda.amp.GradScaler() 236 | 237 | # ============ init schedulers ... ============ 238 | lr_schedule = utils.cosine_scheduler( 239 | args.lr * (args.batch_size_per_gpu * utils.get_world_size()) / 256., # linear scaling rule 240 | args.min_lr, 241 | args.epochs, len(data_loader), 242 | warmup_epochs=args.warmup_epochs, 243 | ) 244 | wd_schedule = utils.cosine_scheduler( 245 | args.weight_decay, 246 | args.weight_decay_end, 247 | args.epochs, len(data_loader), 248 | ) 249 | # momentum parameter is increased to 1. during training with a cosine schedule 250 | momentum_schedule = utils.cosine_scheduler(args.momentum_teacher, 1, 251 | args.epochs, len(data_loader)) 252 | print(f"Loss, optimizer and schedulers ready.") 253 | 254 | # ============ optionally resume training ... ============ 255 | to_restore = {"epoch": 0} 256 | utils.restart_from_checkpoint( 257 | os.path.join(args.output_dir, "checkpoint.pth"), 258 | run_variables=to_restore, 259 | student=student, 260 | teacher=teacher, 261 | optimizer=optimizer, 262 | fp16_scaler=fp16_scaler, 263 | dino_loss=dino_loss, 264 | ) 265 | start_epoch = to_restore["epoch"] 266 | 267 | start_time = time.time() 268 | print("Starting DINO training !") 269 | for epoch in range(start_epoch, args.epochs): 270 | data_loader.sampler.set_epoch(epoch) 271 | 272 | # ============ training one epoch of DINO ... ============ 273 | train_stats = train_one_epoch(student, teacher, teacher_without_ddp, dino_loss, 274 | data_loader, optimizer, lr_schedule, wd_schedule, momentum_schedule, 275 | epoch, fp16_scaler, args) 276 | 277 | # ============ writing logs ... ============ 278 | save_dict = { 279 | 'student': student.state_dict(), 280 | 'teacher': teacher.state_dict(), 281 | 'optimizer': optimizer.state_dict(), 282 | 'epoch': epoch + 1, 283 | 'args': args, 284 | 'dino_loss': dino_loss.state_dict(), 285 | } 286 | if fp16_scaler is not None: 287 | save_dict['fp16_scaler'] = fp16_scaler.state_dict() 288 | utils.save_on_master(save_dict, os.path.join(args.output_dir, 'checkpoint.pth')) 289 | if args.saveckp_freq and epoch % args.saveckp_freq == 0: 290 | utils.save_on_master(save_dict, os.path.join(args.output_dir, f'checkpoint{epoch:04}.pth')) 291 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 292 | 'epoch': epoch} 293 | if utils.is_main_process(): 294 | with (Path(args.output_dir) / "log.txt").open("a") as f: 295 | f.write(json.dumps(log_stats) + "\n") 296 | total_time = time.time() - start_time 297 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 298 | print('Training time {}'.format(total_time_str)) 299 | 300 | 301 | def train_one_epoch(student, teacher, teacher_without_ddp, dino_loss, data_loader, 302 | optimizer, lr_schedule, wd_schedule, momentum_schedule,epoch, 303 | fp16_scaler, args): 304 | metric_logger = utils.MetricLogger(delimiter=" ") 305 | header = 'Epoch: [{}/{}]'.format(epoch, args.epochs) 306 | for it, (images, _) in enumerate(metric_logger.log_every(data_loader, 10, header)): 307 | # update weight decay and learning rate according to their schedule 308 | it = len(data_loader) * epoch + it # global training iteration 309 | for i, param_group in enumerate(optimizer.param_groups): 310 | param_group["lr"] = lr_schedule[it] 311 | if i == 0: # only the first group is regularized 312 | param_group["weight_decay"] = wd_schedule[it] 313 | 314 | # move images to gpu 315 | images = [im.cuda(non_blocking=True) for im in images] 316 | # teacher and student forward passes + compute dino loss 317 | with torch.cuda.amp.autocast(fp16_scaler is not None): 318 | teacher_output = teacher(images[:2]) # only the 2 global views pass through the teacher 319 | student_output = student(images) 320 | loss = dino_loss(student_output, teacher_output, epoch) 321 | 322 | if not math.isfinite(loss.item()): 323 | print("Loss is {}, stopping training".format(loss.item()), force=True) 324 | sys.exit(1) 325 | 326 | # student update 327 | optimizer.zero_grad() 328 | param_norms = None 329 | if fp16_scaler is None: 330 | loss.backward() 331 | if args.clip_grad: 332 | param_norms = utils.clip_gradients(student, args.clip_grad) 333 | utils.cancel_gradients_last_layer(epoch, student, 334 | args.freeze_last_layer) 335 | optimizer.step() 336 | else: 337 | fp16_scaler.scale(loss).backward() 338 | if args.clip_grad: 339 | fp16_scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 340 | param_norms = utils.clip_gradients(student, args.clip_grad) 341 | utils.cancel_gradients_last_layer(epoch, student, 342 | args.freeze_last_layer) 343 | fp16_scaler.step(optimizer) 344 | fp16_scaler.update() 345 | 346 | # EMA update for the teacher 347 | with torch.no_grad(): 348 | m = momentum_schedule[it] # momentum parameter 349 | for param_q, param_k in zip(student.module.parameters(), teacher_without_ddp.parameters()): 350 | param_k.data.mul_(m).add_((1 - m) * param_q.detach().data) 351 | 352 | # logging 353 | torch.cuda.synchronize() 354 | metric_logger.update(loss=loss.item()) 355 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 356 | metric_logger.update(wd=optimizer.param_groups[0]["weight_decay"]) 357 | # gather the stats from all processes 358 | metric_logger.synchronize_between_processes() 359 | print("Averaged stats:", metric_logger) 360 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 361 | 362 | 363 | class DINOLoss(nn.Module): 364 | def __init__(self, out_dim, ncrops, warmup_teacher_temp, teacher_temp, 365 | warmup_teacher_temp_epochs, nepochs, student_temp=0.1, 366 | center_momentum=0.9): 367 | super().__init__() 368 | self.student_temp = student_temp 369 | self.center_momentum = center_momentum 370 | self.ncrops = ncrops 371 | self.register_buffer("center", torch.zeros(1, out_dim)) 372 | # we apply a warm up for the teacher temperature because 373 | # a too high temperature makes the training instable at the beginning 374 | self.teacher_temp_schedule = np.concatenate(( 375 | np.linspace(warmup_teacher_temp, 376 | teacher_temp, warmup_teacher_temp_epochs), 377 | np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp 378 | )) 379 | 380 | def forward(self, student_output, teacher_output, epoch): 381 | """ 382 | Cross-entropy between softmax outputs of the teacher and student networks. 383 | """ 384 | student_out = student_output / self.student_temp 385 | student_out = student_out.chunk(self.ncrops) 386 | 387 | # teacher centering and sharpening 388 | temp = self.teacher_temp_schedule[epoch] 389 | teacher_out = F.softmax((teacher_output - self.center) / temp, dim=-1) 390 | teacher_out = teacher_out.detach().chunk(2) 391 | 392 | total_loss = 0 393 | n_loss_terms = 0 394 | for iq, q in enumerate(teacher_out): 395 | for v in range(len(student_out)): 396 | if v == iq: 397 | # we skip cases where student and teacher operate on the same view 398 | continue 399 | loss = torch.sum(-q * F.log_softmax(student_out[v], dim=-1), dim=-1) 400 | total_loss += loss.mean() 401 | n_loss_terms += 1 402 | total_loss /= n_loss_terms 403 | self.update_center(teacher_output) 404 | return total_loss 405 | 406 | @torch.no_grad() 407 | def update_center(self, teacher_output): 408 | """ 409 | Update center used for teacher output. 410 | """ 411 | batch_center = torch.sum(teacher_output, dim=0, keepdim=True) 412 | dist.all_reduce(batch_center) 413 | batch_center = batch_center / (len(teacher_output) * dist.get_world_size()) 414 | 415 | # ema update 416 | self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum) 417 | 418 | 419 | class DataAugmentationDINO(object): 420 | def __init__(self, global_crops_scale, local_crops_scale, local_crops_number): 421 | flip_and_color_jitter = transforms.Compose([ 422 | transforms.RandomHorizontalFlip(p=0.5), 423 | transforms.RandomApply( 424 | [transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)], 425 | p=0.8 426 | ), 427 | transforms.RandomGrayscale(p=0.2), 428 | ]) 429 | normalize = transforms.Compose([ 430 | transforms.ToTensor(), 431 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 432 | ]) 433 | 434 | # first global crop 435 | self.global_transfo1 = transforms.Compose([ 436 | transforms.RandomResizedCrop(224, scale=global_crops_scale, interpolation=Image.BICUBIC), 437 | flip_and_color_jitter, 438 | utils.GaussianBlur(1.0), 439 | normalize, 440 | ]) 441 | # second global crop 442 | self.global_transfo2 = transforms.Compose([ 443 | transforms.RandomResizedCrop(224, scale=global_crops_scale, interpolation=Image.BICUBIC), 444 | flip_and_color_jitter, 445 | utils.GaussianBlur(0.1), 446 | utils.Solarization(0.2), 447 | normalize, 448 | ]) 449 | # transformation for the local small crops 450 | self.local_crops_number = local_crops_number 451 | self.local_transfo = transforms.Compose([ 452 | transforms.RandomResizedCrop(96, scale=local_crops_scale, interpolation=Image.BICUBIC), 453 | flip_and_color_jitter, 454 | utils.GaussianBlur(p=0.5), 455 | normalize, 456 | ]) 457 | 458 | def __call__(self, image): 459 | crops = [] 460 | crops.append(self.global_transfo1(image)) 461 | crops.append(self.global_transfo2(image)) 462 | for _ in range(self.local_crops_number): 463 | crops.append(self.local_transfo(image)) 464 | return crops 465 | 466 | 467 | if __name__ == '__main__': 468 | parser = argparse.ArgumentParser('DINO', parents=[get_args_parser()]) 469 | args = parser.parse_args() 470 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 471 | train_dino(args) 472 | -------------------------------------------------------------------------------- /run_with_submitit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | A script to run multinode training with submitit. 16 | Almost copy-paste from https://github.com/facebookresearch/deit/blob/main/run_with_submitit.py 17 | """ 18 | import argparse 19 | import os 20 | import uuid 21 | from pathlib import Path 22 | 23 | import main_dino 24 | import submitit 25 | 26 | 27 | def parse_args(): 28 | parser = argparse.ArgumentParser("Submitit for DINO", parents=[main_dino.get_args_parser()]) 29 | parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node") 30 | parser.add_argument("--nodes", default=2, type=int, help="Number of nodes to request") 31 | parser.add_argument("--timeout", default=2800, type=int, help="Duration of the job") 32 | 33 | parser.add_argument("--partition", default="learnfair", type=str, help="Partition where to submit") 34 | parser.add_argument("--use_volta32", action='store_true', help="Big models? Use this") 35 | parser.add_argument('--comment', default="", type=str, 36 | help='Comment to pass to scheduler, e.g. priority message') 37 | return parser.parse_args() 38 | 39 | 40 | def get_shared_folder() -> Path: 41 | user = os.getenv("USER") 42 | if Path("/checkpoint/").is_dir(): 43 | p = Path(f"/checkpoint/{user}/experiments") 44 | p.mkdir(exist_ok=True) 45 | return p 46 | raise RuntimeError("No shared folder available") 47 | 48 | 49 | def get_init_file(): 50 | # Init file must not exist, but it's parent dir must exist. 51 | os.makedirs(str(get_shared_folder()), exist_ok=True) 52 | init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init" 53 | if init_file.exists(): 54 | os.remove(str(init_file)) 55 | return init_file 56 | 57 | 58 | class Trainer(object): 59 | def __init__(self, args): 60 | self.args = args 61 | 62 | def __call__(self): 63 | import main_dino 64 | 65 | self._setup_gpu_args() 66 | main_dino.train_dino(self.args) 67 | 68 | def checkpoint(self): 69 | import os 70 | import submitit 71 | 72 | self.args.dist_url = get_init_file().as_uri() 73 | print("Requeuing ", self.args) 74 | empty_trainer = type(self)(self.args) 75 | return submitit.helpers.DelayedSubmission(empty_trainer) 76 | 77 | def _setup_gpu_args(self): 78 | import submitit 79 | from pathlib import Path 80 | 81 | job_env = submitit.JobEnvironment() 82 | self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id))) 83 | self.args.gpu = job_env.local_rank 84 | self.args.rank = job_env.global_rank 85 | self.args.world_size = job_env.num_tasks 86 | print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") 87 | 88 | 89 | def main(): 90 | args = parse_args() 91 | if args.output_dir == "": 92 | args.output_dir = get_shared_folder() / "%j" 93 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 94 | executor = submitit.AutoExecutor(folder=args.output_dir, slurm_max_num_timeout=30) 95 | 96 | num_gpus_per_node = args.ngpus 97 | nodes = args.nodes 98 | timeout_min = args.timeout 99 | 100 | partition = args.partition 101 | kwargs = {} 102 | if args.use_volta32: 103 | kwargs['slurm_constraint'] = 'volta32gb' 104 | if args.comment: 105 | kwargs['slurm_comment'] = args.comment 106 | 107 | executor.update_parameters( 108 | mem_gb=40 * num_gpus_per_node, 109 | gpus_per_node=num_gpus_per_node, 110 | tasks_per_node=num_gpus_per_node, # one task per GPU 111 | cpus_per_task=10, 112 | nodes=nodes, 113 | timeout_min=timeout_min, # max is 60 * 72 114 | # Below are cluster dependent parameters 115 | slurm_partition=partition, 116 | slurm_signal_delay_s=120, 117 | **kwargs 118 | ) 119 | 120 | executor.update_parameters(name="dino") 121 | 122 | args.dist_url = get_init_file().as_uri() 123 | 124 | trainer = Trainer(args) 125 | job = executor.submit(trainer) 126 | 127 | print(f"Submitted job_id: {job.job_id}") 128 | print(f"Logs and checkpoints will be saved at: {args.output_dir}") 129 | 130 | 131 | if __name__ == "__main__": 132 | main() 133 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Misc functions. 16 | 17 | Mostly copy-paste from torchvision references or other public repos like DETR: 18 | https://github.com/facebookresearch/detr/blob/master/util/misc.py 19 | """ 20 | import os 21 | import sys 22 | import time 23 | import math 24 | import random 25 | import datetime 26 | import subprocess 27 | from collections import defaultdict, deque 28 | 29 | import numpy as np 30 | import torch 31 | from torch import nn 32 | import torch.distributed as dist 33 | from PIL import ImageFilter, ImageOps 34 | 35 | 36 | class GaussianBlur(object): 37 | """ 38 | Apply Gaussian Blur to the PIL image. 39 | """ 40 | def __init__(self, p=0.5, radius_min=0.1, radius_max=2.): 41 | self.prob = p 42 | self.radius_min = radius_min 43 | self.radius_max = radius_max 44 | 45 | def __call__(self, img): 46 | do_it = random.random() <= self.prob 47 | if not do_it: 48 | return img 49 | 50 | return img.filter( 51 | ImageFilter.GaussianBlur( 52 | radius=random.uniform(self.radius_min, self.radius_max) 53 | ) 54 | ) 55 | 56 | 57 | class Solarization(object): 58 | """ 59 | Apply Solarization to the PIL image. 60 | """ 61 | def __init__(self, p): 62 | self.p = p 63 | 64 | def __call__(self, img): 65 | if random.random() < self.p: 66 | return ImageOps.solarize(img) 67 | else: 68 | return img 69 | 70 | 71 | def load_pretrained_weights(model, pretrained_weights, checkpoint_key, model_name, patch_size): 72 | if os.path.isfile(pretrained_weights): 73 | state_dict = torch.load(pretrained_weights, map_location="cpu") 74 | if checkpoint_key is not None and checkpoint_key in state_dict: 75 | print(f"Take key {checkpoint_key} in provided checkpoint dict") 76 | state_dict = state_dict[checkpoint_key] 77 | # remove `module.` prefix 78 | state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} 79 | # remove `backbone.` prefix induced by multicrop wrapper 80 | state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} 81 | msg = model.load_state_dict(state_dict, strict=False) 82 | print('Pretrained weights found at {} and loaded with msg: {}'.format(pretrained_weights, msg)) 83 | else: 84 | print("Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate.") 85 | url = None 86 | if model_name == "vit_small" and patch_size == 16: 87 | url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth" 88 | elif model_name == "vit_small" and patch_size == 8: 89 | url = "dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth" 90 | elif model_name == "vit_base" and patch_size == 16: 91 | url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth" 92 | elif model_name == "vit_base" and patch_size == 8: 93 | url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth" 94 | elif model_name == "xcit_small_12_p16": 95 | url = "dino_xcit_small_12_p16_pretrain/dino_xcit_small_12_p16_pretrain.pth" 96 | elif model_name == "xcit_small_12_p8": 97 | url = "dino_xcit_small_12_p8_pretrain/dino_xcit_small_12_p8_pretrain.pth" 98 | elif model_name == "xcit_medium_24_p16": 99 | url = "dino_xcit_medium_24_p16_pretrain/dino_xcit_medium_24_p16_pretrain.pth" 100 | elif model_name == "xcit_medium_24_p8": 101 | url = "dino_xcit_medium_24_p8_pretrain/dino_xcit_medium_24_p8_pretrain.pth" 102 | elif model_name == "resnet50": 103 | url = "dino_resnet50_pretrain/dino_resnet50_pretrain.pth" 104 | if url is not None: 105 | print("Since no pretrained weights have been provided, we load the reference pretrained DINO weights.") 106 | state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url) 107 | model.load_state_dict(state_dict, strict=True) 108 | else: 109 | print("There is no reference weights available for this model => We use random weights.") 110 | 111 | 112 | def load_pretrained_linear_weights(linear_classifier, model_name, patch_size): 113 | url = None 114 | if model_name == "vit_small" and patch_size == 16: 115 | url = "dino_deitsmall16_pretrain/dino_deitsmall16_linearweights.pth" 116 | elif model_name == "vit_small" and patch_size == 8: 117 | url = "dino_deitsmall8_pretrain/dino_deitsmall8_linearweights.pth" 118 | elif model_name == "vit_base" and patch_size == 16: 119 | url = "dino_vitbase16_pretrain/dino_vitbase16_linearweights.pth" 120 | elif model_name == "vit_base" and patch_size == 8: 121 | url = "dino_vitbase8_pretrain/dino_vitbase8_linearweights.pth" 122 | elif model_name == "resnet50": 123 | url = "dino_resnet50_pretrain/dino_resnet50_linearweights.pth" 124 | if url is not None: 125 | print("We load the reference pretrained linear weights.") 126 | state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)["state_dict"] 127 | linear_classifier.load_state_dict(state_dict, strict=True) 128 | else: 129 | print("We use random linear weights.") 130 | 131 | 132 | def clip_gradients(model, clip): 133 | norms = [] 134 | for name, p in model.named_parameters(): 135 | if p.grad is not None: 136 | param_norm = p.grad.data.norm(2) 137 | norms.append(param_norm.item()) 138 | clip_coef = clip / (param_norm + 1e-6) 139 | if clip_coef < 1: 140 | p.grad.data.mul_(clip_coef) 141 | return norms 142 | 143 | 144 | def cancel_gradients_last_layer(epoch, model, freeze_last_layer): 145 | if epoch >= freeze_last_layer: 146 | return 147 | for n, p in model.named_parameters(): 148 | if "last_layer" in n: 149 | p.grad = None 150 | 151 | 152 | def restart_from_checkpoint(ckp_path, run_variables=None, **kwargs): 153 | """ 154 | Re-start from checkpoint 155 | """ 156 | if not os.path.isfile(ckp_path): 157 | return 158 | print("Found checkpoint at {}".format(ckp_path)) 159 | 160 | # open checkpoint file 161 | checkpoint = torch.load(ckp_path, map_location="cpu") 162 | 163 | # key is what to look for in the checkpoint file 164 | # value is the object to load 165 | # example: {'state_dict': model} 166 | for key, value in kwargs.items(): 167 | if key in checkpoint and value is not None: 168 | try: 169 | msg = value.load_state_dict(checkpoint[key], strict=False) 170 | print("=> loaded '{}' from checkpoint '{}' with msg {}".format(key, ckp_path, msg)) 171 | except TypeError: 172 | try: 173 | msg = value.load_state_dict(checkpoint[key]) 174 | print("=> loaded '{}' from checkpoint: '{}'".format(key, ckp_path)) 175 | except ValueError: 176 | print("=> failed to load '{}' from checkpoint: '{}'".format(key, ckp_path)) 177 | else: 178 | print("=> key '{}' not found in checkpoint: '{}'".format(key, ckp_path)) 179 | 180 | # re load variable important for the run 181 | if run_variables is not None: 182 | for var_name in run_variables: 183 | if var_name in checkpoint: 184 | run_variables[var_name] = checkpoint[var_name] 185 | 186 | 187 | def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0): 188 | warmup_schedule = np.array([]) 189 | warmup_iters = warmup_epochs * niter_per_ep 190 | if warmup_epochs > 0: 191 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) 192 | 193 | iters = np.arange(epochs * niter_per_ep - warmup_iters) 194 | schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters))) 195 | 196 | schedule = np.concatenate((warmup_schedule, schedule)) 197 | assert len(schedule) == epochs * niter_per_ep 198 | return schedule 199 | 200 | 201 | def bool_flag(s): 202 | """ 203 | Parse boolean arguments from the command line. 204 | """ 205 | FALSY_STRINGS = {"off", "false", "0"} 206 | TRUTHY_STRINGS = {"on", "true", "1"} 207 | if s.lower() in FALSY_STRINGS: 208 | return False 209 | elif s.lower() in TRUTHY_STRINGS: 210 | return True 211 | else: 212 | raise argparse.ArgumentTypeError("invalid value for a boolean flag") 213 | 214 | 215 | def fix_random_seeds(seed=31): 216 | """ 217 | Fix random seeds. 218 | """ 219 | torch.manual_seed(seed) 220 | torch.cuda.manual_seed_all(seed) 221 | np.random.seed(seed) 222 | 223 | 224 | class SmoothedValue(object): 225 | """Track a series of values and provide access to smoothed values over a 226 | window or the global series average. 227 | """ 228 | 229 | def __init__(self, window_size=20, fmt=None): 230 | if fmt is None: 231 | fmt = "{median:.6f} ({global_avg:.6f})" 232 | self.deque = deque(maxlen=window_size) 233 | self.total = 0.0 234 | self.count = 0 235 | self.fmt = fmt 236 | 237 | def update(self, value, n=1): 238 | self.deque.append(value) 239 | self.count += n 240 | self.total += value * n 241 | 242 | def synchronize_between_processes(self): 243 | """ 244 | Warning: does not synchronize the deque! 245 | """ 246 | if not is_dist_avail_and_initialized(): 247 | return 248 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 249 | dist.barrier() 250 | dist.all_reduce(t) 251 | t = t.tolist() 252 | self.count = int(t[0]) 253 | self.total = t[1] 254 | 255 | @property 256 | def median(self): 257 | d = torch.tensor(list(self.deque)) 258 | return d.median().item() 259 | 260 | @property 261 | def avg(self): 262 | d = torch.tensor(list(self.deque), dtype=torch.float32) 263 | return d.mean().item() 264 | 265 | @property 266 | def global_avg(self): 267 | return self.total / self.count 268 | 269 | @property 270 | def max(self): 271 | return max(self.deque) 272 | 273 | @property 274 | def value(self): 275 | return self.deque[-1] 276 | 277 | def __str__(self): 278 | return self.fmt.format( 279 | median=self.median, 280 | avg=self.avg, 281 | global_avg=self.global_avg, 282 | max=self.max, 283 | value=self.value) 284 | 285 | 286 | def reduce_dict(input_dict, average=True): 287 | """ 288 | Args: 289 | input_dict (dict): all the values will be reduced 290 | average (bool): whether to do average or sum 291 | Reduce the values in the dictionary from all processes so that all processes 292 | have the averaged results. Returns a dict with the same fields as 293 | input_dict, after reduction. 294 | """ 295 | world_size = get_world_size() 296 | if world_size < 2: 297 | return input_dict 298 | with torch.no_grad(): 299 | names = [] 300 | values = [] 301 | # sort the keys so that they are consistent across processes 302 | for k in sorted(input_dict.keys()): 303 | names.append(k) 304 | values.append(input_dict[k]) 305 | values = torch.stack(values, dim=0) 306 | dist.all_reduce(values) 307 | if average: 308 | values /= world_size 309 | reduced_dict = {k: v for k, v in zip(names, values)} 310 | return reduced_dict 311 | 312 | 313 | class MetricLogger(object): 314 | def __init__(self, delimiter="\t"): 315 | self.meters = defaultdict(SmoothedValue) 316 | self.delimiter = delimiter 317 | 318 | def update(self, **kwargs): 319 | for k, v in kwargs.items(): 320 | if isinstance(v, torch.Tensor): 321 | v = v.item() 322 | assert isinstance(v, (float, int)) 323 | self.meters[k].update(v) 324 | 325 | def __getattr__(self, attr): 326 | if attr in self.meters: 327 | return self.meters[attr] 328 | if attr in self.__dict__: 329 | return self.__dict__[attr] 330 | raise AttributeError("'{}' object has no attribute '{}'".format( 331 | type(self).__name__, attr)) 332 | 333 | def __str__(self): 334 | loss_str = [] 335 | for name, meter in self.meters.items(): 336 | loss_str.append( 337 | "{}: {}".format(name, str(meter)) 338 | ) 339 | return self.delimiter.join(loss_str) 340 | 341 | def synchronize_between_processes(self): 342 | for meter in self.meters.values(): 343 | meter.synchronize_between_processes() 344 | 345 | def add_meter(self, name, meter): 346 | self.meters[name] = meter 347 | 348 | def log_every(self, iterable, print_freq, header=None): 349 | i = 0 350 | if not header: 351 | header = '' 352 | start_time = time.time() 353 | end = time.time() 354 | iter_time = SmoothedValue(fmt='{avg:.6f}') 355 | data_time = SmoothedValue(fmt='{avg:.6f}') 356 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 357 | if torch.cuda.is_available(): 358 | log_msg = self.delimiter.join([ 359 | header, 360 | '[{0' + space_fmt + '}/{1}]', 361 | 'eta: {eta}', 362 | '{meters}', 363 | 'time: {time}', 364 | 'data: {data}', 365 | 'max mem: {memory:.0f}' 366 | ]) 367 | else: 368 | log_msg = self.delimiter.join([ 369 | header, 370 | '[{0' + space_fmt + '}/{1}]', 371 | 'eta: {eta}', 372 | '{meters}', 373 | 'time: {time}', 374 | 'data: {data}' 375 | ]) 376 | MB = 1024.0 * 1024.0 377 | for obj in iterable: 378 | data_time.update(time.time() - end) 379 | yield obj 380 | iter_time.update(time.time() - end) 381 | if i % print_freq == 0 or i == len(iterable) - 1: 382 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 383 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 384 | if torch.cuda.is_available(): 385 | print(log_msg.format( 386 | i, len(iterable), eta=eta_string, 387 | meters=str(self), 388 | time=str(iter_time), data=str(data_time), 389 | memory=torch.cuda.max_memory_allocated() / MB)) 390 | else: 391 | print(log_msg.format( 392 | i, len(iterable), eta=eta_string, 393 | meters=str(self), 394 | time=str(iter_time), data=str(data_time))) 395 | i += 1 396 | end = time.time() 397 | total_time = time.time() - start_time 398 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 399 | print('{} Total time: {} ({:.6f} s / it)'.format( 400 | header, total_time_str, total_time / len(iterable))) 401 | 402 | 403 | def get_sha(): 404 | cwd = os.path.dirname(os.path.abspath(__file__)) 405 | 406 | def _run(command): 407 | return subprocess.check_output(command, cwd=cwd).decode('ascii').strip() 408 | sha = 'N/A' 409 | diff = "clean" 410 | branch = 'N/A' 411 | try: 412 | sha = _run(['git', 'rev-parse', 'HEAD']) 413 | subprocess.check_output(['git', 'diff'], cwd=cwd) 414 | diff = _run(['git', 'diff-index', 'HEAD']) 415 | diff = "has uncommited changes" if diff else "clean" 416 | branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD']) 417 | except Exception: 418 | pass 419 | message = f"sha: {sha}, status: {diff}, branch: {branch}" 420 | return message 421 | 422 | 423 | def is_dist_avail_and_initialized(): 424 | if not dist.is_available(): 425 | return False 426 | if not dist.is_initialized(): 427 | return False 428 | return True 429 | 430 | 431 | def get_world_size(): 432 | if not is_dist_avail_and_initialized(): 433 | return 1 434 | return dist.get_world_size() 435 | 436 | 437 | def get_rank(): 438 | if not is_dist_avail_and_initialized(): 439 | return 0 440 | return dist.get_rank() 441 | 442 | 443 | def is_main_process(): 444 | return get_rank() == 0 445 | 446 | 447 | def save_on_master(*args, **kwargs): 448 | if is_main_process(): 449 | torch.save(*args, **kwargs) 450 | 451 | 452 | def setup_for_distributed(is_master): 453 | """ 454 | This function disables printing when not in master process 455 | """ 456 | import builtins as __builtin__ 457 | builtin_print = __builtin__.print 458 | 459 | def print(*args, **kwargs): 460 | force = kwargs.pop('force', False) 461 | if is_master or force: 462 | builtin_print(*args, **kwargs) 463 | 464 | __builtin__.print = print 465 | 466 | 467 | def init_distributed_mode(args): 468 | # launched with torch.distributed.launch 469 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 470 | args.rank = int(os.environ["RANK"]) 471 | args.world_size = int(os.environ['WORLD_SIZE']) 472 | args.gpu = int(os.environ['LOCAL_RANK']) 473 | # launched with submitit on a slurm cluster 474 | elif 'SLURM_PROCID' in os.environ: 475 | args.rank = int(os.environ['SLURM_PROCID']) 476 | args.gpu = args.rank % torch.cuda.device_count() 477 | # launched naively with `python main_dino.py` 478 | # we manually add MASTER_ADDR and MASTER_PORT to env variables 479 | elif torch.cuda.is_available(): 480 | print('Will run the code on one GPU.') 481 | args.rank, args.gpu, args.world_size = 0, 0, 1 482 | os.environ['MASTER_ADDR'] = '127.0.0.1' 483 | os.environ['MASTER_PORT'] = '29500' 484 | else: 485 | print('Does not support training without GPU.') 486 | sys.exit(1) 487 | 488 | dist.init_process_group( 489 | backend="nccl", 490 | init_method=args.dist_url, 491 | world_size=args.world_size, 492 | rank=args.rank, 493 | ) 494 | 495 | torch.cuda.set_device(args.gpu) 496 | print('| distributed init (rank {}): {}'.format( 497 | args.rank, args.dist_url), flush=True) 498 | dist.barrier() 499 | setup_for_distributed(args.rank == 0) 500 | 501 | 502 | def accuracy(output, target, topk=(1,)): 503 | """Computes the accuracy over the k top predictions for the specified values of k""" 504 | maxk = max(topk) 505 | batch_size = target.size(0) 506 | _, pred = output.topk(maxk, 1, True, True) 507 | pred = pred.t() 508 | correct = pred.eq(target.reshape(1, -1).expand_as(pred)) 509 | return [correct[:k].reshape(-1).float().sum(0) * 100. / batch_size for k in topk] 510 | 511 | 512 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 513 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 514 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 515 | def norm_cdf(x): 516 | # Computes standard normal cumulative distribution function 517 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 518 | 519 | if (mean < a - 2 * std) or (mean > b + 2 * std): 520 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 521 | "The distribution of values may be incorrect.", 522 | stacklevel=2) 523 | 524 | with torch.no_grad(): 525 | # Values are generated by using a truncated uniform distribution and 526 | # then using the inverse CDF for the normal distribution. 527 | # Get upper and lower cdf values 528 | l = norm_cdf((a - mean) / std) 529 | u = norm_cdf((b - mean) / std) 530 | 531 | # Uniformly fill tensor with values from [l, u], then translate to 532 | # [2l-1, 2u-1]. 533 | tensor.uniform_(2 * l - 1, 2 * u - 1) 534 | 535 | # Use inverse cdf transform for normal distribution to get truncated 536 | # standard normal 537 | tensor.erfinv_() 538 | 539 | # Transform to proper mean, std 540 | tensor.mul_(std * math.sqrt(2.)) 541 | tensor.add_(mean) 542 | 543 | # Clamp to ensure it's in the proper range 544 | tensor.clamp_(min=a, max=b) 545 | return tensor 546 | 547 | 548 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 549 | # type: (Tensor, float, float, float, float) -> Tensor 550 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 551 | 552 | 553 | class LARS(torch.optim.Optimizer): 554 | """ 555 | Almost copy-paste from https://github.com/facebookresearch/barlowtwins/blob/main/main.py 556 | """ 557 | def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, eta=0.001, 558 | weight_decay_filter=None, lars_adaptation_filter=None): 559 | defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, 560 | eta=eta, weight_decay_filter=weight_decay_filter, 561 | lars_adaptation_filter=lars_adaptation_filter) 562 | super().__init__(params, defaults) 563 | 564 | @torch.no_grad() 565 | def step(self): 566 | for g in self.param_groups: 567 | for p in g['params']: 568 | dp = p.grad 569 | 570 | if dp is None: 571 | continue 572 | 573 | if p.ndim != 1: 574 | dp = dp.add(p, alpha=g['weight_decay']) 575 | 576 | if p.ndim != 1: 577 | param_norm = torch.norm(p) 578 | update_norm = torch.norm(dp) 579 | one = torch.ones_like(param_norm) 580 | q = torch.where(param_norm > 0., 581 | torch.where(update_norm > 0, 582 | (g['eta'] * param_norm / update_norm), one), one) 583 | dp = dp.mul(q) 584 | 585 | param_state = self.state[p] 586 | if 'mu' not in param_state: 587 | param_state['mu'] = torch.zeros_like(p) 588 | mu = param_state['mu'] 589 | mu.mul_(g['momentum']).add_(dp) 590 | 591 | p.add_(mu, alpha=-g['lr']) 592 | 593 | 594 | class MultiCropWrapper(nn.Module): 595 | """ 596 | Perform forward pass separately on each resolution input. 597 | The inputs corresponding to a single resolution are clubbed and single 598 | forward is run on the same resolution inputs. Hence we do several 599 | forward passes = number of different resolutions used. We then 600 | concatenate all the output features and run the head forward on these 601 | concatenated features. 602 | """ 603 | def __init__(self, backbone, head): 604 | super(MultiCropWrapper, self).__init__() 605 | # disable layers dedicated to ImageNet labels classification 606 | backbone.fc, backbone.head = nn.Identity(), nn.Identity() 607 | self.backbone = backbone 608 | self.head = head 609 | 610 | def forward(self, x): 611 | # convert to list 612 | if not isinstance(x, list): 613 | x = [x] 614 | idx_crops = torch.cumsum(torch.unique_consecutive( 615 | torch.tensor([inp.shape[-1] for inp in x]), 616 | return_counts=True, 617 | )[1], 0) 618 | start_idx, output = 0, torch.empty(0).to(x[0].device) 619 | for end_idx in idx_crops: 620 | _out = self.backbone(torch.cat(x[start_idx: end_idx])) 621 | # The output is a tuple with XCiT model. See: 622 | # https://github.com/facebookresearch/xcit/blob/master/xcit.py#L404-L405 623 | if isinstance(_out, tuple): 624 | _out = _out[0] 625 | # accumulate outputs 626 | output = torch.cat((output, _out)) 627 | start_idx = end_idx 628 | # Run the head forward on the concatenated features. 629 | return self.head(output) 630 | 631 | 632 | def get_params_groups(model): 633 | regularized = [] 634 | not_regularized = [] 635 | for name, param in model.named_parameters(): 636 | if not param.requires_grad: 637 | continue 638 | # we do not regularize biases nor Norm parameters 639 | if name.endswith(".bias") or len(param.shape) == 1: 640 | not_regularized.append(param) 641 | else: 642 | regularized.append(param) 643 | return [{'params': regularized}, {'params': not_regularized, 'weight_decay': 0.}] 644 | 645 | 646 | def has_batchnorms(model): 647 | bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm) 648 | for name, module in model.named_modules(): 649 | if isinstance(module, bn_types): 650 | return True 651 | return False 652 | 653 | 654 | class PCA(): 655 | """ 656 | Class to compute and apply PCA. 657 | """ 658 | def __init__(self, dim=256, whit=0.5): 659 | self.dim = dim 660 | self.whit = whit 661 | self.mean = None 662 | 663 | def train_pca(self, cov): 664 | """ 665 | Takes a covariance matrix (np.ndarray) as input. 666 | """ 667 | d, v = np.linalg.eigh(cov) 668 | eps = d.max() * 1e-5 669 | n_0 = (d < eps).sum() 670 | if n_0 > 0: 671 | d[d < eps] = eps 672 | 673 | # total energy 674 | totenergy = d.sum() 675 | 676 | # sort eigenvectors with eigenvalues order 677 | idx = np.argsort(d)[::-1][:self.dim] 678 | d = d[idx] 679 | v = v[:, idx] 680 | 681 | print("keeping %.2f %% of the energy" % (d.sum() / totenergy * 100.0)) 682 | 683 | # for the whitening 684 | d = np.diag(1. / d**self.whit) 685 | 686 | # principal components 687 | self.dvt = np.dot(d, v.T) 688 | 689 | def apply(self, x): 690 | # input is from numpy 691 | if isinstance(x, np.ndarray): 692 | if self.mean is not None: 693 | x -= self.mean 694 | return np.dot(self.dvt, x.T).T 695 | 696 | # input is from torch and is on GPU 697 | if x.is_cuda: 698 | if self.mean is not None: 699 | x -= torch.cuda.FloatTensor(self.mean) 700 | return torch.mm(torch.cuda.FloatTensor(self.dvt), x.transpose(0, 1)).transpose(0, 1) 701 | 702 | # input if from torch, on CPU 703 | if self.mean is not None: 704 | x -= torch.FloatTensor(self.mean) 705 | return torch.mm(torch.FloatTensor(self.dvt), x.transpose(0, 1)).transpose(0, 1) 706 | 707 | 708 | def compute_ap(ranks, nres): 709 | """ 710 | Computes average precision for given ranked indexes. 711 | Arguments 712 | --------- 713 | ranks : zerro-based ranks of positive images 714 | nres : number of positive images 715 | Returns 716 | ------- 717 | ap : average precision 718 | """ 719 | 720 | # number of images ranked by the system 721 | nimgranks = len(ranks) 722 | 723 | # accumulate trapezoids in PR-plot 724 | ap = 0 725 | 726 | recall_step = 1. / nres 727 | 728 | for j in np.arange(nimgranks): 729 | rank = ranks[j] 730 | 731 | if rank == 0: 732 | precision_0 = 1. 733 | else: 734 | precision_0 = float(j) / rank 735 | 736 | precision_1 = float(j + 1) / (rank + 1) 737 | 738 | ap += (precision_0 + precision_1) * recall_step / 2. 739 | 740 | return ap 741 | 742 | 743 | def compute_map(ranks, gnd, kappas=[]): 744 | """ 745 | Computes the mAP for a given set of returned results. 746 | Usage: 747 | map = compute_map (ranks, gnd) 748 | computes mean average precsion (map) only 749 | map, aps, pr, prs = compute_map (ranks, gnd, kappas) 750 | computes mean average precision (map), average precision (aps) for each query 751 | computes mean precision at kappas (pr), precision at kappas (prs) for each query 752 | Notes: 753 | 1) ranks starts from 0, ranks.shape = db_size X #queries 754 | 2) The junk results (e.g., the query itself) should be declared in the gnd stuct array 755 | 3) If there are no positive images for some query, that query is excluded from the evaluation 756 | """ 757 | 758 | map = 0. 759 | nq = len(gnd) # number of queries 760 | aps = np.zeros(nq) 761 | pr = np.zeros(len(kappas)) 762 | prs = np.zeros((nq, len(kappas))) 763 | nempty = 0 764 | 765 | for i in np.arange(nq): 766 | qgnd = np.array(gnd[i]['ok']) 767 | 768 | # no positive images, skip from the average 769 | if qgnd.shape[0] == 0: 770 | aps[i] = float('nan') 771 | prs[i, :] = float('nan') 772 | nempty += 1 773 | continue 774 | 775 | try: 776 | qgndj = np.array(gnd[i]['junk']) 777 | except: 778 | qgndj = np.empty(0) 779 | 780 | # sorted positions of positive and junk images (0 based) 781 | pos = np.arange(ranks.shape[0])[np.in1d(ranks[:,i], qgnd)] 782 | junk = np.arange(ranks.shape[0])[np.in1d(ranks[:,i], qgndj)] 783 | 784 | k = 0; 785 | ij = 0; 786 | if len(junk): 787 | # decrease positions of positives based on the number of 788 | # junk images appearing before them 789 | ip = 0 790 | while (ip < len(pos)): 791 | while (ij < len(junk) and pos[ip] > junk[ij]): 792 | k += 1 793 | ij += 1 794 | pos[ip] = pos[ip] - k 795 | ip += 1 796 | 797 | # compute ap 798 | ap = compute_ap(pos, len(qgnd)) 799 | map = map + ap 800 | aps[i] = ap 801 | 802 | # compute precision @ k 803 | pos += 1 # get it to 1-based 804 | for j in np.arange(len(kappas)): 805 | kq = min(max(pos), kappas[j]); 806 | prs[i, j] = (pos <= kq).sum() / kq 807 | pr = pr + prs[i, :] 808 | 809 | map = map / (nq - nempty) 810 | pr = pr / (nq - nempty) 811 | 812 | return map, aps, pr, prs 813 | 814 | 815 | def multi_scale(samples, model): 816 | v = None 817 | for s in [1, 1/2**(1/2), 1/2]: # we use 3 different scales 818 | if s == 1: 819 | inp = samples.clone() 820 | else: 821 | inp = nn.functional.interpolate(samples, scale_factor=s, mode='bilinear', align_corners=False) 822 | feats = model(inp).clone() 823 | if v is None: 824 | v = feats 825 | else: 826 | v += feats 827 | v /= 3 828 | v /= v.norm() 829 | return v 830 | -------------------------------------------------------------------------------- /video_generation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import os 15 | import glob 16 | import sys 17 | import argparse 18 | import cv2 19 | 20 | from tqdm import tqdm 21 | import matplotlib.pyplot as plt 22 | import torch 23 | import torch.nn as nn 24 | import torchvision 25 | from torchvision import transforms as pth_transforms 26 | import numpy as np 27 | from PIL import Image 28 | 29 | import utils 30 | import vision_transformer as vits 31 | 32 | 33 | FOURCC = { 34 | "mp4": cv2.VideoWriter_fourcc(*"MP4V"), 35 | "avi": cv2.VideoWriter_fourcc(*"XVID"), 36 | } 37 | DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 38 | 39 | 40 | class VideoGenerator: 41 | def __init__(self, args): 42 | self.args = args 43 | # self.model = None 44 | # Don't need to load model if you only want a video 45 | if not self.args.video_only: 46 | self.model = self.__load_model() 47 | 48 | def run(self): 49 | if self.args.input_path is None: 50 | print(f"Provided input path {self.args.input_path} is non valid.") 51 | sys.exit(1) 52 | else: 53 | if self.args.video_only: 54 | self._generate_video_from_images( 55 | self.args.input_path, self.args.output_path 56 | ) 57 | else: 58 | # If input path exists 59 | if os.path.exists(self.args.input_path): 60 | # If input is a video file 61 | if os.path.isfile(self.args.input_path): 62 | frames_folder = os.path.join(self.args.output_path, "frames") 63 | attention_folder = os.path.join( 64 | self.args.output_path, "attention" 65 | ) 66 | 67 | os.makedirs(frames_folder, exist_ok=True) 68 | os.makedirs(attention_folder, exist_ok=True) 69 | 70 | self._extract_frames_from_video( 71 | self.args.input_path, frames_folder 72 | ) 73 | 74 | self._inference( 75 | frames_folder, 76 | attention_folder, 77 | ) 78 | 79 | self._generate_video_from_images( 80 | attention_folder, self.args.output_path 81 | ) 82 | 83 | # If input is a folder of already extracted frames 84 | if os.path.isdir(self.args.input_path): 85 | attention_folder = os.path.join( 86 | self.args.output_path, "attention" 87 | ) 88 | 89 | os.makedirs(attention_folder, exist_ok=True) 90 | 91 | self._inference(self.args.input_path, attention_folder) 92 | 93 | self._generate_video_from_images( 94 | attention_folder, self.args.output_path 95 | ) 96 | 97 | # If input path doesn't exists 98 | else: 99 | print(f"Provided input path {self.args.input_path} doesn't exists.") 100 | sys.exit(1) 101 | 102 | def _extract_frames_from_video(self, inp: str, out: str): 103 | vidcap = cv2.VideoCapture(inp) 104 | self.args.fps = vidcap.get(cv2.CAP_PROP_FPS) 105 | 106 | print(f"Video: {inp} ({self.args.fps} fps)") 107 | print(f"Extracting frames to {out}") 108 | 109 | success, image = vidcap.read() 110 | count = 0 111 | while success: 112 | cv2.imwrite( 113 | os.path.join(out, f"frame-{count:04}.jpg"), 114 | image, 115 | ) 116 | success, image = vidcap.read() 117 | count += 1 118 | 119 | def _generate_video_from_images(self, inp: str, out: str): 120 | img_array = [] 121 | attention_images_list = sorted(glob.glob(os.path.join(inp, "attn-*.jpg"))) 122 | 123 | # Get size of the first image 124 | with open(attention_images_list[0], "rb") as f: 125 | img = Image.open(f) 126 | img = img.convert("RGB") 127 | size = (img.width, img.height) 128 | img_array.append(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)) 129 | 130 | print(f"Generating video {size} to {out}") 131 | 132 | for filename in tqdm(attention_images_list[1:]): 133 | with open(filename, "rb") as f: 134 | img = Image.open(f) 135 | img = img.convert("RGB") 136 | img_array.append(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)) 137 | 138 | out = cv2.VideoWriter( 139 | os.path.join(out, "video." + self.args.video_format), 140 | FOURCC[self.args.video_format], 141 | self.args.fps, 142 | size, 143 | ) 144 | 145 | for i in range(len(img_array)): 146 | out.write(img_array[i]) 147 | out.release() 148 | print("Done") 149 | 150 | def _inference(self, inp: str, out: str): 151 | print(f"Generating attention images to {out}") 152 | 153 | for img_path in tqdm(sorted(glob.glob(os.path.join(inp, "*.jpg")))): 154 | with open(img_path, "rb") as f: 155 | img = Image.open(f) 156 | img = img.convert("RGB") 157 | 158 | if self.args.resize is not None: 159 | transform = pth_transforms.Compose( 160 | [ 161 | pth_transforms.ToTensor(), 162 | pth_transforms.Resize(self.args.resize), 163 | pth_transforms.Normalize( 164 | (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) 165 | ), 166 | ] 167 | ) 168 | else: 169 | transform = pth_transforms.Compose( 170 | [ 171 | pth_transforms.ToTensor(), 172 | pth_transforms.Normalize( 173 | (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) 174 | ), 175 | ] 176 | ) 177 | 178 | img = transform(img) 179 | 180 | # make the image divisible by the patch size 181 | w, h = ( 182 | img.shape[1] - img.shape[1] % self.args.patch_size, 183 | img.shape[2] - img.shape[2] % self.args.patch_size, 184 | ) 185 | img = img[:, :w, :h].unsqueeze(0) 186 | 187 | w_featmap = img.shape[-2] // self.args.patch_size 188 | h_featmap = img.shape[-1] // self.args.patch_size 189 | 190 | attentions = self.model.get_last_selfattention(img.to(DEVICE)) 191 | 192 | nh = attentions.shape[1] # number of head 193 | 194 | # we keep only the output patch attention 195 | attentions = attentions[0, :, 0, 1:].reshape(nh, -1) 196 | 197 | # we keep only a certain percentage of the mass 198 | val, idx = torch.sort(attentions) 199 | val /= torch.sum(val, dim=1, keepdim=True) 200 | cumval = torch.cumsum(val, dim=1) 201 | th_attn = cumval > (1 - self.args.threshold) 202 | idx2 = torch.argsort(idx) 203 | for head in range(nh): 204 | th_attn[head] = th_attn[head][idx2[head]] 205 | th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float() 206 | # interpolate 207 | th_attn = ( 208 | nn.functional.interpolate( 209 | th_attn.unsqueeze(0), 210 | scale_factor=self.args.patch_size, 211 | mode="nearest", 212 | )[0] 213 | .cpu() 214 | .numpy() 215 | ) 216 | 217 | attentions = attentions.reshape(nh, w_featmap, h_featmap) 218 | attentions = ( 219 | nn.functional.interpolate( 220 | attentions.unsqueeze(0), 221 | scale_factor=self.args.patch_size, 222 | mode="nearest", 223 | )[0] 224 | .cpu() 225 | .numpy() 226 | ) 227 | 228 | # save attentions heatmaps 229 | fname = os.path.join(out, "attn-" + os.path.basename(img_path)) 230 | plt.imsave( 231 | fname=fname, 232 | arr=sum( 233 | attentions[i] * 1 / attentions.shape[0] 234 | for i in range(attentions.shape[0]) 235 | ), 236 | cmap="inferno", 237 | format="jpg", 238 | ) 239 | 240 | def __load_model(self): 241 | # build model 242 | model = vits.__dict__[self.args.arch]( 243 | patch_size=self.args.patch_size, num_classes=0 244 | ) 245 | for p in model.parameters(): 246 | p.requires_grad = False 247 | model.eval() 248 | model.to(DEVICE) 249 | 250 | if os.path.isfile(self.args.pretrained_weights): 251 | state_dict = torch.load(self.args.pretrained_weights, map_location="cpu") 252 | if ( 253 | self.args.checkpoint_key is not None 254 | and self.args.checkpoint_key in state_dict 255 | ): 256 | print( 257 | f"Take key {self.args.checkpoint_key} in provided checkpoint dict" 258 | ) 259 | state_dict = state_dict[self.args.checkpoint_key] 260 | state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} 261 | # remove `backbone.` prefix induced by multicrop wrapper 262 | state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} 263 | msg = model.load_state_dict(state_dict, strict=False) 264 | print( 265 | "Pretrained weights found at {} and loaded with msg: {}".format( 266 | self.args.pretrained_weights, msg 267 | ) 268 | ) 269 | else: 270 | print( 271 | "Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate." 272 | ) 273 | url = None 274 | if self.args.arch == "vit_small" and self.args.patch_size == 16: 275 | url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth" 276 | elif self.args.arch == "vit_small" and self.args.patch_size == 8: 277 | url = "dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth" # model used for visualizations in our paper 278 | elif self.args.arch == "vit_base" and self.args.patch_size == 16: 279 | url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth" 280 | elif self.args.arch == "vit_base" and self.args.patch_size == 8: 281 | url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth" 282 | if url is not None: 283 | print( 284 | "Since no pretrained weights have been provided, we load the reference pretrained DINO weights." 285 | ) 286 | state_dict = torch.hub.load_state_dict_from_url( 287 | url="https://dl.fbaipublicfiles.com/dino/" + url 288 | ) 289 | model.load_state_dict(state_dict, strict=True) 290 | else: 291 | print( 292 | "There is no reference weights available for this model => We use random weights." 293 | ) 294 | return model 295 | 296 | 297 | def parse_args(): 298 | parser = argparse.ArgumentParser("Generation self-attention video") 299 | parser.add_argument( 300 | "--arch", 301 | default="vit_small", 302 | type=str, 303 | choices=["vit_tiny", "vit_small", "vit_base"], 304 | help="Architecture (support only ViT atm).", 305 | ) 306 | parser.add_argument( 307 | "--patch_size", default=8, type=int, help="Patch resolution of the self.model." 308 | ) 309 | parser.add_argument( 310 | "--pretrained_weights", 311 | default="", 312 | type=str, 313 | help="Path to pretrained weights to load.", 314 | ) 315 | parser.add_argument( 316 | "--checkpoint_key", 317 | default="teacher", 318 | type=str, 319 | help='Key to use in the checkpoint (example: "teacher")', 320 | ) 321 | parser.add_argument( 322 | "--input_path", 323 | required=True, 324 | type=str, 325 | help="""Path to a video file if you want to extract frames 326 | or to a folder of images already extracted by yourself. 327 | or to a folder of attention images.""", 328 | ) 329 | parser.add_argument( 330 | "--output_path", 331 | default="./", 332 | type=str, 333 | help="""Path to store a folder of frames and / or a folder of attention images. 334 | and / or a final video. Default to current directory.""", 335 | ) 336 | parser.add_argument( 337 | "--threshold", 338 | type=float, 339 | default=0.6, 340 | help="""We visualize masks 341 | obtained by thresholding the self-attention maps to keep xx percent of the mass.""", 342 | ) 343 | parser.add_argument( 344 | "--resize", 345 | default=None, 346 | type=int, 347 | nargs="+", 348 | help="""Apply a resize transformation to input image(s). Use if OOM error. 349 | Usage (single or W H): --resize 512, --resize 720 1280""", 350 | ) 351 | parser.add_argument( 352 | "--video_only", 353 | action="store_true", 354 | help="""Use this flag if you only want to generate a video and not all attention images. 355 | If used, --input_path must be set to the folder of attention images. Ex: ./attention/""", 356 | ) 357 | parser.add_argument( 358 | "--fps", 359 | default=30.0, 360 | type=float, 361 | help="FPS of input / output video. Automatically set if you extract frames from a video.", 362 | ) 363 | parser.add_argument( 364 | "--video_format", 365 | default="mp4", 366 | type=str, 367 | choices=["mp4", "avi"], 368 | help="Format of generated video (mp4 or avi).", 369 | ) 370 | 371 | return parser.parse_args() 372 | 373 | 374 | if __name__ == "__main__": 375 | args = parse_args() 376 | 377 | vg = VideoGenerator(args) 378 | vg.run() 379 | -------------------------------------------------------------------------------- /vision_transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Mostly copy-paste from timm library. 16 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 17 | """ 18 | import math 19 | from functools import partial 20 | 21 | import torch 22 | import torch.nn as nn 23 | 24 | from utils import trunc_normal_ 25 | 26 | 27 | def drop_path(x, drop_prob: float = 0., training: bool = False): 28 | if drop_prob == 0. or not training: 29 | return x 30 | keep_prob = 1 - drop_prob 31 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 32 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 33 | random_tensor.floor_() # binarize 34 | output = x.div(keep_prob) * random_tensor 35 | return output 36 | 37 | 38 | class DropPath(nn.Module): 39 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 40 | """ 41 | def __init__(self, drop_prob=None): 42 | super(DropPath, self).__init__() 43 | self.drop_prob = drop_prob 44 | 45 | def forward(self, x): 46 | return drop_path(x, self.drop_prob, self.training) 47 | 48 | 49 | class Mlp(nn.Module): 50 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 51 | super().__init__() 52 | out_features = out_features or in_features 53 | hidden_features = hidden_features or in_features 54 | self.fc1 = nn.Linear(in_features, hidden_features) 55 | self.act = act_layer() 56 | self.fc2 = nn.Linear(hidden_features, out_features) 57 | self.drop = nn.Dropout(drop) 58 | 59 | def forward(self, x): 60 | x = self.fc1(x) 61 | x = self.act(x) 62 | x = self.drop(x) 63 | x = self.fc2(x) 64 | x = self.drop(x) 65 | return x 66 | 67 | 68 | class Attention(nn.Module): 69 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 70 | super().__init__() 71 | self.num_heads = num_heads 72 | head_dim = dim // num_heads 73 | self.scale = qk_scale or head_dim ** -0.5 74 | 75 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 76 | self.attn_drop = nn.Dropout(attn_drop) 77 | self.proj = nn.Linear(dim, dim) 78 | self.proj_drop = nn.Dropout(proj_drop) 79 | 80 | def forward(self, x): 81 | B, N, C = x.shape 82 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 83 | q, k, v = qkv[0], qkv[1], qkv[2] 84 | 85 | attn = (q @ k.transpose(-2, -1)) * self.scale 86 | attn = attn.softmax(dim=-1) 87 | attn = self.attn_drop(attn) 88 | 89 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 90 | x = self.proj(x) 91 | x = self.proj_drop(x) 92 | return x, attn 93 | 94 | 95 | class Block(nn.Module): 96 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 97 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 98 | super().__init__() 99 | self.norm1 = norm_layer(dim) 100 | self.attn = Attention( 101 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 102 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 103 | self.norm2 = norm_layer(dim) 104 | mlp_hidden_dim = int(dim * mlp_ratio) 105 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 106 | 107 | def forward(self, x, return_attention=False): 108 | y, attn = self.attn(self.norm1(x)) 109 | if return_attention: 110 | return attn 111 | x = x + self.drop_path(y) 112 | x = x + self.drop_path(self.mlp(self.norm2(x))) 113 | return x 114 | 115 | 116 | class PatchEmbed(nn.Module): 117 | """ Image to Patch Embedding 118 | """ 119 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 120 | super().__init__() 121 | num_patches = (img_size // patch_size) * (img_size // patch_size) 122 | self.img_size = img_size 123 | self.patch_size = patch_size 124 | self.num_patches = num_patches 125 | 126 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 127 | 128 | def forward(self, x): 129 | B, C, H, W = x.shape 130 | x = self.proj(x).flatten(2).transpose(1, 2) 131 | return x 132 | 133 | 134 | class VisionTransformer(nn.Module): 135 | """ Vision Transformer """ 136 | def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12, 137 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 138 | drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs): 139 | super().__init__() 140 | self.num_features = self.embed_dim = embed_dim 141 | 142 | self.patch_embed = PatchEmbed( 143 | img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 144 | num_patches = self.patch_embed.num_patches 145 | 146 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 147 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 148 | self.pos_drop = nn.Dropout(p=drop_rate) 149 | 150 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 151 | self.blocks = nn.ModuleList([ 152 | Block( 153 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 154 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) 155 | for i in range(depth)]) 156 | self.norm = norm_layer(embed_dim) 157 | 158 | # Classifier head 159 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 160 | 161 | trunc_normal_(self.pos_embed, std=.02) 162 | trunc_normal_(self.cls_token, std=.02) 163 | self.apply(self._init_weights) 164 | 165 | def _init_weights(self, m): 166 | if isinstance(m, nn.Linear): 167 | trunc_normal_(m.weight, std=.02) 168 | if isinstance(m, nn.Linear) and m.bias is not None: 169 | nn.init.constant_(m.bias, 0) 170 | elif isinstance(m, nn.LayerNorm): 171 | nn.init.constant_(m.bias, 0) 172 | nn.init.constant_(m.weight, 1.0) 173 | 174 | def interpolate_pos_encoding(self, x, w, h): 175 | npatch = x.shape[1] - 1 176 | N = self.pos_embed.shape[1] - 1 177 | if npatch == N and w == h: 178 | return self.pos_embed 179 | class_pos_embed = self.pos_embed[:, 0] 180 | patch_pos_embed = self.pos_embed[:, 1:] 181 | dim = x.shape[-1] 182 | w0 = w // self.patch_embed.patch_size 183 | h0 = h // self.patch_embed.patch_size 184 | # we add a small number to avoid floating point error in the interpolation 185 | # see discussion at https://github.com/facebookresearch/dino/issues/8 186 | w0, h0 = w0 + 0.1, h0 + 0.1 187 | patch_pos_embed = nn.functional.interpolate( 188 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), 189 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), 190 | mode='bicubic', 191 | ) 192 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] 193 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) 194 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) 195 | 196 | def prepare_tokens(self, x): 197 | B, nc, w, h = x.shape 198 | x = self.patch_embed(x) # patch linear embedding 199 | 200 | # add the [CLS] token to the embed patch tokens 201 | cls_tokens = self.cls_token.expand(B, -1, -1) 202 | x = torch.cat((cls_tokens, x), dim=1) 203 | 204 | # add positional encoding to each token 205 | x = x + self.interpolate_pos_encoding(x, w, h) 206 | 207 | return self.pos_drop(x) 208 | 209 | def forward(self, x): 210 | x = self.prepare_tokens(x) 211 | for blk in self.blocks: 212 | x = blk(x) 213 | x = self.norm(x) 214 | return x[:, 0] 215 | 216 | def get_last_selfattention(self, x): 217 | x = self.prepare_tokens(x) 218 | for i, blk in enumerate(self.blocks): 219 | if i < len(self.blocks) - 1: 220 | x = blk(x) 221 | else: 222 | # return attention of the last block 223 | return blk(x, return_attention=True) 224 | 225 | def get_intermediate_layers(self, x, n=1): 226 | x = self.prepare_tokens(x) 227 | # we return the output tokens from the `n` last blocks 228 | output = [] 229 | for i, blk in enumerate(self.blocks): 230 | x = blk(x) 231 | if len(self.blocks) - i <= n: 232 | output.append(self.norm(x)) 233 | return output 234 | 235 | 236 | def vit_tiny(patch_size=16, **kwargs): 237 | model = VisionTransformer( 238 | patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, 239 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 240 | return model 241 | 242 | 243 | def vit_small(patch_size=16, **kwargs): 244 | model = VisionTransformer( 245 | patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, 246 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 247 | return model 248 | 249 | 250 | def vit_base(patch_size=16, **kwargs): 251 | model = VisionTransformer( 252 | patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, 253 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 254 | return model 255 | 256 | 257 | class DINOHead(nn.Module): 258 | def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256): 259 | super().__init__() 260 | nlayers = max(nlayers, 1) 261 | if nlayers == 1: 262 | self.mlp = nn.Linear(in_dim, bottleneck_dim) 263 | else: 264 | layers = [nn.Linear(in_dim, hidden_dim)] 265 | if use_bn: 266 | layers.append(nn.BatchNorm1d(hidden_dim)) 267 | layers.append(nn.GELU()) 268 | for _ in range(nlayers - 2): 269 | layers.append(nn.Linear(hidden_dim, hidden_dim)) 270 | if use_bn: 271 | layers.append(nn.BatchNorm1d(hidden_dim)) 272 | layers.append(nn.GELU()) 273 | layers.append(nn.Linear(hidden_dim, bottleneck_dim)) 274 | self.mlp = nn.Sequential(*layers) 275 | self.apply(self._init_weights) 276 | self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) 277 | self.last_layer.weight_g.data.fill_(1) 278 | if norm_last_layer: 279 | self.last_layer.weight_g.requires_grad = False 280 | 281 | def _init_weights(self, m): 282 | if isinstance(m, nn.Linear): 283 | trunc_normal_(m.weight, std=.02) 284 | if isinstance(m, nn.Linear) and m.bias is not None: 285 | nn.init.constant_(m.bias, 0) 286 | 287 | def forward(self, x): 288 | x = self.mlp(x) 289 | x = nn.functional.normalize(x, dim=-1, p=2) 290 | x = self.last_layer(x) 291 | return x 292 | -------------------------------------------------------------------------------- /visualize_attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import os 15 | import sys 16 | import argparse 17 | import cv2 18 | import random 19 | import colorsys 20 | import requests 21 | from io import BytesIO 22 | 23 | import skimage.io 24 | from skimage.measure import find_contours 25 | import matplotlib.pyplot as plt 26 | from matplotlib.patches import Polygon 27 | import torch 28 | import torch.nn as nn 29 | import torchvision 30 | from torchvision import transforms as pth_transforms 31 | import numpy as np 32 | from PIL import Image 33 | 34 | import utils 35 | import vision_transformer as vits 36 | 37 | 38 | def apply_mask(image, mask, color, alpha=0.5): 39 | for c in range(3): 40 | image[:, :, c] = image[:, :, c] * (1 - alpha * mask) + alpha * mask * color[c] * 255 41 | return image 42 | 43 | 44 | def random_colors(N, bright=True): 45 | """ 46 | Generate random colors. 47 | """ 48 | brightness = 1.0 if bright else 0.7 49 | hsv = [(i / N, 1, brightness) for i in range(N)] 50 | colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv)) 51 | random.shuffle(colors) 52 | return colors 53 | 54 | 55 | def display_instances(image, mask, fname="test", figsize=(5, 5), blur=False, contour=True, alpha=0.5): 56 | fig = plt.figure(figsize=figsize, frameon=False) 57 | ax = plt.Axes(fig, [0., 0., 1., 1.]) 58 | ax.set_axis_off() 59 | fig.add_axes(ax) 60 | ax = plt.gca() 61 | 62 | N = 1 63 | mask = mask[None, :, :] 64 | # Generate random colors 65 | colors = random_colors(N) 66 | 67 | # Show area outside image boundaries. 68 | height, width = image.shape[:2] 69 | margin = 0 70 | ax.set_ylim(height + margin, -margin) 71 | ax.set_xlim(-margin, width + margin) 72 | ax.axis('off') 73 | masked_image = image.astype(np.uint32).copy() 74 | for i in range(N): 75 | color = colors[i] 76 | _mask = mask[i] 77 | if blur: 78 | _mask = cv2.blur(_mask,(10,10)) 79 | # Mask 80 | masked_image = apply_mask(masked_image, _mask, color, alpha) 81 | # Mask Polygon 82 | # Pad to ensure proper polygons for masks that touch image edges. 83 | if contour: 84 | padded_mask = np.zeros((_mask.shape[0] + 2, _mask.shape[1] + 2)) 85 | padded_mask[1:-1, 1:-1] = _mask 86 | contours = find_contours(padded_mask, 0.5) 87 | for verts in contours: 88 | # Subtract the padding and flip (y, x) to (x, y) 89 | verts = np.fliplr(verts) - 1 90 | p = Polygon(verts, facecolor="none", edgecolor=color) 91 | ax.add_patch(p) 92 | ax.imshow(masked_image.astype(np.uint8), aspect='auto') 93 | fig.savefig(fname) 94 | print(f"{fname} saved.") 95 | return 96 | 97 | 98 | if __name__ == '__main__': 99 | parser = argparse.ArgumentParser('Visualize Self-Attention maps') 100 | parser.add_argument('--arch', default='vit_small', type=str, 101 | choices=['vit_tiny', 'vit_small', 'vit_base'], help='Architecture (support only ViT atm).') 102 | parser.add_argument('--patch_size', default=8, type=int, help='Patch resolution of the model.') 103 | parser.add_argument('--pretrained_weights', default='', type=str, 104 | help="Path to pretrained weights to load.") 105 | parser.add_argument("--checkpoint_key", default="teacher", type=str, 106 | help='Key to use in the checkpoint (example: "teacher")') 107 | parser.add_argument("--image_path", default=None, type=str, help="Path of the image to load.") 108 | parser.add_argument("--image_size", default=(480, 480), type=int, nargs="+", help="Resize image.") 109 | parser.add_argument('--output_dir', default='.', help='Path where to save visualizations.') 110 | parser.add_argument("--threshold", type=float, default=None, help="""We visualize masks 111 | obtained by thresholding the self-attention maps to keep xx% of the mass.""") 112 | args = parser.parse_args() 113 | 114 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 115 | # build model 116 | model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0) 117 | for p in model.parameters(): 118 | p.requires_grad = False 119 | model.eval() 120 | model.to(device) 121 | if os.path.isfile(args.pretrained_weights): 122 | state_dict = torch.load(args.pretrained_weights, map_location="cpu") 123 | if args.checkpoint_key is not None and args.checkpoint_key in state_dict: 124 | print(f"Take key {args.checkpoint_key} in provided checkpoint dict") 125 | state_dict = state_dict[args.checkpoint_key] 126 | # remove `module.` prefix 127 | state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} 128 | # remove `backbone.` prefix induced by multicrop wrapper 129 | state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} 130 | msg = model.load_state_dict(state_dict, strict=False) 131 | print('Pretrained weights found at {} and loaded with msg: {}'.format(args.pretrained_weights, msg)) 132 | else: 133 | print("Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate.") 134 | url = None 135 | if args.arch == "vit_small" and args.patch_size == 16: 136 | url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth" 137 | elif args.arch == "vit_small" and args.patch_size == 8: 138 | url = "dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth" # model used for visualizations in our paper 139 | elif args.arch == "vit_base" and args.patch_size == 16: 140 | url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth" 141 | elif args.arch == "vit_base" and args.patch_size == 8: 142 | url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth" 143 | if url is not None: 144 | print("Since no pretrained weights have been provided, we load the reference pretrained DINO weights.") 145 | state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url) 146 | model.load_state_dict(state_dict, strict=True) 147 | else: 148 | print("There is no reference weights available for this model => We use random weights.") 149 | 150 | # open image 151 | if args.image_path is None: 152 | # user has not specified any image - we use our own image 153 | print("Please use the `--image_path` argument to indicate the path of the image you wish to visualize.") 154 | print("Since no image path have been provided, we take the first image in our paper.") 155 | response = requests.get("https://dl.fbaipublicfiles.com/dino/img.png") 156 | img = Image.open(BytesIO(response.content)) 157 | img = img.convert('RGB') 158 | elif os.path.isfile(args.image_path): 159 | with open(args.image_path, 'rb') as f: 160 | img = Image.open(f) 161 | img = img.convert('RGB') 162 | else: 163 | print(f"Provided image path {args.image_path} is non valid.") 164 | sys.exit(1) 165 | transform = pth_transforms.Compose([ 166 | pth_transforms.Resize(args.image_size), 167 | pth_transforms.ToTensor(), 168 | pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 169 | ]) 170 | img = transform(img) 171 | 172 | # make the image divisible by the patch size 173 | w, h = img.shape[1] - img.shape[1] % args.patch_size, img.shape[2] - img.shape[2] % args.patch_size 174 | img = img[:, :w, :h].unsqueeze(0) 175 | 176 | w_featmap = img.shape[-2] // args.patch_size 177 | h_featmap = img.shape[-1] // args.patch_size 178 | 179 | attentions = model.get_last_selfattention(img.to(device)) 180 | 181 | nh = attentions.shape[1] # number of head 182 | 183 | # we keep only the output patch attention 184 | attentions = attentions[0, :, 0, 1:].reshape(nh, -1) 185 | 186 | if args.threshold is not None: 187 | # we keep only a certain percentage of the mass 188 | val, idx = torch.sort(attentions) 189 | val /= torch.sum(val, dim=1, keepdim=True) 190 | cumval = torch.cumsum(val, dim=1) 191 | th_attn = cumval > (1 - args.threshold) 192 | idx2 = torch.argsort(idx) 193 | for head in range(nh): 194 | th_attn[head] = th_attn[head][idx2[head]] 195 | th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float() 196 | # interpolate 197 | th_attn = nn.functional.interpolate(th_attn.unsqueeze(0), scale_factor=args.patch_size, mode="nearest")[0].cpu().numpy() 198 | 199 | attentions = attentions.reshape(nh, w_featmap, h_featmap) 200 | attentions = nn.functional.interpolate(attentions.unsqueeze(0), scale_factor=args.patch_size, mode="nearest")[0].cpu().numpy() 201 | 202 | # save attentions heatmaps 203 | os.makedirs(args.output_dir, exist_ok=True) 204 | torchvision.utils.save_image(torchvision.utils.make_grid(img, normalize=True, scale_each=True), os.path.join(args.output_dir, "img.png")) 205 | for j in range(nh): 206 | fname = os.path.join(args.output_dir, "attn-head" + str(j) + ".png") 207 | plt.imsave(fname=fname, arr=attentions[j], format='png') 208 | print(f"{fname} saved.") 209 | 210 | if args.threshold is not None: 211 | image = skimage.io.imread(os.path.join(args.output_dir, "img.png")) 212 | for j in range(nh): 213 | display_instances(image, th_attn[j], fname=os.path.join(args.output_dir, "mask_th" + str(args.threshold) + "_head" + str(j) +".png"), blur=False) 214 | --------------------------------------------------------------------------------