├── .gitignore ├── CPDataset_HD.py ├── CPDataset_HD_new.py ├── DressCodeDataSet.py ├── DressCodeDataSet_new.py ├── README.md ├── TikTokDataSet.py ├── TikTokDataSet_new.py ├── VTTDataSet.py ├── VTTDataSet_train.py ├── VideoDataSet.py ├── WildVideoDataSet.py ├── anydoor_train.py ├── anydoor_train_TikTok.py ├── anydoor_train_video.py ├── autoencoder_kl_emasc.py ├── blended_cloth_pipeline.py ├── blended_cloth_pipeline_new.py ├── config.py ├── dino_module.py ├── dinov2 ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── MODEL_CARD.md ├── README.md ├── conda.yaml ├── dinov2 │ ├── __init__.py │ ├── configs │ │ ├── __init__.py │ │ ├── eval │ │ │ ├── vitb14_pretrain.yaml │ │ │ ├── vitg14_pretrain.yaml │ │ │ ├── vitl14_pretrain.yaml │ │ │ └── vits14_pretrain.yaml │ │ ├── ssl_default_config.yaml │ │ └── train │ │ │ ├── vitg14.yaml │ │ │ ├── vitl14.yaml │ │ │ └── vitl16_short.yaml │ ├── data │ │ ├── __init__.py │ │ ├── adapters.py │ │ ├── augmentations.py │ │ ├── collate.py │ │ ├── datasets │ │ │ ├── __init__.py │ │ │ ├── decoders.py │ │ │ ├── extended.py │ │ │ ├── image_net.py │ │ │ └── image_net_22k.py │ │ ├── loaders.py │ │ ├── masking.py │ │ ├── samplers.py │ │ └── transforms.py │ ├── distributed │ │ └── __init__.py │ ├── eval │ │ ├── __init__.py │ │ ├── knn.py │ │ ├── linear.py │ │ ├── log_regression.py │ │ ├── metrics.py │ │ ├── setup.py │ │ └── utils.py │ ├── fsdp │ │ └── __init__.py │ ├── layers │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── block.py │ │ ├── dino_head.py │ │ ├── drop_path.py │ │ ├── layer_scale.py │ │ ├── mlp.py │ │ ├── patch_embed.py │ │ └── swiglu_ffn.py │ ├── logging │ │ ├── __init__.py │ │ └── helpers.py │ ├── loss │ │ ├── __init__.py │ │ ├── dino_clstoken_loss.py │ │ ├── ibot_patch_loss.py │ │ └── koleo_loss.py │ ├── models │ │ ├── __init__.py │ │ └── vision_transformer.py │ ├── run │ │ ├── __init__.py │ │ ├── eval │ │ │ ├── knn.py │ │ │ ├── linear.py │ │ │ └── log_regression.py │ │ ├── submit.py │ │ └── train │ │ │ └── train.py │ ├── train │ │ ├── __init__.py │ │ ├── ssl_meta_arch.py │ │ └── train.py │ └── utils │ │ ├── __init__.py │ │ ├── cluster.py │ │ ├── config.py │ │ ├── dtype.py │ │ ├── param_groups.py │ │ └── utils.py ├── hubconf.py ├── pyproject.toml ├── requirements-dev.txt ├── requirements.txt ├── scripts │ └── lint.sh ├── setup.cfg ├── setup.py └── vis.py ├── evaluate.py ├── extract_dino_fea.py ├── extract_dino_fea_vtt.py ├── figures ├── Figure12_TikTok_case1.mp4 ├── Figure12_TikTok_case5.mp4 ├── case1.gif ├── case2.gif ├── p1.png └── p2.png ├── infer.py ├── infer_new.py ├── infer_video.py ├── infer_video_guided.py ├── infer_video_mae_clip_guided.py ├── infer_video_mae_guided.py ├── infer_video_mae_guided_2.py ├── infer_video_new.py ├── infer_video_vtt_list.py ├── mae ├── FINETUNE.md ├── INSTALL.md ├── README.md ├── TARGETS ├── TODO.md ├── __pycache__ │ ├── models_mae.cpython-310.pyc │ └── models_mae.cpython-39.pyc ├── demo-preview.png ├── demo │ ├── goods.mp4 │ ├── qZ_lFjCiR1c_000104_000114.avi │ └── v_KW4TDvxIc_000223_000233.mp4 ├── engine_finetune.py ├── engine_pretrain.py ├── engine_test.py ├── launch_flow.sh ├── launch_local.sh ├── launch_tensorboard.sh ├── main_finetune.py ├── main_linprobe.py ├── main_pretrain.py ├── main_test.py ├── masks.png ├── models_mae.py ├── models_vit.py ├── mr-95-demo-vid-0.gif ├── mr-95-demo-vid-1.gif ├── mr-98-demo-vid-0.gif ├── mr-98-demo-vid-1.gif ├── rm_files │ ├── submitit_finetune.py │ ├── submitit_linprobe.py │ └── submitit_pretrain.py ├── run_finetune.py ├── run_pretrain.py ├── run_test.py ├── teaser.png ├── util │ ├── .misc.py.swp │ ├── __pycache__ │ │ ├── logging.cpython-310.pyc │ │ ├── logging.cpython-39.pyc │ │ ├── pos_embed.cpython-310.pyc │ │ ├── pos_embed.cpython-39.pyc │ │ ├── video_vit.cpython-310.pyc │ │ └── video_vit.cpython-39.pyc │ ├── crop.py │ ├── datasets.py │ ├── decoder │ │ ├── __pycache__ │ │ │ ├── rand_augment.cpython-310.pyc │ │ │ ├── rand_augment.cpython-39.pyc │ │ │ ├── random_erasing.cpython-310.pyc │ │ │ ├── random_erasing.cpython-39.pyc │ │ │ ├── transform.cpython-310.pyc │ │ │ ├── transform.cpython-39.pyc │ │ │ ├── utils.cpython-310.pyc │ │ │ └── utils.cpython-39.pyc │ │ ├── decoder.py │ │ ├── mixup.py │ │ ├── rand_augment.py │ │ ├── random_erasing.py │ │ ├── transform.py │ │ ├── utils.py │ │ └── video_container.py │ ├── env.py │ ├── kinetics.py │ ├── lars.py │ ├── logging.py │ ├── lr_decay.py │ ├── lr_sched.py │ ├── meters.py │ ├── misc.py │ ├── pos_embed.py │ └── video_vit.py └── video_mae_visualize.ipynb ├── one_stage_train.py ├── requirements.txt ├── train_cloth_vae_agnostic.py ├── train_cloth_vae_agnostic_TikTok.py ├── unet_emasc.py ├── utils.py ├── video_models ├── __pycache__ │ ├── attention.cpython-310.pyc │ ├── attention.cpython-39.pyc │ ├── resnet.cpython-310.pyc │ ├── resnet.cpython-39.pyc │ ├── unet.cpython-310.pyc │ ├── unet.cpython-39.pyc │ ├── unet_blocks.cpython-310.pyc │ ├── unet_blocks.cpython-39.pyc │ ├── video_pipeline.cpython-39.pyc │ ├── video_pipeline_guided.cpython-310.pyc │ ├── video_pipeline_guided.cpython-39.pyc │ ├── video_pipeline_mae_guided.cpython-310.pyc │ ├── video_pipeline_mae_guided.cpython-39.pyc │ └── video_pipeline_mae_guided_2.cpython-310.pyc ├── attentio2.py ├── attention.py ├── bug.txt ├── resnet.py ├── unet.py ├── unet_blocks.py ├── util.py ├── video_pipeline.py ├── video_pipeline_guided.py ├── video_pipeline_mae_clip_guided.py ├── video_pipeline_mae_clip_guided_2.py ├── video_pipeline_mae_guided.py ├── video_pipeline_mae_guided_2.py └── video_pipeline_new.py ├── videomae ├── DATASET.md ├── FINETUNE.md ├── INSTALL.md ├── LICENSE ├── MODEL_ZOO.md ├── NOTICE.md ├── PRETRAIN.md ├── README.md ├── __pycache__ │ ├── datasets.cpython-310.pyc │ ├── functional.cpython-310.pyc │ ├── kinetics.cpython-310.pyc │ ├── masking_generator.cpython-310.pyc │ ├── modeling_finetune.cpython-310.pyc │ ├── modeling_pretrain.cpython-310.pyc │ ├── rand_augment.cpython-310.pyc │ ├── random_erasing.cpython-310.pyc │ ├── ssv2.cpython-310.pyc │ ├── transforms.cpython-310.pyc │ ├── video_transforms.cpython-310.pyc │ └── volume_transforms.cpython-310.pyc ├── datasets.py ├── engine_for_finetuning.py ├── engine_for_pretraining.py ├── functional.py ├── kinetics.py ├── masking_generator.py ├── mixup.py ├── modeling_finetune.py ├── modeling_pretrain.py ├── optim_factory.py ├── rand_augment.py ├── random_erasing.py ├── run_class_finetuning.py ├── run_mae_pretraining.py ├── ssv2.py ├── transforms.py ├── utils.py ├── video_transforms.py ├── vis.sh └── volume_transforms.py ├── vis2.py └── wild_config.py /.gitignore: -------------------------------------------------------------------------------- 1 | *__pycache__* 2 | *.ipynb_checkpoints* 3 | *.jpg 4 | *.pth 5 | *.pyc* 6 | *.jpg 7 | *.png 8 | *.gif -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # data 2 | dataroot = '/data3/hzj/zalando-hd-resized' 3 | dataroot2 = '/data3/hzj/DressCode/upper_body/' 4 | dataroot3 = '/root/autodl-tmp/192_256/' 5 | dataroot4 = '/data3/hzj/TikTok_dataset2/' 6 | vtt_data_list = '/data1/hzj/192_256/custom_test_pairs.txt' 7 | fine_height = 512 8 | fine_width = 384 9 | with_one_hot = False 10 | with_parse_agnostic = True 11 | with_agnostic = True 12 | semantic_nc = 13 13 | 14 | # infer 15 | model_path = '/data1/hzj/magic-animate/pretrained_models/stable-diffusion-v1-5' # basic stable diffusion 16 | 17 | # VITON 18 | # # unet_path = 'model_VITON_512_DINO_large_large2/checkpoint-45000' 19 | # unet_path = 'model_VITON_512_DINO_large/checkpoint-40000' 20 | # # unet_path = 'trained_models/model_VITON_512_fixbug/checkpoint-120000' 21 | # # unet_path = '/data0/hzj/anydoor/trained_models/agnostic_norm_hair_have_background/checkpoint-50000/' # paper result 22 | # vae_path = 'trained_models/HR_VITON_vae' 23 | # test_dataset = 'VITON' 24 | # infer_datamode = 'test' 25 | # infer_data_list = 'test_pairs.txt' 26 | # infer_datasetting = 'unpaired' 27 | # out_dir = 'gen_test_VITON_new_scale2.0_0129_280' 28 | 29 | # DressCode 30 | # unet_path = '/data0/hzj/anydoor/trained_models/model_VITON_512_fixbug/checkpoint_120000' 31 | # unet_path = 'model_VITON_512_DINO_large_large2/checkpoint-45000' 32 | # vae_path = 'trained_models/HR_VITON_vae' 33 | # test_dataset = 'DressCode' 34 | # infer_datamode = 'test' 35 | # infer_data_list = 'test_pairs_unpaired.txt' 36 | # infer_datasetting = 'unpaired' 37 | # out_dir = 'gen_test_DressCode_new_scale2.0_0129' 38 | 39 | # TikTok 40 | # unet_path = '../anydoor/trained_models/model_TikTok_512_fixbug_1109_lip/checkpoint-150000' # tiktok model 41 | # unet_path = '/data1/hzj/anydoor/trained_models/model_TikTok_rebuttal_ft_from_TikTok/checkpoint-55000' 42 | # unet_path = '/data1/hzj/anydoor/trained_models/model_VITON_512_DINO_large_large_TikTok2/checkpoint-45000/' 43 | # unet_path = 'model_TikTok_eccv_ft/checkpoint-30000' 44 | unet_path = 'model_TikTok_eccv_ft/checkpoint-30000' 45 | # vae_path = '../anydoor/trained_models/HR_VITON_vae' # VITON vae 46 | vae_path = '/data1/hzj/magic-animate/pretrained_models/sd-vae-ft-mse/' 47 | # out_dir = 'gen_test_TikTOk_rebuttal_large_3.0' 48 | # out_dir = 'gen_test_TikTok_rebuttal_scale_102_321_video2' 49 | # out_dir = 'gen_test_134_138' 50 | out_dir = 'gen_test_128_102_2' 51 | test_dataset = 'TikTok' 52 | infer_datamode = 'test' 53 | # infer_data_list = 'train_unpairs_sp_159_167.txt' 54 | # infer_data_list = 'train_unpairs_sp_134_138.txt' 55 | # infer_data_list = 'train_unpairs_sp_179_167.txt' 56 | # infer_data_list = 'train_unpairs_sp_202_267.txt' 57 | infer_data_list = 'train_unpairs_sp_128_102.txt' 58 | infer_datasetting = 'unpaired' 59 | 60 | # # VTT 61 | # unet_path = 'trained_models/model_VTT_192_256_1030_fixbug/checkpoint-120000' # tiktok model 62 | # vae_path = 'trained_models/HR_VITON_vae' # VITON vae 63 | # output_root = 'gen_test_VTT' 64 | # test_dataset = 'VTT' 65 | # infer_datamode = 'test' 66 | # infer_data_list = 'test_pairs_sp.txt' 67 | # infer_datasetting = 'unpaired' 68 | # fine_height = 256 69 | # fine_width = 192 70 | 71 | # # unet_path = 'model_TikTok_512_fixbug_1109_atr/checkpoint-60000' 72 | # unet_path = 'model_TikTok_512_fixbug_1109_lip/checkpoint-150000' # tiktok model 73 | # # unet_path = 'model_VTT_192_256_1030_fixbug/checkpoint-80000' # VVT model 74 | # # unet_path = 'model_VTT_192_256_1023/checkpoint-62000' # VVT model 75 | # # unet_path = '/data1/hzj/agnostic_norm_hair_have_background/checkpoint-50000/' # VITON-HD and DressCode model 76 | # # vae_path = 'model_TikTok_vae_512_fixbug/checkpoint-4000' # TikTok vae, not use 77 | # vae_path = '../virtual_try_on_code/save_models/HR_VITON_vae' # VITON vae 78 | # # vae_path = 'model_VTT_vae/checkpoint-8000' # VTT vae, not use 79 | # # vae_path = 'parse_other_norm_nobackground_vae/checkpoint-14000' 80 | # out_dir = 'test_TikTok_video_demo/test1' 81 | 82 | # train data 83 | train_datamode = 'train' 84 | train_data_list = 'train_pairs.txt' 85 | # train_data_list = 'train_unpairs_sp_50_38.txt' 86 | train_datasetting = 'paired' 87 | 88 | # train 89 | pretrained_model_name_or_path = '/data1/hzj/magic-animate/pretrained_models/stable-diffusion-v1-5' 90 | # output_dir = 'model_VITON_512_DINO_large_large_TikTok2' # one-stage train 91 | output_dir = 'model_TikTok_eccv_ft_dino560_dresscode_06' # one-stage train 92 | # output_dir = 'model_TikTok_rebuttal_ft_from_TikTok' 93 | # output_dir = 'model_VTT_vae_fixhand' 94 | revision = None 95 | validation_prompt = None 96 | num_validation_images = 1 97 | validation_epoches = 1 98 | max_train_samples = None 99 | cache_dir = None 100 | seed = 42 101 | resolution = 512 102 | center_crop = False 103 | random_flip = True 104 | train_batch_size = 16 105 | max_train_steps = None 106 | num_train_epochs = 1000 107 | gradient_accumulation_steps = 1 108 | gradient_checkpointing = True 109 | learning_rate = 5e-05 110 | scale_lr = False 111 | lr_scheduler = 'constant' 112 | lr_warmup_steps = 0 113 | conditioning_dropout_prob = 0.00 114 | use_8bit_adam = False 115 | allow_tf32 = True 116 | use_ema = False 117 | non_ema_revision = None 118 | dataloader_num_workers = 10 119 | adam_beta1 = 0.9 120 | adam_beta2 = 0.999 121 | adam_weight_decay = 1e-2 122 | adam_epsilon = 1e-08 123 | max_grad_norm = 1.0 124 | logging_dir = 'logs' 125 | mixed_precision = 'fp16' 126 | report_to = 'tensorboard' 127 | local_rank = -1 128 | checkpointing_steps = 5000 129 | checkpoints_total_limit = None 130 | resume_from_checkpoint = 'latest' 131 | # resume_from_checkpoint = None 132 | enable_xformers_memory_efficient_attention = True 133 | -------------------------------------------------------------------------------- /dino_module.py: -------------------------------------------------------------------------------- 1 | ## from anydoor 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import sys 7 | 8 | class AbstractEncoder(nn.Module): 9 | def __init__(self): 10 | super().__init__() 11 | 12 | def encode(self, *args, **kwargs): 13 | raise NotImplementedError 14 | 15 | sys.path.append("./dinov2") 16 | import hubconf 17 | from omegaconf import OmegaConf 18 | # DINOv2_weight_path = '/root/autodl-tmp/AnyDoor/trained_models/dinov2_vitg14_pretrain.pth' 19 | DINOv2_weight_path = '/data1/hzj/anydoor/trained_models/dinov2_vitl14_pretrain.pth' 20 | 21 | class FrozenDinoV2Encoder(AbstractEncoder): 22 | """ 23 | Uses the DINOv2 encoder for image 24 | """ 25 | def __init__(self, freeze=True): 26 | super().__init__() 27 | # dinov2 = hubconf.dinov2_vitg14() 28 | dinov2 = hubconf.dinov2_vitl14() 29 | state_dict = torch.load(DINOv2_weight_path) 30 | dinov2.load_state_dict(state_dict, strict=False) 31 | self.model = dinov2 32 | if freeze: 33 | self.freeze() 34 | 35 | def freeze(self): 36 | self.model.eval() 37 | for param in self.model.parameters(): 38 | param.requires_grad = False 39 | 40 | def forward(self, image): 41 | features = self.model.forward_features(image) 42 | tokens = features["x_norm_patchtokens"] 43 | image_features = features["x_norm_clstoken"] 44 | image_features = image_features.unsqueeze(1) 45 | hint = torch.cat([image_features,tokens],1) # b,257,1536 46 | return hint 47 | 48 | if __name__ == '__main__': 49 | p = FrozenDinoV2Encoder() 50 | -------------------------------------------------------------------------------- /dinov2/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /dinov2/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to DINOv2 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Meta's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to DINOv2, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. 32 | -------------------------------------------------------------------------------- /dinov2/conda.yaml: -------------------------------------------------------------------------------- 1 | name: dinov2 2 | channels: 3 | - defaults 4 | - pytorch 5 | - nvidia 6 | - xformers 7 | - conda-forge 8 | dependencies: 9 | - python=3.9 10 | - pytorch::pytorch=2.0.0 11 | - pytorch::pytorch-cuda=11.7.0 12 | - pytorch::torchvision=0.15.0 13 | - omegaconf 14 | - torchmetrics=0.10.3 15 | - fvcore 16 | - iopath 17 | - xformers::xformers=0.0.18 18 | - pip 19 | - pip: 20 | - git+https://github.com/facebookincubator/submitit 21 | - --extra-index-url https://pypi.nvidia.com 22 | - cuml-cu11 23 | -------------------------------------------------------------------------------- /dinov2/dinov2/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | __version__ = "0.0.1" 8 | -------------------------------------------------------------------------------- /dinov2/dinov2/configs/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import pathlib 8 | 9 | from omegaconf import OmegaConf 10 | 11 | 12 | def load_config(config_name: str): 13 | config_filename = config_name + ".yaml" 14 | return OmegaConf.load(pathlib.Path(__file__).parent.resolve() / config_filename) 15 | 16 | 17 | dinov2_default_config = load_config("ssl_default_config") 18 | 19 | 20 | def load_and_merge_config(config_name: str): 21 | default_config = OmegaConf.create(dinov2_default_config) 22 | loaded_config = load_config(config_name) 23 | return OmegaConf.merge(default_config, loaded_config) 24 | -------------------------------------------------------------------------------- /dinov2/dinov2/configs/eval/vitb14_pretrain.yaml: -------------------------------------------------------------------------------- 1 | student: 2 | arch: vit_base 3 | patch_size: 14 4 | crops: 5 | global_crops_size: 518 # this is to set up the position embeddings properly 6 | local_crops_size: 98 -------------------------------------------------------------------------------- /dinov2/dinov2/configs/eval/vitg14_pretrain.yaml: -------------------------------------------------------------------------------- 1 | student: 2 | arch: vit_giant2 3 | patch_size: 14 4 | ffn_layer: swiglufused 5 | crops: 6 | global_crops_size: 518 # this is to set up the position embeddings properly 7 | local_crops_size: 98 -------------------------------------------------------------------------------- /dinov2/dinov2/configs/eval/vitl14_pretrain.yaml: -------------------------------------------------------------------------------- 1 | student: 2 | arch: vit_large 3 | patch_size: 14 4 | crops: 5 | global_crops_size: 518 # this is to set up the position embeddings properly 6 | local_crops_size: 98 -------------------------------------------------------------------------------- /dinov2/dinov2/configs/eval/vits14_pretrain.yaml: -------------------------------------------------------------------------------- 1 | student: 2 | arch: vit_small 3 | patch_size: 14 4 | crops: 5 | global_crops_size: 518 # this is to set up the position embeddings properly 6 | local_crops_size: 98 -------------------------------------------------------------------------------- /dinov2/dinov2/configs/ssl_default_config.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | WEIGHTS: '' 3 | compute_precision: 4 | grad_scaler: true 5 | teacher: 6 | backbone: 7 | sharding_strategy: SHARD_GRAD_OP 8 | mixed_precision: 9 | param_dtype: fp16 10 | reduce_dtype: fp16 11 | buffer_dtype: fp32 12 | dino_head: 13 | sharding_strategy: SHARD_GRAD_OP 14 | mixed_precision: 15 | param_dtype: fp16 16 | reduce_dtype: fp16 17 | buffer_dtype: fp32 18 | ibot_head: 19 | sharding_strategy: SHARD_GRAD_OP 20 | mixed_precision: 21 | param_dtype: fp16 22 | reduce_dtype: fp16 23 | buffer_dtype: fp32 24 | student: 25 | backbone: 26 | sharding_strategy: SHARD_GRAD_OP 27 | mixed_precision: 28 | param_dtype: fp16 29 | reduce_dtype: fp16 30 | buffer_dtype: fp32 31 | dino_head: 32 | sharding_strategy: SHARD_GRAD_OP 33 | mixed_precision: 34 | param_dtype: fp16 35 | reduce_dtype: fp32 36 | buffer_dtype: fp32 37 | ibot_head: 38 | sharding_strategy: SHARD_GRAD_OP 39 | mixed_precision: 40 | param_dtype: fp16 41 | reduce_dtype: fp32 42 | buffer_dtype: fp32 43 | dino: 44 | loss_weight: 1.0 45 | head_n_prototypes: 65536 46 | head_bottleneck_dim: 256 47 | head_nlayers: 3 48 | head_hidden_dim: 2048 49 | koleo_loss_weight: 0.1 50 | ibot: 51 | loss_weight: 1.0 52 | mask_sample_probability: 0.5 53 | mask_ratio_min_max: 54 | - 0.1 55 | - 0.5 56 | separate_head: false 57 | head_n_prototypes: 65536 58 | head_bottleneck_dim: 256 59 | head_nlayers: 3 60 | head_hidden_dim: 2048 61 | train: 62 | batch_size_per_gpu: 64 63 | dataset_path: ImageNet:split=TRAIN 64 | output_dir: . 65 | saveckp_freq: 20 66 | seed: 0 67 | num_workers: 10 68 | OFFICIAL_EPOCH_LENGTH: 1250 69 | cache_dataset: true 70 | centering: "centering" # or "sinkhorn_knopp" 71 | student: 72 | arch: vit_large 73 | patch_size: 16 74 | drop_path_rate: 0.3 75 | layerscale: 1.0e-05 76 | drop_path_uniform: true 77 | pretrained_weights: '' 78 | ffn_layer: "mlp" 79 | block_chunks: 0 80 | qkv_bias: true 81 | proj_bias: true 82 | ffn_bias: true 83 | teacher: 84 | momentum_teacher: 0.992 85 | final_momentum_teacher: 1 86 | warmup_teacher_temp: 0.04 87 | teacher_temp: 0.07 88 | warmup_teacher_temp_epochs: 30 89 | optim: 90 | epochs: 100 91 | weight_decay: 0.04 92 | weight_decay_end: 0.4 93 | base_lr: 0.004 # learning rate for a batch size of 1024 94 | lr: 0. # will be set after applying scaling rule 95 | warmup_epochs: 10 96 | min_lr: 1.0e-06 97 | clip_grad: 3.0 98 | freeze_last_layer_epochs: 1 99 | scaling_rule: sqrt_wrt_1024 100 | patch_embed_lr_mult: 0.2 101 | layerwise_decay: 0.9 102 | adamw_beta1: 0.9 103 | adamw_beta2: 0.999 104 | crops: 105 | global_crops_scale: 106 | - 0.32 107 | - 1.0 108 | local_crops_number: 8 109 | local_crops_scale: 110 | - 0.05 111 | - 0.32 112 | global_crops_size: 224 113 | local_crops_size: 96 114 | evaluation: 115 | eval_period_iterations: 12500 116 | -------------------------------------------------------------------------------- /dinov2/dinov2/configs/train/vitg14.yaml: -------------------------------------------------------------------------------- 1 | dino: 2 | head_n_prototypes: 131072 3 | head_bottleneck_dim: 384 4 | ibot: 5 | separate_head: true 6 | head_n_prototypes: 131072 7 | train: 8 | batch_size_per_gpu: 12 9 | dataset_path: ImageNet22k 10 | centering: sinkhorn_knopp 11 | student: 12 | arch: vit_giant2 13 | patch_size: 14 14 | drop_path_rate: 0.4 15 | ffn_layer: swiglufused 16 | block_chunks: 4 17 | teacher: 18 | momentum_teacher: 0.994 19 | optim: 20 | epochs: 500 21 | weight_decay_end: 0.2 22 | base_lr: 2.0e-04 # learning rate for a batch size of 1024 23 | warmup_epochs: 80 24 | layerwise_decay: 1.0 25 | crops: 26 | local_crops_size: 98 -------------------------------------------------------------------------------- /dinov2/dinov2/configs/train/vitl14.yaml: -------------------------------------------------------------------------------- 1 | dino: 2 | head_n_prototypes: 131072 3 | head_bottleneck_dim: 384 4 | ibot: 5 | separate_head: true 6 | head_n_prototypes: 131072 7 | train: 8 | batch_size_per_gpu: 32 9 | dataset_path: ImageNet22k 10 | centering: sinkhorn_knopp 11 | student: 12 | arch: vit_large 13 | patch_size: 14 14 | drop_path_rate: 0.4 15 | ffn_layer: swiglufused 16 | block_chunks: 4 17 | teacher: 18 | momentum_teacher: 0.994 19 | optim: 20 | epochs: 500 21 | weight_decay_end: 0.2 22 | base_lr: 2.0e-04 # learning rate for a batch size of 1024 23 | warmup_epochs: 80 24 | layerwise_decay: 1.0 25 | crops: 26 | local_crops_size: 98 -------------------------------------------------------------------------------- /dinov2/dinov2/configs/train/vitl16_short.yaml: -------------------------------------------------------------------------------- 1 | # this corresponds to the default config 2 | train: 3 | dataset_path: ImageNet:split=TRAIN 4 | batch_size_per_gpu: 64 5 | student: 6 | block_chunks: 4 7 | -------------------------------------------------------------------------------- /dinov2/dinov2/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .adapters import DatasetWithEnumeratedTargets 8 | from .loaders import make_data_loader, make_dataset, SamplerType 9 | from .collate import collate_data_and_cast 10 | from .masking import MaskingGenerator 11 | from .augmentations import DataAugmentationDINO 12 | -------------------------------------------------------------------------------- /dinov2/dinov2/data/adapters.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Any, Tuple 8 | 9 | from torch.utils.data import Dataset 10 | 11 | 12 | class DatasetWithEnumeratedTargets(Dataset): 13 | def __init__(self, dataset): 14 | self._dataset = dataset 15 | 16 | def get_image_data(self, index: int) -> bytes: 17 | return self._dataset.get_image_data(index) 18 | 19 | def get_target(self, index: int) -> Tuple[Any, int]: 20 | target = self._dataset.get_target(index) 21 | return (index, target) 22 | 23 | def __getitem__(self, index: int) -> Tuple[Any, Tuple[Any, int]]: 24 | image, target = self._dataset[index] 25 | target = index if target is None else target 26 | return image, (index, target) 27 | 28 | def __len__(self) -> int: 29 | return len(self._dataset) 30 | -------------------------------------------------------------------------------- /dinov2/dinov2/data/augmentations.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | 9 | from torchvision import transforms 10 | 11 | from .transforms import ( 12 | GaussianBlur, 13 | make_normalize_transform, 14 | ) 15 | 16 | 17 | logger = logging.getLogger("dinov2") 18 | 19 | 20 | class DataAugmentationDINO(object): 21 | def __init__( 22 | self, 23 | global_crops_scale, 24 | local_crops_scale, 25 | local_crops_number, 26 | global_crops_size=224, 27 | local_crops_size=96, 28 | ): 29 | self.global_crops_scale = global_crops_scale 30 | self.local_crops_scale = local_crops_scale 31 | self.local_crops_number = local_crops_number 32 | self.global_crops_size = global_crops_size 33 | self.local_crops_size = local_crops_size 34 | 35 | logger.info("###################################") 36 | logger.info("Using data augmentation parameters:") 37 | logger.info(f"global_crops_scale: {global_crops_scale}") 38 | logger.info(f"local_crops_scale: {local_crops_scale}") 39 | logger.info(f"local_crops_number: {local_crops_number}") 40 | logger.info(f"global_crops_size: {global_crops_size}") 41 | logger.info(f"local_crops_size: {local_crops_size}") 42 | logger.info("###################################") 43 | 44 | # random resized crop and flip 45 | self.geometric_augmentation_global = transforms.Compose( 46 | [ 47 | transforms.RandomResizedCrop( 48 | global_crops_size, scale=global_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC 49 | ), 50 | transforms.RandomHorizontalFlip(p=0.5), 51 | ] 52 | ) 53 | 54 | self.geometric_augmentation_local = transforms.Compose( 55 | [ 56 | transforms.RandomResizedCrop( 57 | local_crops_size, scale=local_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC 58 | ), 59 | transforms.RandomHorizontalFlip(p=0.5), 60 | ] 61 | ) 62 | 63 | # color distorsions / blurring 64 | color_jittering = transforms.Compose( 65 | [ 66 | transforms.RandomApply( 67 | [transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)], 68 | p=0.8, 69 | ), 70 | transforms.RandomGrayscale(p=0.2), 71 | ] 72 | ) 73 | 74 | global_transfo1_extra = GaussianBlur(p=1.0) 75 | 76 | global_transfo2_extra = transforms.Compose( 77 | [ 78 | GaussianBlur(p=0.1), 79 | transforms.RandomSolarize(threshold=128, p=0.2), 80 | ] 81 | ) 82 | 83 | local_transfo_extra = GaussianBlur(p=0.5) 84 | 85 | # normalization 86 | self.normalize = transforms.Compose( 87 | [ 88 | transforms.ToTensor(), 89 | make_normalize_transform(), 90 | ] 91 | ) 92 | 93 | self.global_transfo1 = transforms.Compose([color_jittering, global_transfo1_extra, self.normalize]) 94 | self.global_transfo2 = transforms.Compose([color_jittering, global_transfo2_extra, self.normalize]) 95 | self.local_transfo = transforms.Compose([color_jittering, local_transfo_extra, self.normalize]) 96 | 97 | def __call__(self, image): 98 | output = {} 99 | 100 | # global crops: 101 | im1_base = self.geometric_augmentation_global(image) 102 | global_crop_1 = self.global_transfo1(im1_base) 103 | 104 | im2_base = self.geometric_augmentation_global(image) 105 | global_crop_2 = self.global_transfo2(im2_base) 106 | 107 | output["global_crops"] = [global_crop_1, global_crop_2] 108 | 109 | # global crops for teacher: 110 | output["global_crops_teacher"] = [global_crop_1, global_crop_2] 111 | 112 | # local crops: 113 | local_crops = [ 114 | self.local_transfo(self.geometric_augmentation_local(image)) for _ in range(self.local_crops_number) 115 | ] 116 | output["local_crops"] = local_crops 117 | output["offsets"] = () 118 | 119 | return output 120 | -------------------------------------------------------------------------------- /dinov2/dinov2/data/collate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import random 9 | 10 | 11 | def collate_data_and_cast(samples_list, mask_ratio_tuple, mask_probability, dtype, n_tokens=None, mask_generator=None): 12 | # dtype = torch.half # TODO: Remove 13 | 14 | n_global_crops = len(samples_list[0][0]["global_crops"]) 15 | n_local_crops = len(samples_list[0][0]["local_crops"]) 16 | 17 | collated_global_crops = torch.stack([s[0]["global_crops"][i] for i in range(n_global_crops) for s in samples_list]) 18 | 19 | collated_local_crops = torch.stack([s[0]["local_crops"][i] for i in range(n_local_crops) for s in samples_list]) 20 | 21 | B = len(collated_global_crops) 22 | N = n_tokens 23 | n_samples_masked = int(B * mask_probability) 24 | probs = torch.linspace(*mask_ratio_tuple, n_samples_masked + 1) 25 | upperbound = 0 26 | masks_list = [] 27 | for i in range(0, n_samples_masked): 28 | prob_min = probs[i] 29 | prob_max = probs[i + 1] 30 | masks_list.append(torch.BoolTensor(mask_generator(int(N * random.uniform(prob_min, prob_max))))) 31 | upperbound += int(N * prob_max) 32 | for i in range(n_samples_masked, B): 33 | masks_list.append(torch.BoolTensor(mask_generator(0))) 34 | 35 | random.shuffle(masks_list) 36 | 37 | collated_masks = torch.stack(masks_list).flatten(1) 38 | mask_indices_list = collated_masks.flatten().nonzero().flatten() 39 | 40 | masks_weight = (1 / collated_masks.sum(-1).clamp(min=1.0)).unsqueeze(-1).expand_as(collated_masks)[collated_masks] 41 | 42 | return { 43 | "collated_global_crops": collated_global_crops.to(dtype), 44 | "collated_local_crops": collated_local_crops.to(dtype), 45 | "collated_masks": collated_masks, 46 | "mask_indices_list": mask_indices_list, 47 | "masks_weight": masks_weight, 48 | "upperbound": upperbound, 49 | "n_masked_patches": torch.full((1,), fill_value=mask_indices_list.shape[0], dtype=torch.long), 50 | } 51 | -------------------------------------------------------------------------------- /dinov2/dinov2/data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .image_net import ImageNet 8 | from .image_net_22k import ImageNet22k 9 | -------------------------------------------------------------------------------- /dinov2/dinov2/data/datasets/decoders.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from io import BytesIO 8 | from typing import Any 9 | 10 | from PIL import Image 11 | 12 | 13 | class Decoder: 14 | def decode(self) -> Any: 15 | raise NotImplementedError 16 | 17 | 18 | class ImageDataDecoder(Decoder): 19 | def __init__(self, image_data: bytes) -> None: 20 | self._image_data = image_data 21 | 22 | def decode(self) -> Image: 23 | f = BytesIO(self._image_data) 24 | return Image.open(f).convert(mode="RGB") 25 | 26 | 27 | class TargetDecoder(Decoder): 28 | def __init__(self, target: Any): 29 | self._target = target 30 | 31 | def decode(self) -> Any: 32 | return self._target 33 | -------------------------------------------------------------------------------- /dinov2/dinov2/data/datasets/extended.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Any, Tuple 8 | 9 | from torchvision.datasets import VisionDataset 10 | 11 | from .decoders import TargetDecoder, ImageDataDecoder 12 | 13 | 14 | class ExtendedVisionDataset(VisionDataset): 15 | def __init__(self, *args, **kwargs) -> None: 16 | super().__init__(*args, **kwargs) # type: ignore 17 | 18 | def get_image_data(self, index: int) -> bytes: 19 | raise NotImplementedError 20 | 21 | def get_target(self, index: int) -> Any: 22 | raise NotImplementedError 23 | 24 | def __getitem__(self, index: int) -> Tuple[Any, Any]: 25 | try: 26 | image_data = self.get_image_data(index) 27 | image = ImageDataDecoder(image_data).decode() 28 | except Exception as e: 29 | raise RuntimeError(f"can not read image for sample {index}") from e 30 | target = self.get_target(index) 31 | target = TargetDecoder(target).decode() 32 | 33 | if self.transforms is not None: 34 | image, target = self.transforms(image, target) 35 | 36 | return image, target 37 | 38 | def __len__(self) -> int: 39 | raise NotImplementedError 40 | -------------------------------------------------------------------------------- /dinov2/dinov2/data/masking.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import random 8 | import math 9 | import numpy as np 10 | 11 | 12 | class MaskingGenerator: 13 | def __init__( 14 | self, 15 | input_size, 16 | num_masking_patches=None, 17 | min_num_patches=4, 18 | max_num_patches=None, 19 | min_aspect=0.3, 20 | max_aspect=None, 21 | ): 22 | if not isinstance(input_size, tuple): 23 | input_size = (input_size,) * 2 24 | self.height, self.width = input_size 25 | 26 | self.num_patches = self.height * self.width 27 | self.num_masking_patches = num_masking_patches 28 | 29 | self.min_num_patches = min_num_patches 30 | self.max_num_patches = num_masking_patches if max_num_patches is None else max_num_patches 31 | 32 | max_aspect = max_aspect or 1 / min_aspect 33 | self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) 34 | 35 | def __repr__(self): 36 | repr_str = "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % ( 37 | self.height, 38 | self.width, 39 | self.min_num_patches, 40 | self.max_num_patches, 41 | self.num_masking_patches, 42 | self.log_aspect_ratio[0], 43 | self.log_aspect_ratio[1], 44 | ) 45 | return repr_str 46 | 47 | def get_shape(self): 48 | return self.height, self.width 49 | 50 | def _mask(self, mask, max_mask_patches): 51 | delta = 0 52 | for _ in range(10): 53 | target_area = random.uniform(self.min_num_patches, max_mask_patches) 54 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) 55 | h = int(round(math.sqrt(target_area * aspect_ratio))) 56 | w = int(round(math.sqrt(target_area / aspect_ratio))) 57 | if w < self.width and h < self.height: 58 | top = random.randint(0, self.height - h) 59 | left = random.randint(0, self.width - w) 60 | 61 | num_masked = mask[top : top + h, left : left + w].sum() 62 | # Overlap 63 | if 0 < h * w - num_masked <= max_mask_patches: 64 | for i in range(top, top + h): 65 | for j in range(left, left + w): 66 | if mask[i, j] == 0: 67 | mask[i, j] = 1 68 | delta += 1 69 | 70 | if delta > 0: 71 | break 72 | return delta 73 | 74 | def __call__(self, num_masking_patches=0): 75 | mask = np.zeros(shape=self.get_shape(), dtype=bool) 76 | mask_count = 0 77 | while mask_count < num_masking_patches: 78 | max_mask_patches = num_masking_patches - mask_count 79 | max_mask_patches = min(max_mask_patches, self.max_num_patches) 80 | 81 | delta = self._mask(mask, max_mask_patches) 82 | if delta == 0: 83 | break 84 | else: 85 | mask_count += delta 86 | 87 | return mask 88 | -------------------------------------------------------------------------------- /dinov2/dinov2/data/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Sequence 8 | 9 | import torch 10 | from torchvision import transforms 11 | 12 | 13 | class GaussianBlur(transforms.RandomApply): 14 | """ 15 | Apply Gaussian Blur to the PIL image. 16 | """ 17 | 18 | def __init__(self, *, p: float = 0.5, radius_min: float = 0.1, radius_max: float = 2.0): 19 | # NOTE: torchvision is applying 1 - probability to return the original image 20 | keep_p = 1 - p 21 | transform = transforms.GaussianBlur(kernel_size=9, sigma=(radius_min, radius_max)) 22 | super().__init__(transforms=[transform], p=keep_p) 23 | 24 | 25 | class MaybeToTensor(transforms.ToTensor): 26 | """ 27 | Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor, or keep as is if already a tensor. 28 | """ 29 | 30 | def __call__(self, pic): 31 | """ 32 | Args: 33 | pic (PIL Image, numpy.ndarray or torch.tensor): Image to be converted to tensor. 34 | Returns: 35 | Tensor: Converted image. 36 | """ 37 | if isinstance(pic, torch.Tensor): 38 | return pic 39 | return super().__call__(pic) 40 | 41 | 42 | # Use timm's names 43 | IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) 44 | IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) 45 | 46 | 47 | def make_normalize_transform( 48 | mean: Sequence[float] = IMAGENET_DEFAULT_MEAN, 49 | std: Sequence[float] = IMAGENET_DEFAULT_STD, 50 | ) -> transforms.Normalize: 51 | return transforms.Normalize(mean=mean, std=std) 52 | 53 | 54 | # This roughly matches torchvision's preset for classification training: 55 | # https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L6-L44 56 | def make_classification_train_transform( 57 | *, 58 | crop_size: int = 224, 59 | interpolation=transforms.InterpolationMode.BICUBIC, 60 | hflip_prob: float = 0.5, 61 | mean: Sequence[float] = IMAGENET_DEFAULT_MEAN, 62 | std: Sequence[float] = IMAGENET_DEFAULT_STD, 63 | ): 64 | transforms_list = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)] 65 | if hflip_prob > 0.0: 66 | transforms_list.append(transforms.RandomHorizontalFlip(hflip_prob)) 67 | transforms_list.extend( 68 | [ 69 | MaybeToTensor(), 70 | make_normalize_transform(mean=mean, std=std), 71 | ] 72 | ) 73 | return transforms.Compose(transforms_list) 74 | 75 | 76 | # This matches (roughly) torchvision's preset for classification evaluation: 77 | # https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L47-L69 78 | def make_classification_eval_transform( 79 | *, 80 | resize_size: int = 256, 81 | interpolation=transforms.InterpolationMode.BICUBIC, 82 | crop_size: int = 224, 83 | mean: Sequence[float] = IMAGENET_DEFAULT_MEAN, 84 | std: Sequence[float] = IMAGENET_DEFAULT_STD, 85 | ) -> transforms.Compose: 86 | transforms_list = [ 87 | transforms.Resize(resize_size, interpolation=interpolation), 88 | transforms.CenterCrop(crop_size), 89 | MaybeToTensor(), 90 | make_normalize_transform(mean=mean, std=std), 91 | ] 92 | return transforms.Compose(transforms_list) 93 | -------------------------------------------------------------------------------- /dinov2/dinov2/eval/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /dinov2/dinov2/eval/metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from enum import Enum 8 | import logging 9 | from typing import Any, Dict, Optional 10 | 11 | import torch 12 | from torch import Tensor 13 | from torchmetrics import Metric, MetricCollection 14 | from torchmetrics.classification import MulticlassAccuracy 15 | from torchmetrics.utilities.data import dim_zero_cat, select_topk 16 | 17 | 18 | logger = logging.getLogger("dinov2") 19 | 20 | 21 | class MetricType(Enum): 22 | MEAN_ACCURACY = "mean_accuracy" 23 | MEAN_PER_CLASS_ACCURACY = "mean_per_class_accuracy" 24 | PER_CLASS_ACCURACY = "per_class_accuracy" 25 | IMAGENET_REAL_ACCURACY = "imagenet_real_accuracy" 26 | 27 | @property 28 | def accuracy_averaging(self): 29 | return getattr(AccuracyAveraging, self.name, None) 30 | 31 | def __str__(self): 32 | return self.value 33 | 34 | 35 | class AccuracyAveraging(Enum): 36 | MEAN_ACCURACY = "micro" 37 | MEAN_PER_CLASS_ACCURACY = "macro" 38 | PER_CLASS_ACCURACY = "none" 39 | 40 | def __str__(self): 41 | return self.value 42 | 43 | 44 | def build_metric(metric_type: MetricType, *, num_classes: int, ks: Optional[tuple] = None): 45 | if metric_type.accuracy_averaging is not None: 46 | return build_topk_accuracy_metric( 47 | average_type=metric_type.accuracy_averaging, 48 | num_classes=num_classes, 49 | ks=(1, 5) if ks is None else ks, 50 | ) 51 | elif metric_type == MetricType.IMAGENET_REAL_ACCURACY: 52 | return build_topk_imagenet_real_accuracy_metric( 53 | num_classes=num_classes, 54 | ks=(1, 5) if ks is None else ks, 55 | ) 56 | 57 | raise ValueError(f"Unknown metric type {metric_type}") 58 | 59 | 60 | def build_topk_accuracy_metric(average_type: AccuracyAveraging, num_classes: int, ks: tuple = (1, 5)): 61 | metrics: Dict[str, Metric] = { 62 | f"top-{k}": MulticlassAccuracy(top_k=k, num_classes=int(num_classes), average=average_type.value) for k in ks 63 | } 64 | return MetricCollection(metrics) 65 | 66 | 67 | def build_topk_imagenet_real_accuracy_metric(num_classes: int, ks: tuple = (1, 5)): 68 | metrics: Dict[str, Metric] = {f"top-{k}": ImageNetReaLAccuracy(top_k=k, num_classes=int(num_classes)) for k in ks} 69 | return MetricCollection(metrics) 70 | 71 | 72 | class ImageNetReaLAccuracy(Metric): 73 | is_differentiable: bool = False 74 | higher_is_better: Optional[bool] = None 75 | full_state_update: bool = False 76 | 77 | def __init__( 78 | self, 79 | num_classes: int, 80 | top_k: int = 1, 81 | **kwargs: Any, 82 | ) -> None: 83 | super().__init__(**kwargs) 84 | self.num_classes = num_classes 85 | self.top_k = top_k 86 | self.add_state("tp", [], dist_reduce_fx="cat") 87 | 88 | def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore 89 | # preds [B, D] 90 | # target [B, A] 91 | # preds_oh [B, D] with 0 and 1 92 | # select top K highest probabilities, use one hot representation 93 | preds_oh = select_topk(preds, self.top_k) 94 | # target_oh [B, D + 1] with 0 and 1 95 | target_oh = torch.zeros((preds_oh.shape[0], preds_oh.shape[1] + 1), device=target.device, dtype=torch.int32) 96 | target = target.long() 97 | # for undefined targets (-1) use a fake value `num_classes` 98 | target[target == -1] = self.num_classes 99 | # fill targets, use one hot representation 100 | target_oh.scatter_(1, target, 1) 101 | # target_oh [B, D] (remove the fake target at index `num_classes`) 102 | target_oh = target_oh[:, :-1] 103 | # tp [B] with 0 and 1 104 | tp = (preds_oh * target_oh == 1).sum(dim=1) 105 | # at least one match between prediction and target 106 | tp.clip_(max=1) 107 | # ignore instances where no targets are defined 108 | mask = target_oh.sum(dim=1) > 0 109 | tp = tp[mask] 110 | self.tp.append(tp) # type: ignore 111 | 112 | def compute(self) -> Tensor: 113 | tp = dim_zero_cat(self.tp) # type: ignore 114 | return tp.float().mean() 115 | -------------------------------------------------------------------------------- /dinov2/dinov2/eval/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | from typing import Any, List, Optional, Tuple 9 | 10 | import torch 11 | import torch.backends.cudnn as cudnn 12 | 13 | from dinov2.models import build_model_from_cfg 14 | from dinov2.utils.config import setup 15 | import dinov2.utils.utils as dinov2_utils 16 | 17 | 18 | def get_args_parser( 19 | description: Optional[str] = None, 20 | parents: Optional[List[argparse.ArgumentParser]] = [], 21 | add_help: bool = True, 22 | ): 23 | parser = argparse.ArgumentParser( 24 | description=description, 25 | parents=parents, 26 | add_help=add_help, 27 | ) 28 | parser.add_argument( 29 | "--config-file", 30 | type=str, 31 | help="Model configuration file", 32 | ) 33 | parser.add_argument( 34 | "--pretrained-weights", 35 | type=str, 36 | help="Pretrained model weights", 37 | ) 38 | parser.add_argument( 39 | "--output-dir", 40 | default="", 41 | type=str, 42 | help="Output directory to write results and logs", 43 | ) 44 | parser.add_argument( 45 | "--opts", 46 | help="Extra configuration options", 47 | default=[], 48 | nargs="+", 49 | ) 50 | return parser 51 | 52 | 53 | def get_autocast_dtype(config): 54 | teacher_dtype_str = config.compute_precision.teacher.backbone.mixed_precision.param_dtype 55 | if teacher_dtype_str == "fp16": 56 | return torch.half 57 | elif teacher_dtype_str == "bf16": 58 | return torch.bfloat16 59 | else: 60 | return torch.float 61 | 62 | 63 | def build_model_for_eval(config, pretrained_weights): 64 | model, _ = build_model_from_cfg(config, only_teacher=True) 65 | dinov2_utils.load_pretrained_weights(model, pretrained_weights, "teacher") 66 | model.eval() 67 | model.cuda() 68 | return model 69 | 70 | 71 | def setup_and_build_model(args) -> Tuple[Any, torch.dtype]: 72 | cudnn.benchmark = True 73 | config = setup(args) 74 | model = build_model_for_eval(config, args.pretrained_weights) 75 | autocast_dtype = get_autocast_dtype(config) 76 | return model, autocast_dtype 77 | -------------------------------------------------------------------------------- /dinov2/dinov2/eval/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | from typing import Dict, Optional 9 | 10 | import torch 11 | from torch import nn 12 | from torchmetrics import MetricCollection 13 | 14 | from dinov2.data import DatasetWithEnumeratedTargets, SamplerType, make_data_loader 15 | import dinov2.distributed as distributed 16 | from dinov2.logging import MetricLogger 17 | 18 | 19 | logger = logging.getLogger("dinov2") 20 | 21 | 22 | class ModelWithNormalize(torch.nn.Module): 23 | def __init__(self, model): 24 | super().__init__() 25 | self.model = model 26 | 27 | def forward(self, samples): 28 | return nn.functional.normalize(self.model(samples), dim=1, p=2) 29 | 30 | 31 | class ModelWithIntermediateLayers(nn.Module): 32 | def __init__(self, feature_model, n_last_blocks, autocast_ctx): 33 | super().__init__() 34 | self.feature_model = feature_model 35 | self.feature_model.eval() 36 | self.n_last_blocks = n_last_blocks 37 | self.autocast_ctx = autocast_ctx 38 | 39 | def forward(self, images): 40 | with torch.inference_mode(): 41 | with self.autocast_ctx(): 42 | features = self.feature_model.get_intermediate_layers( 43 | images, self.n_last_blocks, return_class_token=True 44 | ) 45 | return features 46 | 47 | 48 | @torch.inference_mode() 49 | def evaluate( 50 | model: nn.Module, 51 | data_loader, 52 | postprocessors: Dict[str, nn.Module], 53 | metrics: Dict[str, MetricCollection], 54 | device: torch.device, 55 | criterion: Optional[nn.Module] = None, 56 | ): 57 | model.eval() 58 | if criterion is not None: 59 | criterion.eval() 60 | 61 | for metric in metrics.values(): 62 | metric = metric.to(device) 63 | 64 | metric_logger = MetricLogger(delimiter=" ") 65 | header = "Test:" 66 | 67 | for samples, targets, *_ in metric_logger.log_every(data_loader, 10, header): 68 | outputs = model(samples.to(device)) 69 | targets = targets.to(device) 70 | 71 | if criterion is not None: 72 | loss = criterion(outputs, targets) 73 | metric_logger.update(loss=loss.item()) 74 | 75 | for k, metric in metrics.items(): 76 | metric_inputs = postprocessors[k](outputs, targets) 77 | metric.update(**metric_inputs) 78 | 79 | metric_logger.synchronize_between_processes() 80 | logger.info(f"Averaged stats: {metric_logger}") 81 | 82 | stats = {k: metric.compute() for k, metric in metrics.items()} 83 | metric_logger_stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()} 84 | return metric_logger_stats, stats 85 | 86 | 87 | def all_gather_and_flatten(tensor_rank): 88 | tensor_all_ranks = torch.empty( 89 | distributed.get_global_size(), 90 | *tensor_rank.shape, 91 | dtype=tensor_rank.dtype, 92 | device=tensor_rank.device, 93 | ) 94 | tensor_list = list(tensor_all_ranks.unbind(0)) 95 | torch.distributed.all_gather(tensor_list, tensor_rank.contiguous()) 96 | return tensor_all_ranks.flatten(end_dim=1) 97 | 98 | 99 | def extract_features(model, dataset, batch_size, num_workers, gather_on_cpu=False): 100 | dataset_with_enumerated_targets = DatasetWithEnumeratedTargets(dataset) 101 | sample_count = len(dataset_with_enumerated_targets) 102 | data_loader = make_data_loader( 103 | dataset=dataset_with_enumerated_targets, 104 | batch_size=batch_size, 105 | num_workers=num_workers, 106 | sampler_type=SamplerType.DISTRIBUTED, 107 | drop_last=False, 108 | shuffle=False, 109 | ) 110 | return extract_features_with_dataloader(model, data_loader, sample_count, gather_on_cpu) 111 | 112 | 113 | @torch.inference_mode() 114 | def extract_features_with_dataloader(model, data_loader, sample_count, gather_on_cpu=False): 115 | gather_device = torch.device("cpu") if gather_on_cpu else torch.device("cuda") 116 | metric_logger = MetricLogger(delimiter=" ") 117 | features, all_labels = None, None 118 | for samples, (index, labels_rank) in metric_logger.log_every(data_loader, 10): 119 | samples = samples.cuda(non_blocking=True) 120 | labels_rank = labels_rank.cuda(non_blocking=True) 121 | index = index.cuda(non_blocking=True) 122 | features_rank = model(samples).float() 123 | 124 | # init storage feature matrix 125 | if features is None: 126 | features = torch.zeros(sample_count, features_rank.shape[-1], device=gather_device) 127 | labels_shape = list(labels_rank.shape) 128 | labels_shape[0] = sample_count 129 | all_labels = torch.full(labels_shape, fill_value=-1, device=gather_device) 130 | logger.info(f"Storing features into tensor of shape {features.shape}") 131 | 132 | # share indexes, features and labels between processes 133 | index_all = all_gather_and_flatten(index).to(gather_device) 134 | features_all_ranks = all_gather_and_flatten(features_rank).to(gather_device) 135 | labels_all_ranks = all_gather_and_flatten(labels_rank).to(gather_device) 136 | 137 | # update storage feature matrix 138 | if len(index_all) > 0: 139 | features.index_copy_(0, index_all, features_all_ranks) 140 | all_labels.index_copy_(0, index_all, labels_all_ranks) 141 | 142 | logger.info(f"Features shape: {tuple(features.shape)}") 143 | logger.info(f"Labels shape: {tuple(all_labels.shape)}") 144 | 145 | assert torch.all(all_labels > -1) 146 | 147 | return features, all_labels 148 | -------------------------------------------------------------------------------- /dinov2/dinov2/fsdp/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | from typing import Any 9 | 10 | import torch 11 | import dinov2.distributed as distributed 12 | from functools import partial 13 | from fvcore.common.checkpoint import Checkpointer 14 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 15 | from torch.distributed.fsdp import ShardingStrategy 16 | from torch.distributed.fsdp import MixedPrecision 17 | from torch.distributed.fsdp import StateDictType 18 | from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler 19 | from torch.distributed.fsdp.wrap import ModuleWrapPolicy 20 | from torch.distributed.fsdp._runtime_utils import _reshard 21 | 22 | 23 | def get_fsdp_wrapper(model_cfg, modules_to_wrap=set()): 24 | sharding_strategy_dict = { 25 | "NO_SHARD": ShardingStrategy.NO_SHARD, 26 | "SHARD_GRAD_OP": ShardingStrategy.SHARD_GRAD_OP, 27 | "FULL_SHARD": ShardingStrategy.FULL_SHARD, 28 | } 29 | 30 | dtype_dict = { 31 | "fp32": torch.float32, 32 | "fp16": torch.float16, 33 | "bf16": torch.bfloat16, 34 | } 35 | 36 | mixed_precision_config = MixedPrecision( 37 | param_dtype=dtype_dict[model_cfg.mixed_precision.param_dtype], 38 | reduce_dtype=dtype_dict[model_cfg.mixed_precision.reduce_dtype], 39 | buffer_dtype=dtype_dict[model_cfg.mixed_precision.buffer_dtype], 40 | ) 41 | 42 | sharding_strategy_config = sharding_strategy_dict[model_cfg.sharding_strategy] 43 | 44 | local_rank = distributed.get_local_rank() 45 | 46 | fsdp_wrapper = partial( 47 | FSDP, 48 | sharding_strategy=sharding_strategy_config, 49 | mixed_precision=mixed_precision_config, 50 | device_id=local_rank, 51 | sync_module_states=True, 52 | use_orig_params=True, 53 | auto_wrap_policy=ModuleWrapPolicy(modules_to_wrap), 54 | ) 55 | return fsdp_wrapper 56 | 57 | 58 | def is_fsdp(x): 59 | return isinstance(x, FSDP) 60 | 61 | 62 | def is_sharded_fsdp(x): 63 | return is_fsdp(x) and x.sharding_strategy is not ShardingStrategy.NO_SHARD 64 | 65 | 66 | def free_if_fsdp(x): 67 | if is_sharded_fsdp(x): 68 | handles = x._handles 69 | true_list = [True for h in handles] 70 | _reshard(x, handles, true_list) 71 | 72 | 73 | def get_fsdp_modules(x): 74 | return FSDP.fsdp_modules(x) 75 | 76 | 77 | def reshard_fsdp_model(x): 78 | for m in get_fsdp_modules(x): 79 | free_if_fsdp(m) 80 | 81 | 82 | def rankstr(): 83 | return f"rank_{distributed.get_global_rank()}" 84 | 85 | 86 | class FSDPCheckpointer(Checkpointer): 87 | def save(self, name: str, **kwargs: Any) -> None: 88 | """ 89 | Dump model and checkpointables to a file. 90 | 91 | Args: 92 | name (str): name of the file. 93 | kwargs (dict): extra arbitrary data to save. 94 | """ 95 | if not self.save_dir or not self.save_to_disk: 96 | return 97 | 98 | data = {} 99 | with FSDP.state_dict_type(self.model, StateDictType.LOCAL_STATE_DICT): 100 | data["model"] = self.model.state_dict() 101 | 102 | # data["model"] = self.model.state_dict() 103 | for key, obj in self.checkpointables.items(): 104 | data[key] = obj.state_dict() 105 | data.update(kwargs) 106 | 107 | basename = f"{name}.{rankstr()}.pth" 108 | save_file = os.path.join(self.save_dir, basename) 109 | assert os.path.basename(save_file) == basename, basename 110 | self.logger.info("Saving checkpoint to {}".format(save_file)) 111 | with self.path_manager.open(save_file, "wb") as f: 112 | torch.save(data, f) 113 | self.tag_last_checkpoint(basename) 114 | 115 | def load(self, *args, **kwargs): 116 | with FSDP.state_dict_type(self.model, StateDictType.LOCAL_STATE_DICT): 117 | return super().load(*args, **kwargs) 118 | 119 | def has_checkpoint(self) -> bool: 120 | """ 121 | Returns: 122 | bool: whether a checkpoint exists in the target directory. 123 | """ 124 | save_file = os.path.join(self.save_dir, f"last_checkpoint.{rankstr()}") 125 | return self.path_manager.exists(save_file) 126 | 127 | def get_checkpoint_file(self) -> str: 128 | """ 129 | Returns: 130 | str: The latest checkpoint file in target directory. 131 | """ 132 | save_file = os.path.join(self.save_dir, f"last_checkpoint.{rankstr()}") 133 | try: 134 | with self.path_manager.open(save_file, "r") as f: 135 | last_saved = f.read().strip() 136 | except IOError: 137 | # if file doesn't exist, maybe because it has just been 138 | # deleted by a separate process 139 | return "" 140 | # pyre-fixme[6]: For 2nd param expected `Union[PathLike[str], str]` but got 141 | # `Union[bytes, str]`. 142 | return os.path.join(self.save_dir, last_saved) 143 | 144 | def tag_last_checkpoint(self, last_filename_basename: str) -> None: 145 | """ 146 | Tag the last checkpoint. 147 | 148 | Args: 149 | last_filename_basename (str): the basename of the last filename. 150 | """ 151 | if distributed.is_enabled(): 152 | torch.distributed.barrier() 153 | save_file = os.path.join(self.save_dir, f"last_checkpoint.{rankstr()}") 154 | with self.path_manager.open(save_file, "w") as f: 155 | f.write(last_filename_basename) # pyre-ignore 156 | 157 | 158 | ShardedGradScaler = ShardedGradScaler 159 | -------------------------------------------------------------------------------- /dinov2/dinov2/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .dino_head import DINOHead 8 | from .mlp import Mlp 9 | from .patch_embed import PatchEmbed 10 | from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused 11 | from .block import NestedTensorBlock 12 | from .attention import MemEffAttention 13 | -------------------------------------------------------------------------------- /dinov2/dinov2/layers/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py 10 | 11 | import logging 12 | 13 | from torch import Tensor 14 | from torch import nn 15 | 16 | 17 | logger = logging.getLogger("dinov2") 18 | 19 | 20 | try: 21 | from xformers.ops import memory_efficient_attention, unbind, fmha 22 | 23 | XFORMERS_AVAILABLE = True 24 | except ImportError: 25 | logger.warning("xFormers not available") 26 | XFORMERS_AVAILABLE = False 27 | 28 | 29 | class Attention(nn.Module): 30 | def __init__( 31 | self, 32 | dim: int, 33 | num_heads: int = 8, 34 | qkv_bias: bool = False, 35 | proj_bias: bool = True, 36 | attn_drop: float = 0.0, 37 | proj_drop: float = 0.0, 38 | ) -> None: 39 | super().__init__() 40 | self.num_heads = num_heads 41 | head_dim = dim // num_heads 42 | self.scale = head_dim**-0.5 43 | 44 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 45 | self.attn_drop = nn.Dropout(attn_drop) 46 | self.proj = nn.Linear(dim, dim, bias=proj_bias) 47 | self.proj_drop = nn.Dropout(proj_drop) 48 | 49 | def forward(self, x: Tensor) -> Tensor: 50 | B, N, C = x.shape 51 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 52 | 53 | q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] 54 | attn = q @ k.transpose(-2, -1) 55 | 56 | attn = attn.softmax(dim=-1) 57 | attn = self.attn_drop(attn) 58 | 59 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 60 | x = self.proj(x) 61 | x = self.proj_drop(x) 62 | return x 63 | 64 | 65 | class MemEffAttention(Attention): 66 | def forward(self, x: Tensor, attn_bias=None) -> Tensor: 67 | if not XFORMERS_AVAILABLE: 68 | assert attn_bias is None, "xFormers is required for nested tensors usage" 69 | return super().forward(x) 70 | 71 | B, N, C = x.shape 72 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) 73 | 74 | q, k, v = unbind(qkv, 2) 75 | 76 | x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) 77 | x = x.reshape([B, N, C]) 78 | 79 | x = self.proj(x) 80 | x = self.proj_drop(x) 81 | return x 82 | -------------------------------------------------------------------------------- /dinov2/dinov2/layers/dino_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn.init import trunc_normal_ 10 | from torch.nn.utils import weight_norm 11 | 12 | 13 | class DINOHead(nn.Module): 14 | def __init__( 15 | self, 16 | in_dim, 17 | out_dim, 18 | use_bn=False, 19 | nlayers=3, 20 | hidden_dim=2048, 21 | bottleneck_dim=256, 22 | mlp_bias=True, 23 | ): 24 | super().__init__() 25 | nlayers = max(nlayers, 1) 26 | self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias) 27 | self.apply(self._init_weights) 28 | self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) 29 | self.last_layer.weight_g.data.fill_(1) 30 | 31 | def _init_weights(self, m): 32 | if isinstance(m, nn.Linear): 33 | trunc_normal_(m.weight, std=0.02) 34 | if isinstance(m, nn.Linear) and m.bias is not None: 35 | nn.init.constant_(m.bias, 0) 36 | 37 | def forward(self, x): 38 | x = self.mlp(x) 39 | eps = 1e-6 if x.dtype == torch.float16 else 1e-12 40 | x = nn.functional.normalize(x, dim=-1, p=2, eps=eps) 41 | x = self.last_layer(x) 42 | return x 43 | 44 | 45 | def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True): 46 | if nlayers == 1: 47 | return nn.Linear(in_dim, bottleneck_dim, bias=bias) 48 | else: 49 | layers = [nn.Linear(in_dim, hidden_dim, bias=bias)] 50 | if use_bn: 51 | layers.append(nn.BatchNorm1d(hidden_dim)) 52 | layers.append(nn.GELU()) 53 | for _ in range(nlayers - 2): 54 | layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias)) 55 | if use_bn: 56 | layers.append(nn.BatchNorm1d(hidden_dim)) 57 | layers.append(nn.GELU()) 58 | layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias)) 59 | return nn.Sequential(*layers) 60 | -------------------------------------------------------------------------------- /dinov2/dinov2/layers/drop_path.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py 10 | 11 | 12 | from torch import nn 13 | 14 | 15 | def drop_path(x, drop_prob: float = 0.0, training: bool = False): 16 | if drop_prob == 0.0 or not training: 17 | return x 18 | keep_prob = 1 - drop_prob 19 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 20 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob) 21 | if keep_prob > 0.0: 22 | random_tensor.div_(keep_prob) 23 | output = x * random_tensor 24 | return output 25 | 26 | 27 | class DropPath(nn.Module): 28 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" 29 | 30 | def __init__(self, drop_prob=None): 31 | super(DropPath, self).__init__() 32 | self.drop_prob = drop_prob 33 | 34 | def forward(self, x): 35 | return drop_path(x, self.drop_prob, self.training) 36 | -------------------------------------------------------------------------------- /dinov2/dinov2/layers/layer_scale.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 8 | 9 | from typing import Union 10 | 11 | import torch 12 | from torch import Tensor 13 | from torch import nn 14 | 15 | 16 | class LayerScale(nn.Module): 17 | def __init__( 18 | self, 19 | dim: int, 20 | init_values: Union[float, Tensor] = 1e-5, 21 | inplace: bool = False, 22 | ) -> None: 23 | super().__init__() 24 | self.inplace = inplace 25 | self.gamma = nn.Parameter(init_values * torch.ones(dim)) 26 | 27 | def forward(self, x: Tensor) -> Tensor: 28 | return x.mul_(self.gamma) if self.inplace else x * self.gamma 29 | -------------------------------------------------------------------------------- /dinov2/dinov2/layers/mlp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py 10 | 11 | 12 | from typing import Callable, Optional 13 | 14 | from torch import Tensor, nn 15 | 16 | 17 | class Mlp(nn.Module): 18 | def __init__( 19 | self, 20 | in_features: int, 21 | hidden_features: Optional[int] = None, 22 | out_features: Optional[int] = None, 23 | act_layer: Callable[..., nn.Module] = nn.GELU, 24 | drop: float = 0.0, 25 | bias: bool = True, 26 | ) -> None: 27 | super().__init__() 28 | out_features = out_features or in_features 29 | hidden_features = hidden_features or in_features 30 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) 31 | self.act = act_layer() 32 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) 33 | self.drop = nn.Dropout(drop) 34 | 35 | def forward(self, x: Tensor) -> Tensor: 36 | x = self.fc1(x) 37 | x = self.act(x) 38 | x = self.drop(x) 39 | x = self.fc2(x) 40 | x = self.drop(x) 41 | return x 42 | -------------------------------------------------------------------------------- /dinov2/dinov2/layers/patch_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py 10 | 11 | from typing import Callable, Optional, Tuple, Union 12 | 13 | from torch import Tensor 14 | import torch.nn as nn 15 | 16 | 17 | def make_2tuple(x): 18 | if isinstance(x, tuple): 19 | assert len(x) == 2 20 | return x 21 | 22 | assert isinstance(x, int) 23 | return (x, x) 24 | 25 | 26 | class PatchEmbed(nn.Module): 27 | """ 28 | 2D image to patch embedding: (B,C,H,W) -> (B,N,D) 29 | 30 | Args: 31 | img_size: Image size. 32 | patch_size: Patch token size. 33 | in_chans: Number of input image channels. 34 | embed_dim: Number of linear projection output channels. 35 | norm_layer: Normalization layer. 36 | """ 37 | 38 | def __init__( 39 | self, 40 | img_size: Union[int, Tuple[int, int]] = 224, 41 | patch_size: Union[int, Tuple[int, int]] = 16, 42 | in_chans: int = 3, 43 | embed_dim: int = 768, 44 | norm_layer: Optional[Callable] = None, 45 | flatten_embedding: bool = True, 46 | ) -> None: 47 | super().__init__() 48 | 49 | image_HW = make_2tuple(img_size) 50 | patch_HW = make_2tuple(patch_size) 51 | patch_grid_size = ( 52 | image_HW[0] // patch_HW[0], 53 | image_HW[1] // patch_HW[1], 54 | ) 55 | 56 | self.img_size = image_HW 57 | self.patch_size = patch_HW 58 | self.patches_resolution = patch_grid_size 59 | self.num_patches = patch_grid_size[0] * patch_grid_size[1] 60 | 61 | self.in_chans = in_chans 62 | self.embed_dim = embed_dim 63 | 64 | self.flatten_embedding = flatten_embedding 65 | 66 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) 67 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 68 | 69 | def forward(self, x: Tensor) -> Tensor: 70 | _, _, H, W = x.shape 71 | patch_H, patch_W = self.patch_size 72 | 73 | assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" 74 | assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" 75 | 76 | x = self.proj(x) # B C H W 77 | H, W = x.size(2), x.size(3) 78 | x = x.flatten(2).transpose(1, 2) # B HW C 79 | x = self.norm(x) 80 | if not self.flatten_embedding: 81 | x = x.reshape(-1, H, W, self.embed_dim) # B H W C 82 | return x 83 | 84 | def flops(self) -> float: 85 | Ho, Wo = self.patches_resolution 86 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) 87 | if self.norm is not None: 88 | flops += Ho * Wo * self.embed_dim 89 | return flops 90 | -------------------------------------------------------------------------------- /dinov2/dinov2/layers/swiglu_ffn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Callable, Optional 8 | 9 | from torch import Tensor, nn 10 | import torch.nn.functional as F 11 | 12 | 13 | class SwiGLUFFN(nn.Module): 14 | def __init__( 15 | self, 16 | in_features: int, 17 | hidden_features: Optional[int] = None, 18 | out_features: Optional[int] = None, 19 | act_layer: Callable[..., nn.Module] = None, 20 | drop: float = 0.0, 21 | bias: bool = True, 22 | ) -> None: 23 | super().__init__() 24 | out_features = out_features or in_features 25 | hidden_features = hidden_features or in_features 26 | self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) 27 | self.w3 = nn.Linear(hidden_features, out_features, bias=bias) 28 | 29 | def forward(self, x: Tensor) -> Tensor: 30 | x12 = self.w12(x) 31 | x1, x2 = x12.chunk(2, dim=-1) 32 | hidden = F.silu(x1) * x2 33 | return self.w3(hidden) 34 | 35 | 36 | try: 37 | from xformers.ops import SwiGLU 38 | 39 | XFORMERS_AVAILABLE = True 40 | except ImportError: 41 | SwiGLU = SwiGLUFFN 42 | XFORMERS_AVAILABLE = False 43 | 44 | 45 | class SwiGLUFFNFused(SwiGLU): 46 | def __init__( 47 | self, 48 | in_features: int, 49 | hidden_features: Optional[int] = None, 50 | out_features: Optional[int] = None, 51 | act_layer: Callable[..., nn.Module] = None, 52 | drop: float = 0.0, 53 | bias: bool = True, 54 | ) -> None: 55 | out_features = out_features or in_features 56 | hidden_features = hidden_features or in_features 57 | hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 58 | super().__init__( 59 | in_features=in_features, 60 | hidden_features=hidden_features, 61 | out_features=out_features, 62 | bias=bias, 63 | ) 64 | -------------------------------------------------------------------------------- /dinov2/dinov2/logging/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import functools 8 | import logging 9 | import os 10 | import sys 11 | from typing import Optional 12 | 13 | import dinov2.distributed as distributed 14 | from .helpers import MetricLogger, SmoothedValue 15 | 16 | 17 | # So that calling _configure_logger multiple times won't add many handlers 18 | @functools.lru_cache() 19 | def _configure_logger( 20 | name: Optional[str] = None, 21 | *, 22 | level: int = logging.DEBUG, 23 | output: Optional[str] = None, 24 | ): 25 | """ 26 | Configure a logger. 27 | 28 | Adapted from Detectron2. 29 | 30 | Args: 31 | name: The name of the logger to configure. 32 | level: The logging level to use. 33 | output: A file name or a directory to save log. If None, will not save log file. 34 | If ends with ".txt" or ".log", assumed to be a file name. 35 | Otherwise, logs will be saved to `output/log.txt`. 36 | 37 | Returns: 38 | The configured logger. 39 | """ 40 | 41 | logger = logging.getLogger(name) 42 | logger.setLevel(level) 43 | logger.propagate = False 44 | 45 | # Loosely match Google glog format: 46 | # [IWEF]yyyymmdd hh:mm:ss.uuuuuu threadid file:line] msg 47 | # but use a shorter timestamp and include the logger name: 48 | # [IWEF]yyyymmdd hh:mm:ss logger threadid file:line] msg 49 | fmt_prefix = "%(levelname).1s%(asctime)s %(process)s %(name)s %(filename)s:%(lineno)s] " 50 | fmt_message = "%(message)s" 51 | fmt = fmt_prefix + fmt_message 52 | datefmt = "%Y%m%d %H:%M:%S" 53 | formatter = logging.Formatter(fmt=fmt, datefmt=datefmt) 54 | 55 | # stdout logging for main worker only 56 | if distributed.is_main_process(): 57 | handler = logging.StreamHandler(stream=sys.stdout) 58 | handler.setLevel(logging.DEBUG) 59 | handler.setFormatter(formatter) 60 | logger.addHandler(handler) 61 | 62 | # file logging for all workers 63 | if output: 64 | if os.path.splitext(output)[-1] in (".txt", ".log"): 65 | filename = output 66 | else: 67 | filename = os.path.join(output, "logs", "log.txt") 68 | 69 | if not distributed.is_main_process(): 70 | global_rank = distributed.get_global_rank() 71 | filename = filename + ".rank{}".format(global_rank) 72 | 73 | os.makedirs(os.path.dirname(filename), exist_ok=True) 74 | 75 | handler = logging.StreamHandler(open(filename, "a")) 76 | handler.setLevel(logging.DEBUG) 77 | handler.setFormatter(formatter) 78 | logger.addHandler(handler) 79 | 80 | return logger 81 | 82 | 83 | def setup_logging( 84 | output: Optional[str] = None, 85 | *, 86 | name: Optional[str] = None, 87 | level: int = logging.DEBUG, 88 | capture_warnings: bool = True, 89 | ) -> None: 90 | """ 91 | Setup logging. 92 | 93 | Args: 94 | output: A file name or a directory to save log files. If None, log 95 | files will not be saved. If output ends with ".txt" or ".log", it 96 | is assumed to be a file name. 97 | Otherwise, logs will be saved to `output/log.txt`. 98 | name: The name of the logger to configure, by default the root logger. 99 | level: The logging level to use. 100 | capture_warnings: Whether warnings should be captured as logs. 101 | """ 102 | logging.captureWarnings(capture_warnings) 103 | _configure_logger(name, level=level, output=output) 104 | -------------------------------------------------------------------------------- /dinov2/dinov2/loss/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .dino_clstoken_loss import DINOLoss 8 | from .ibot_patch_loss import iBOTPatchLoss 9 | from .koleo_loss import KoLeoLoss 10 | -------------------------------------------------------------------------------- /dinov2/dinov2/loss/dino_clstoken_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.distributed as dist 9 | import torch.nn.functional as F 10 | from torch import nn 11 | 12 | 13 | class DINOLoss(nn.Module): 14 | def __init__( 15 | self, 16 | out_dim, 17 | student_temp=0.1, 18 | center_momentum=0.9, 19 | ): 20 | super().__init__() 21 | self.student_temp = student_temp 22 | self.center_momentum = center_momentum 23 | self.register_buffer("center", torch.zeros(1, out_dim)) 24 | self.updated = True 25 | self.reduce_handle = None 26 | self.len_teacher_output = None 27 | self.async_batch_center = None 28 | 29 | @torch.no_grad() 30 | def softmax_center_teacher(self, teacher_output, teacher_temp): 31 | self.apply_center_update() 32 | # teacher centering and sharpening 33 | return F.softmax((teacher_output - self.center) / teacher_temp, dim=-1) 34 | 35 | @torch.no_grad() 36 | def sinkhorn_knopp_teacher(self, teacher_output, teacher_temp, n_iterations=3): 37 | teacher_output = teacher_output.float() 38 | world_size = dist.get_world_size() if dist.is_initialized() else 1 39 | Q = torch.exp(teacher_output / teacher_temp).t() # Q is K-by-B for consistency with notations from our paper 40 | B = Q.shape[1] * world_size # number of samples to assign 41 | K = Q.shape[0] # how many prototypes 42 | 43 | # make the matrix sums to 1 44 | sum_Q = torch.sum(Q) 45 | if dist.is_initialized(): 46 | dist.all_reduce(sum_Q) 47 | Q /= sum_Q 48 | 49 | for it in range(n_iterations): 50 | # normalize each row: total weight per prototype must be 1/K 51 | sum_of_rows = torch.sum(Q, dim=1, keepdim=True) 52 | if dist.is_initialized(): 53 | dist.all_reduce(sum_of_rows) 54 | Q /= sum_of_rows 55 | Q /= K 56 | 57 | # normalize each column: total weight per sample must be 1/B 58 | Q /= torch.sum(Q, dim=0, keepdim=True) 59 | Q /= B 60 | 61 | Q *= B # the columns must sum to 1 so that Q is an assignment 62 | return Q.t() 63 | 64 | def forward(self, student_output_list, teacher_out_softmaxed_centered_list): 65 | """ 66 | Cross-entropy between softmax outputs of the teacher and student networks. 67 | """ 68 | # TODO: Use cross_entropy_distribution here 69 | total_loss = 0 70 | for s in student_output_list: 71 | lsm = F.log_softmax(s / self.student_temp, dim=-1) 72 | for t in teacher_out_softmaxed_centered_list: 73 | loss = torch.sum(t * lsm, dim=-1) 74 | total_loss -= loss.mean() 75 | return total_loss 76 | 77 | @torch.no_grad() 78 | def update_center(self, teacher_output): 79 | self.reduce_center_update(teacher_output) 80 | 81 | @torch.no_grad() 82 | def reduce_center_update(self, teacher_output): 83 | self.updated = False 84 | self.len_teacher_output = len(teacher_output) 85 | self.async_batch_center = torch.sum(teacher_output, dim=0, keepdim=True) 86 | if dist.is_initialized(): 87 | self.reduce_handle = dist.all_reduce(self.async_batch_center, async_op=True) 88 | 89 | @torch.no_grad() 90 | def apply_center_update(self): 91 | if self.updated is False: 92 | world_size = dist.get_world_size() if dist.is_initialized() else 1 93 | 94 | if self.reduce_handle is not None: 95 | self.reduce_handle.wait() 96 | _t = self.async_batch_center / (self.len_teacher_output * world_size) 97 | 98 | self.center = self.center * self.center_momentum + _t * (1 - self.center_momentum) 99 | 100 | self.updated = True 101 | -------------------------------------------------------------------------------- /dinov2/dinov2/loss/koleo_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | # import torch.distributed as dist 14 | 15 | 16 | logger = logging.getLogger("dinov2") 17 | 18 | 19 | class KoLeoLoss(nn.Module): 20 | """Kozachenko-Leonenko entropic loss regularizer from Sablayrolles et al. - 2018 - Spreading vectors for similarity search""" 21 | 22 | def __init__(self): 23 | super().__init__() 24 | self.pdist = nn.PairwiseDistance(2, eps=1e-8) 25 | 26 | def pairwise_NNs_inner(self, x): 27 | """ 28 | Pairwise nearest neighbors for L2-normalized vectors. 29 | Uses Torch rather than Faiss to remain on GPU. 30 | """ 31 | # parwise dot products (= inverse distance) 32 | dots = torch.mm(x, x.t()) 33 | n = x.shape[0] 34 | dots.view(-1)[:: (n + 1)].fill_(-1) # Trick to fill diagonal with -1 35 | # max inner prod -> min distance 36 | _, I = torch.max(dots, dim=1) # noqa: E741 37 | return I 38 | 39 | def forward(self, student_output, eps=1e-8): 40 | """ 41 | Args: 42 | student_output (BxD): backbone output of student 43 | """ 44 | with torch.cuda.amp.autocast(enabled=False): 45 | student_output = F.normalize(student_output, eps=eps, p=2, dim=-1) 46 | I = self.pairwise_NNs_inner(student_output) # noqa: E741 47 | distances = self.pdist(student_output, student_output[I]) # BxD, BxD -> B 48 | loss = -torch.log(distances + eps).mean() 49 | return loss 50 | -------------------------------------------------------------------------------- /dinov2/dinov2/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | 9 | from . import vision_transformer as vits 10 | 11 | 12 | logger = logging.getLogger("dinov2") 13 | 14 | 15 | def build_model(args, only_teacher=False, img_size=224): 16 | args.arch = args.arch.removesuffix("_memeff") 17 | if "vit" in args.arch: 18 | vit_kwargs = dict( 19 | img_size=img_size, 20 | patch_size=args.patch_size, 21 | init_values=args.layerscale, 22 | ffn_layer=args.ffn_layer, 23 | block_chunks=args.block_chunks, 24 | qkv_bias=args.qkv_bias, 25 | proj_bias=args.proj_bias, 26 | ffn_bias=args.ffn_bias, 27 | ) 28 | teacher = vits.__dict__[args.arch](**vit_kwargs) 29 | if only_teacher: 30 | return teacher, teacher.embed_dim 31 | student = vits.__dict__[args.arch]( 32 | **vit_kwargs, 33 | drop_path_rate=args.drop_path_rate, 34 | drop_path_uniform=args.drop_path_uniform, 35 | ) 36 | embed_dim = student.embed_dim 37 | return student, teacher, embed_dim 38 | 39 | 40 | def build_model_from_cfg(cfg, only_teacher=False): 41 | return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size) 42 | -------------------------------------------------------------------------------- /dinov2/dinov2/run/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /dinov2/dinov2/run/eval/knn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import os 9 | import sys 10 | 11 | from dinov2.eval.knn import get_args_parser as get_knn_args_parser 12 | from dinov2.logging import setup_logging 13 | from dinov2.run.submit import get_args_parser, submit_jobs 14 | 15 | 16 | logger = logging.getLogger("dinov2") 17 | 18 | 19 | class Evaluator: 20 | def __init__(self, args): 21 | self.args = args 22 | 23 | def __call__(self): 24 | from dinov2.eval.knn import main as knn_main 25 | 26 | self._setup_args() 27 | knn_main(self.args) 28 | 29 | def checkpoint(self): 30 | import submitit 31 | 32 | logger.info(f"Requeuing {self.args}") 33 | empty = type(self)(self.args) 34 | return submitit.helpers.DelayedSubmission(empty) 35 | 36 | def _setup_args(self): 37 | import submitit 38 | 39 | job_env = submitit.JobEnvironment() 40 | self.args.output_dir = self.args.output_dir.replace("%j", str(job_env.job_id)) 41 | logger.info(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") 42 | logger.info(f"Args: {self.args}") 43 | 44 | 45 | def main(): 46 | description = "Submitit launcher for DINOv2 k-NN evaluation" 47 | knn_args_parser = get_knn_args_parser(add_help=False) 48 | parents = [knn_args_parser] 49 | args_parser = get_args_parser(description=description, parents=parents) 50 | args = args_parser.parse_args() 51 | 52 | setup_logging() 53 | 54 | assert os.path.exists(args.config_file), "Configuration file does not exist!" 55 | submit_jobs(Evaluator, args, name="dinov2:knn") 56 | return 0 57 | 58 | 59 | if __name__ == "__main__": 60 | sys.exit(main()) 61 | -------------------------------------------------------------------------------- /dinov2/dinov2/run/eval/linear.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import os 9 | import sys 10 | 11 | from dinov2.eval.linear import get_args_parser as get_linear_args_parser 12 | from dinov2.logging import setup_logging 13 | from dinov2.run.submit import get_args_parser, submit_jobs 14 | 15 | 16 | logger = logging.getLogger("dinov2") 17 | 18 | 19 | class Evaluator: 20 | def __init__(self, args): 21 | self.args = args 22 | 23 | def __call__(self): 24 | from dinov2.eval.linear import main as linear_main 25 | 26 | self._setup_args() 27 | linear_main(self.args) 28 | 29 | def checkpoint(self): 30 | import submitit 31 | 32 | logger.info(f"Requeuing {self.args}") 33 | empty = type(self)(self.args) 34 | return submitit.helpers.DelayedSubmission(empty) 35 | 36 | def _setup_args(self): 37 | import submitit 38 | 39 | job_env = submitit.JobEnvironment() 40 | self.args.output_dir = self.args.output_dir.replace("%j", str(job_env.job_id)) 41 | logger.info(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") 42 | logger.info(f"Args: {self.args}") 43 | 44 | 45 | def main(): 46 | description = "Submitit launcher for DINOv2 linear evaluation" 47 | linear_args_parser = get_linear_args_parser(add_help=False) 48 | parents = [linear_args_parser] 49 | args_parser = get_args_parser(description=description, parents=parents) 50 | args = args_parser.parse_args() 51 | 52 | setup_logging() 53 | 54 | assert os.path.exists(args.config_file), "Configuration file does not exist!" 55 | submit_jobs(Evaluator, args, name="dinov2:linear") 56 | return 0 57 | 58 | 59 | if __name__ == "__main__": 60 | sys.exit(main()) 61 | -------------------------------------------------------------------------------- /dinov2/dinov2/run/eval/log_regression.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import os 9 | import sys 10 | 11 | from dinov2.eval.log_regression import get_args_parser as get_log_regression_args_parser 12 | from dinov2.logging import setup_logging 13 | from dinov2.run.submit import get_args_parser, submit_jobs 14 | 15 | 16 | logger = logging.getLogger("dinov2") 17 | 18 | 19 | class Evaluator: 20 | def __init__(self, args): 21 | self.args = args 22 | 23 | def __call__(self): 24 | from dinov2.eval.log_regression import main as log_regression_main 25 | 26 | self._setup_args() 27 | log_regression_main(self.args) 28 | 29 | def checkpoint(self): 30 | import submitit 31 | 32 | logger.info(f"Requeuing {self.args}") 33 | empty = type(self)(self.args) 34 | return submitit.helpers.DelayedSubmission(empty) 35 | 36 | def _setup_args(self): 37 | import submitit 38 | 39 | job_env = submitit.JobEnvironment() 40 | self.args.output_dir = self.args.output_dir.replace("%j", str(job_env.job_id)) 41 | logger.info(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") 42 | logger.info(f"Args: {self.args}") 43 | 44 | 45 | def main(): 46 | description = "Submitit launcher for DINOv2 logistic evaluation" 47 | log_regression_args_parser = get_log_regression_args_parser(add_help=False) 48 | parents = [log_regression_args_parser] 49 | args_parser = get_args_parser(description=description, parents=parents) 50 | args = args_parser.parse_args() 51 | 52 | setup_logging() 53 | 54 | assert os.path.exists(args.config_file), "Configuration file does not exist!" 55 | submit_jobs(Evaluator, args, name="dinov2:logreg") 56 | return 0 57 | 58 | 59 | if __name__ == "__main__": 60 | sys.exit(main()) 61 | -------------------------------------------------------------------------------- /dinov2/dinov2/run/submit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import logging 9 | import os 10 | from pathlib import Path 11 | from typing import List, Optional 12 | 13 | import submitit 14 | 15 | from dinov2.utils.cluster import ( 16 | get_slurm_executor_parameters, 17 | get_slurm_partition, 18 | get_user_checkpoint_path, 19 | ) 20 | 21 | 22 | logger = logging.getLogger("dinov2") 23 | 24 | 25 | def get_args_parser( 26 | description: Optional[str] = None, 27 | parents: Optional[List[argparse.ArgumentParser]] = [], 28 | add_help: bool = True, 29 | ) -> argparse.ArgumentParser: 30 | slurm_partition = get_slurm_partition() 31 | parser = argparse.ArgumentParser( 32 | description=description, 33 | parents=parents, 34 | add_help=add_help, 35 | ) 36 | parser.add_argument( 37 | "--ngpus", 38 | "--gpus", 39 | "--gpus-per-node", 40 | default=8, 41 | type=int, 42 | help="Number of GPUs to request on each node", 43 | ) 44 | parser.add_argument( 45 | "--nodes", 46 | "--nnodes", 47 | default=2, 48 | type=int, 49 | help="Number of nodes to request", 50 | ) 51 | parser.add_argument( 52 | "--timeout", 53 | default=2800, 54 | type=int, 55 | help="Duration of the job", 56 | ) 57 | parser.add_argument( 58 | "--partition", 59 | default=slurm_partition, 60 | type=str, 61 | help="Partition where to submit", 62 | ) 63 | parser.add_argument( 64 | "--use-volta32", 65 | action="store_true", 66 | help="Request V100-32GB GPUs", 67 | ) 68 | parser.add_argument( 69 | "--comment", 70 | default="", 71 | type=str, 72 | help="Comment to pass to scheduler, e.g. priority message", 73 | ) 74 | parser.add_argument( 75 | "--exclude", 76 | default="", 77 | type=str, 78 | help="Nodes to exclude", 79 | ) 80 | return parser 81 | 82 | 83 | def get_shared_folder() -> Path: 84 | user_checkpoint_path = get_user_checkpoint_path() 85 | if user_checkpoint_path is None: 86 | raise RuntimeError("Path to user checkpoint cannot be determined") 87 | path = user_checkpoint_path / "experiments" 88 | path.mkdir(exist_ok=True) 89 | return path 90 | 91 | 92 | def submit_jobs(task_class, args, name: str): 93 | if not args.output_dir: 94 | args.output_dir = str(get_shared_folder() / "%j") 95 | 96 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 97 | executor = submitit.AutoExecutor(folder=args.output_dir, slurm_max_num_timeout=30) 98 | 99 | kwargs = {} 100 | if args.use_volta32: 101 | kwargs["slurm_constraint"] = "volta32gb" 102 | if args.comment: 103 | kwargs["slurm_comment"] = args.comment 104 | if args.exclude: 105 | kwargs["slurm_exclude"] = args.exclude 106 | 107 | executor_params = get_slurm_executor_parameters( 108 | nodes=args.nodes, 109 | num_gpus_per_node=args.ngpus, 110 | timeout_min=args.timeout, # max is 60 * 72 111 | slurm_signal_delay_s=120, 112 | slurm_partition=args.partition, 113 | **kwargs, 114 | ) 115 | executor.update_parameters(name=name, **executor_params) 116 | 117 | task = task_class(args) 118 | job = executor.submit(task) 119 | 120 | logger.info(f"Submitted job_id: {job.job_id}") 121 | str_output_dir = os.path.abspath(args.output_dir).replace("%j", str(job.job_id)) 122 | logger.info(f"Logs and checkpoints will be saved at: {str_output_dir}") 123 | -------------------------------------------------------------------------------- /dinov2/dinov2/run/train/train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import os 9 | import sys 10 | 11 | from dinov2.logging import setup_logging 12 | from dinov2.train import get_args_parser as get_train_args_parser 13 | from dinov2.run.submit import get_args_parser, submit_jobs 14 | 15 | 16 | logger = logging.getLogger("dinov2") 17 | 18 | 19 | class Trainer(object): 20 | def __init__(self, args): 21 | self.args = args 22 | 23 | def __call__(self): 24 | from dinov2.train import main as train_main 25 | 26 | self._setup_args() 27 | train_main(self.args) 28 | 29 | def checkpoint(self): 30 | import submitit 31 | 32 | logger.info(f"Requeuing {self.args}") 33 | empty = type(self)(self.args) 34 | return submitit.helpers.DelayedSubmission(empty) 35 | 36 | def _setup_args(self): 37 | import submitit 38 | 39 | job_env = submitit.JobEnvironment() 40 | self.args.output_dir = self.args.output_dir.replace("%j", str(job_env.job_id)) 41 | logger.info(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") 42 | logger.info(f"Args: {self.args}") 43 | 44 | 45 | def main(): 46 | description = "Submitit launcher for DINOv2 training" 47 | train_args_parser = get_train_args_parser(add_help=False) 48 | parents = [train_args_parser] 49 | args_parser = get_args_parser(description=description, parents=parents) 50 | args = args_parser.parse_args() 51 | 52 | setup_logging() 53 | 54 | assert os.path.exists(args.config_file), "Configuration file does not exist!" 55 | submit_jobs(Trainer, args, name="dinov2:train") 56 | return 0 57 | 58 | 59 | if __name__ == "__main__": 60 | sys.exit(main()) 61 | -------------------------------------------------------------------------------- /dinov2/dinov2/train/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .train import get_args_parser, main 8 | from .ssl_meta_arch import SSLMetaArch 9 | -------------------------------------------------------------------------------- /dinov2/dinov2/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /dinov2/dinov2/utils/cluster.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from enum import Enum 8 | import os 9 | from pathlib import Path 10 | from typing import Any, Dict, Optional 11 | 12 | 13 | class ClusterType(Enum): 14 | AWS = "aws" 15 | FAIR = "fair" 16 | RSC = "rsc" 17 | 18 | 19 | def _guess_cluster_type() -> ClusterType: 20 | uname = os.uname() 21 | if uname.sysname == "Linux": 22 | if uname.release.endswith("-aws"): 23 | # Linux kernel versions on AWS instances are of the form "5.4.0-1051-aws" 24 | return ClusterType.AWS 25 | elif uname.nodename.startswith("rsc"): 26 | # Linux kernel versions on RSC instances are standard ones but hostnames start with "rsc" 27 | return ClusterType.RSC 28 | 29 | return ClusterType.FAIR 30 | 31 | 32 | def get_cluster_type(cluster_type: Optional[ClusterType] = None) -> Optional[ClusterType]: 33 | if cluster_type is None: 34 | return _guess_cluster_type() 35 | 36 | return cluster_type 37 | 38 | 39 | def get_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]: 40 | cluster_type = get_cluster_type(cluster_type) 41 | if cluster_type is None: 42 | return None 43 | 44 | CHECKPOINT_DIRNAMES = { 45 | ClusterType.AWS: "checkpoints", 46 | ClusterType.FAIR: "checkpoint", 47 | ClusterType.RSC: "checkpoint/dino", 48 | } 49 | return Path("/") / CHECKPOINT_DIRNAMES[cluster_type] 50 | 51 | 52 | def get_user_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]: 53 | checkpoint_path = get_checkpoint_path(cluster_type) 54 | if checkpoint_path is None: 55 | return None 56 | 57 | username = os.environ.get("USER") 58 | assert username is not None 59 | return checkpoint_path / username 60 | 61 | 62 | def get_slurm_partition(cluster_type: Optional[ClusterType] = None) -> Optional[str]: 63 | cluster_type = get_cluster_type(cluster_type) 64 | if cluster_type is None: 65 | return None 66 | 67 | SLURM_PARTITIONS = { 68 | ClusterType.AWS: "learnlab", 69 | ClusterType.FAIR: "learnlab", 70 | ClusterType.RSC: "learn", 71 | } 72 | return SLURM_PARTITIONS[cluster_type] 73 | 74 | 75 | def get_slurm_executor_parameters( 76 | nodes: int, num_gpus_per_node: int, cluster_type: Optional[ClusterType] = None, **kwargs 77 | ) -> Dict[str, Any]: 78 | # create default parameters 79 | params = { 80 | "mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html 81 | "gpus_per_node": num_gpus_per_node, 82 | "tasks_per_node": num_gpus_per_node, # one task per GPU 83 | "cpus_per_task": 10, 84 | "nodes": nodes, 85 | "slurm_partition": get_slurm_partition(cluster_type), 86 | } 87 | # apply cluster-specific adjustments 88 | cluster_type = get_cluster_type(cluster_type) 89 | if cluster_type == ClusterType.AWS: 90 | params["cpus_per_task"] = 12 91 | del params["mem_gb"] 92 | elif cluster_type == ClusterType.RSC: 93 | params["cpus_per_task"] = 12 94 | # set additional parameters / apply overrides 95 | params.update(kwargs) 96 | return params 97 | -------------------------------------------------------------------------------- /dinov2/dinov2/utils/config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | import logging 9 | import os 10 | 11 | from omegaconf import OmegaConf 12 | 13 | import dinov2.distributed as distributed 14 | from dinov2.logging import setup_logging 15 | from dinov2.utils import utils 16 | from dinov2.configs import dinov2_default_config 17 | 18 | 19 | logger = logging.getLogger("dinov2") 20 | 21 | 22 | def apply_scaling_rules_to_cfg(cfg): # to fix 23 | if cfg.optim.scaling_rule == "sqrt_wrt_1024": 24 | base_lr = cfg.optim.base_lr 25 | cfg.optim.lr = base_lr 26 | cfg.optim.lr *= math.sqrt(cfg.train.batch_size_per_gpu * distributed.get_global_size() / 1024.0) 27 | logger.info(f"sqrt scaling learning rate; base: {base_lr}, new: {cfg.optim.lr}") 28 | else: 29 | raise NotImplementedError 30 | return cfg 31 | 32 | 33 | def write_config(cfg, output_dir, name="config.yaml"): 34 | logger.info(OmegaConf.to_yaml(cfg)) 35 | saved_cfg_path = os.path.join(output_dir, name) 36 | with open(saved_cfg_path, "w") as f: 37 | OmegaConf.save(config=cfg, f=f) 38 | return saved_cfg_path 39 | 40 | 41 | def get_cfg_from_args(args): 42 | args.output_dir = os.path.abspath(args.output_dir) 43 | args.opts += [f"train.output_dir={args.output_dir}"] 44 | default_cfg = OmegaConf.create(dinov2_default_config) 45 | cfg = OmegaConf.load(args.config_file) 46 | cfg = OmegaConf.merge(default_cfg, cfg, OmegaConf.from_cli(args.opts)) 47 | return cfg 48 | 49 | 50 | def default_setup(args): 51 | distributed.enable(overwrite=True) 52 | seed = getattr(args, "seed", 0) 53 | rank = distributed.get_global_rank() 54 | 55 | global logger 56 | setup_logging(output=args.output_dir, level=logging.INFO) 57 | logger = logging.getLogger("dinov2") 58 | 59 | utils.fix_random_seeds(seed + rank) 60 | logger.info("git:\n {}\n".format(utils.get_sha())) 61 | logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))) 62 | 63 | 64 | def setup(args): 65 | """ 66 | Create configs and perform basic setups. 67 | """ 68 | cfg = get_cfg_from_args(args) 69 | os.makedirs(args.output_dir, exist_ok=True) 70 | default_setup(args) 71 | apply_scaling_rules_to_cfg(cfg) 72 | write_config(cfg, args.output_dir) 73 | return cfg 74 | -------------------------------------------------------------------------------- /dinov2/dinov2/utils/dtype.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | from typing import Dict, Union 9 | 10 | import numpy as np 11 | import torch 12 | 13 | 14 | TypeSpec = Union[str, np.dtype, torch.dtype] 15 | 16 | 17 | _NUMPY_TO_TORCH_DTYPE: Dict[np.dtype, torch.dtype] = { 18 | np.dtype("bool"): torch.bool, 19 | np.dtype("uint8"): torch.uint8, 20 | np.dtype("int8"): torch.int8, 21 | np.dtype("int16"): torch.int16, 22 | np.dtype("int32"): torch.int32, 23 | np.dtype("int64"): torch.int64, 24 | np.dtype("float16"): torch.float16, 25 | np.dtype("float32"): torch.float32, 26 | np.dtype("float64"): torch.float64, 27 | np.dtype("complex64"): torch.complex64, 28 | np.dtype("complex128"): torch.complex128, 29 | } 30 | 31 | 32 | def as_torch_dtype(dtype: TypeSpec) -> torch.dtype: 33 | if isinstance(dtype, torch.dtype): 34 | return dtype 35 | if isinstance(dtype, str): 36 | dtype = np.dtype(dtype) 37 | assert isinstance(dtype, np.dtype), f"Expected an instance of nunpy dtype, got {type(dtype)}" 38 | return _NUMPY_TO_TORCH_DTYPE[dtype] 39 | -------------------------------------------------------------------------------- /dinov2/dinov2/utils/param_groups.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from collections import defaultdict 8 | import logging 9 | 10 | 11 | logger = logging.getLogger("dinov2") 12 | 13 | 14 | def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12, force_is_backbone=False, chunked_blocks=False): 15 | """ 16 | Calculate lr decay rate for different ViT blocks. 17 | Args: 18 | name (string): parameter name. 19 | lr_decay_rate (float): base lr decay rate. 20 | num_layers (int): number of ViT blocks. 21 | Returns: 22 | lr decay rate for the given parameter. 23 | """ 24 | layer_id = num_layers + 1 25 | if name.startswith("backbone") or force_is_backbone: 26 | if ".pos_embed" in name or ".patch_embed" in name or ".mask_token" in name or ".cls_token" in name: 27 | layer_id = 0 28 | elif force_is_backbone and ( 29 | "pos_embed" in name or "patch_embed" in name or "mask_token" in name or "cls_token" in name 30 | ): 31 | layer_id = 0 32 | elif ".blocks." in name and ".residual." not in name: 33 | layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1 34 | elif chunked_blocks and "blocks." in name and "residual." not in name: 35 | layer_id = int(name[name.find("blocks.") :].split(".")[2]) + 1 36 | elif "blocks." in name and "residual." not in name: 37 | layer_id = int(name[name.find("blocks.") :].split(".")[1]) + 1 38 | 39 | return lr_decay_rate ** (num_layers + 1 - layer_id) 40 | 41 | 42 | def get_params_groups_with_decay(model, lr_decay_rate=1.0, patch_embed_lr_mult=1.0): 43 | chunked_blocks = False 44 | if hasattr(model, "n_blocks"): 45 | logger.info("chunked fsdp") 46 | n_blocks = model.n_blocks 47 | chunked_blocks = model.chunked_blocks 48 | elif hasattr(model, "blocks"): 49 | logger.info("first code branch") 50 | n_blocks = len(model.blocks) 51 | elif hasattr(model, "backbone"): 52 | logger.info("second code branch") 53 | n_blocks = len(model.backbone.blocks) 54 | else: 55 | logger.info("else code branch") 56 | n_blocks = 0 57 | all_param_groups = [] 58 | 59 | for name, param in model.named_parameters(): 60 | name = name.replace("_fsdp_wrapped_module.", "") 61 | if not param.requires_grad: 62 | continue 63 | decay_rate = get_vit_lr_decay_rate( 64 | name, lr_decay_rate, num_layers=n_blocks, force_is_backbone=n_blocks > 0, chunked_blocks=chunked_blocks 65 | ) 66 | d = {"params": param, "is_last_layer": False, "lr_multiplier": decay_rate, "wd_multiplier": 1.0, "name": name} 67 | 68 | if "last_layer" in name: 69 | d.update({"is_last_layer": True}) 70 | 71 | if name.endswith(".bias") or "norm" in name or "gamma" in name: 72 | d.update({"wd_multiplier": 0.0}) 73 | 74 | if "patch_embed" in name: 75 | d.update({"lr_multiplier": d["lr_multiplier"] * patch_embed_lr_mult}) 76 | 77 | all_param_groups.append(d) 78 | logger.info(f"""{name}: lr_multiplier: {d["lr_multiplier"]}, wd_multiplier: {d["wd_multiplier"]}""") 79 | 80 | return all_param_groups 81 | 82 | 83 | def fuse_params_groups(all_params_groups, keys=("lr_multiplier", "wd_multiplier", "is_last_layer")): 84 | fused_params_groups = defaultdict(lambda: {"params": []}) 85 | for d in all_params_groups: 86 | identifier = "" 87 | for k in keys: 88 | identifier += k + str(d[k]) + "_" 89 | 90 | for k in keys: 91 | fused_params_groups[identifier][k] = d[k] 92 | fused_params_groups[identifier]["params"].append(d["params"]) 93 | 94 | return fused_params_groups.values() 95 | -------------------------------------------------------------------------------- /dinov2/dinov2/utils/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import os 9 | import random 10 | import subprocess 11 | from urllib.parse import urlparse 12 | 13 | import numpy as np 14 | import torch 15 | from torch import nn 16 | 17 | 18 | logger = logging.getLogger("dinov2") 19 | 20 | 21 | def load_pretrained_weights(model, pretrained_weights, checkpoint_key): 22 | if urlparse(pretrained_weights).scheme: # If it looks like an URL 23 | state_dict = torch.hub.load_state_dict_from_url(pretrained_weights, map_location="cpu") 24 | else: 25 | state_dict = torch.load(pretrained_weights, map_location="cpu") 26 | if checkpoint_key is not None and checkpoint_key in state_dict: 27 | logger.info(f"Take key {checkpoint_key} in provided checkpoint dict") 28 | state_dict = state_dict[checkpoint_key] 29 | # remove `module.` prefix 30 | state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} 31 | # remove `backbone.` prefix induced by multicrop wrapper 32 | state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} 33 | msg = model.load_state_dict(state_dict, strict=False) 34 | logger.info("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg)) 35 | 36 | 37 | def fix_random_seeds(seed=31): 38 | """ 39 | Fix random seeds. 40 | """ 41 | torch.manual_seed(seed) 42 | torch.cuda.manual_seed_all(seed) 43 | np.random.seed(seed) 44 | random.seed(seed) 45 | 46 | 47 | def get_sha(): 48 | cwd = os.path.dirname(os.path.abspath(__file__)) 49 | 50 | def _run(command): 51 | return subprocess.check_output(command, cwd=cwd).decode("ascii").strip() 52 | 53 | sha = "N/A" 54 | diff = "clean" 55 | branch = "N/A" 56 | try: 57 | sha = _run(["git", "rev-parse", "HEAD"]) 58 | subprocess.check_output(["git", "diff"], cwd=cwd) 59 | diff = _run(["git", "diff-index", "HEAD"]) 60 | diff = "has uncommitted changes" if diff else "clean" 61 | branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"]) 62 | except Exception: 63 | pass 64 | message = f"sha: {sha}, status: {diff}, branch: {branch}" 65 | return message 66 | 67 | 68 | class CosineScheduler(object): 69 | def __init__(self, base_value, final_value, total_iters, warmup_iters=0, start_warmup_value=0, freeze_iters=0): 70 | super().__init__() 71 | self.final_value = final_value 72 | self.total_iters = total_iters 73 | 74 | freeze_schedule = np.zeros((freeze_iters)) 75 | 76 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) 77 | 78 | iters = np.arange(total_iters - warmup_iters - freeze_iters) 79 | schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters))) 80 | self.schedule = np.concatenate((freeze_schedule, warmup_schedule, schedule)) 81 | 82 | assert len(self.schedule) == self.total_iters 83 | 84 | def __getitem__(self, it): 85 | if it >= self.total_iters: 86 | return self.final_value 87 | else: 88 | return self.schedule[it] 89 | 90 | 91 | def has_batchnorms(model): 92 | bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm) 93 | for name, module in model.named_modules(): 94 | if isinstance(module, bn_types): 95 | return True 96 | return False 97 | -------------------------------------------------------------------------------- /dinov2/pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 120 3 | 4 | [tool.pylint.master] 5 | persistent = false 6 | score = false 7 | 8 | [tool.pylint.messages_control] 9 | disable = "all" 10 | enable = [ 11 | "miscellaneous", 12 | "similarities", 13 | ] 14 | 15 | [tool.pylint.similarities] 16 | ignore-comments = true 17 | ignore-docstrings = true 18 | ignore-imports = true 19 | min-similarity-lines = 8 20 | 21 | [tool.pylint.reports] 22 | reports = false 23 | 24 | [tool.pylint.miscellaneous] 25 | notes = [ 26 | "FIXME", 27 | "XXX", 28 | "TODO", 29 | ] 30 | -------------------------------------------------------------------------------- /dinov2/requirements-dev.txt: -------------------------------------------------------------------------------- 1 | black==22.6.0 2 | flake8==5.0.4 3 | pylint==2.15.0 4 | -------------------------------------------------------------------------------- /dinov2/requirements.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cu117 2 | torch==2.0.0 3 | torchvision==0.15.0 4 | omegaconf 5 | torchmetrics==0.10.3 6 | fvcore 7 | iopath 8 | xformers==0.0.18 9 | submitit 10 | --extra-index-url https://pypi.nvidia.com 11 | cuml-cu11 12 | -------------------------------------------------------------------------------- /dinov2/scripts/lint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | if [ -n "$1" ]; then 4 | echo "linting \"$1\"" 5 | fi 6 | 7 | echo "running black" 8 | if [ -n "$1" ]; then 9 | black "$1" 10 | else 11 | black dinov2 12 | fi 13 | 14 | echo "running flake8" 15 | if [ -n "$1" ]; then 16 | flake8 "$1" 17 | else 18 | flake8 19 | fi 20 | 21 | echo "running pylint" 22 | if [ -n "$1" ]; then 23 | pylint "$1" 24 | else 25 | pylint dinov2 26 | fi 27 | 28 | exit 0 29 | -------------------------------------------------------------------------------- /dinov2/setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 120 3 | ignore = E203,E501,W503 4 | per-file-ignores = 5 | __init__.py:F401 6 | -------------------------------------------------------------------------------- /dinov2/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from pathlib import Path 8 | import re 9 | from typing import List, Tuple 10 | 11 | from setuptools import setup, find_packages 12 | 13 | 14 | NAME = "dinov2" 15 | DESCRIPTION = "PyTorch code and models for the DINOv2 self-supervised learning method." 16 | 17 | URL = "https://github.com/facebookresearch/dinov2" 18 | AUTHOR = "FAIR" 19 | REQUIRES_PYTHON = ">=3.9.0" 20 | HERE = Path(__file__).parent 21 | 22 | 23 | try: 24 | with open(HERE / "README.md", encoding="utf-8") as f: 25 | long_description = "\n" + f.read() 26 | except FileNotFoundError: 27 | long_description = DESCRIPTION 28 | 29 | 30 | def get_requirements(path: str = HERE / "requirements.txt") -> Tuple[List[str], List[str]]: 31 | requirements = [] 32 | extra_indices = [] 33 | with open(path) as f: 34 | for line in f.readlines(): 35 | line = line.rstrip("\r\n") 36 | if line.startswith("--extra-index-url "): 37 | extra_indices.append(line[18:]) 38 | continue 39 | requirements.append(line) 40 | return requirements, extra_indices 41 | 42 | 43 | def get_package_version() -> str: 44 | with open(HERE / "dinov2/__init__.py") as f: 45 | result = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", f.read(), re.M) 46 | if result: 47 | return result.group(1) 48 | raise RuntimeError("Can't get package version") 49 | 50 | 51 | requirements, extra_indices = get_requirements() 52 | version = get_package_version() 53 | dev_requirements, _ = get_requirements(HERE / "requirements-dev.txt") 54 | 55 | 56 | setup( 57 | name=NAME, 58 | version=version, 59 | description=DESCRIPTION, 60 | long_description=long_description, 61 | long_description_content_type="text/markdown", 62 | author=AUTHOR, 63 | python_requires=REQUIRES_PYTHON, 64 | url=URL, 65 | packages=find_packages(), 66 | package_data={ 67 | "": ["*.yaml"], 68 | }, 69 | install_requires=requirements, 70 | dependency_links=extra_indices, 71 | extras_require={ 72 | "dev": dev_requirements, 73 | }, 74 | install_package_data=True, 75 | license="CC-BY-NC", 76 | license_files=("LICENSE",), 77 | classifiers=[ 78 | # Trove classifiers: https://github.com/pypa/trove-classifiers/blob/main/src/trove_classifiers/__init__.py 79 | "Development Status :: 3 - Alpha", 80 | "Intended Audience :: Developers", 81 | "Intended Audience :: Science/Research", 82 | "License :: Other/Proprietary License", 83 | "Programming Language :: Python :: 3.9", 84 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 85 | "Topic :: Software Development :: Libraries :: Python Modules", 86 | ], 87 | ) 88 | -------------------------------------------------------------------------------- /dinov2/vis.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | import torchvision.transforms as T 4 | import hubconf 5 | from sklearn.decomposition import PCA 6 | from PIL import Image 7 | import numpy as np 8 | 9 | def visualize(features, model): 10 | # vis code from https://github.com/dichotomies/N3F/tree/master/feature_extractor 11 | h, w = int(img.shape[2] / model.patch_embed.patch_size[0]), int( 12 | img.shape[3] / model.patch_embed.patch_size[1] 13 | ) 14 | dim = features.shape[-1] 15 | features = features.reshape(-1, h, w, dim).permute(0, 3, 1, 2) 16 | print(features.shape) 17 | 18 | all_features = features.cpu() 19 | pca = PCA(n_components=3) 20 | N, C, H, W = all_features.shape 21 | all_features = all_features.permute(0, 2, 3, 1).view(-1, C).numpy() 22 | print("Features shape: ", all_features.shape) 23 | pca_features = pca.fit_transform(all_features) 24 | pca_features = (pca_features - pca_features.min()) / (pca_features.max() - pca_features.min()) 25 | pca_features = pca_features * 255 26 | pca_features = pca_features.reshape(h, w, 3) 27 | print(pca_features.shape) 28 | vis_img = Image.fromarray(pca_features.astype(np.uint8)) 29 | vis_img.save('vis.jpg') 30 | 31 | 32 | model = hubconf.dinov2_vitl14() 33 | 34 | img = Image.open('00012_00.jpg') 35 | transform = T.Compose([ 36 | # T.Resize(256, interpolation=T.InterpolationMode.BICUBIC), 37 | # T.CenterCrop(224), 38 | # T.Resize((574,448), interpolation=T.InterpolationMode.BICUBIC), 39 | T.Resize((280,224), interpolation=T.InterpolationMode.BICUBIC), 40 | T.ToTensor(), 41 | T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 42 | ]) 43 | 44 | img = transform(img)[:3].unsqueeze(0) 45 | img = img.cuda() 46 | model = model.cuda() 47 | with torch.no_grad(): 48 | # 将图像张量传递给dinov2_vits14模型获取特征 49 | features_dict = model.forward_features(img) 50 | features_patchtokens = features_dict['x_norm_patchtokens'] 51 | feature_clstoken = features_dict['x_norm_clstoken'].unsqueeze(1) 52 | print(feature_clstoken.shape,features_patchtokens.shape) 53 | features = torch.cat([feature_clstoken,features_patchtokens],1) 54 | print(features.shape) 55 | 56 | visualize(features_patchtokens, model) -------------------------------------------------------------------------------- /figures/Figure12_TikTok_case1.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scnuhealthy/video_try_on/d78b86fba73c04d4488ee99c8b111bc3614b3d4d/figures/Figure12_TikTok_case1.mp4 -------------------------------------------------------------------------------- /figures/Figure12_TikTok_case5.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scnuhealthy/video_try_on/d78b86fba73c04d4488ee99c8b111bc3614b3d4d/figures/Figure12_TikTok_case5.mp4 -------------------------------------------------------------------------------- /figures/case1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scnuhealthy/video_try_on/d78b86fba73c04d4488ee99c8b111bc3614b3d4d/figures/case1.gif -------------------------------------------------------------------------------- /figures/case2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scnuhealthy/video_try_on/d78b86fba73c04d4488ee99c8b111bc3614b3d4d/figures/case2.gif -------------------------------------------------------------------------------- /figures/p1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scnuhealthy/video_try_on/d78b86fba73c04d4488ee99c8b111bc3614b3d4d/figures/p1.png -------------------------------------------------------------------------------- /figures/p2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scnuhealthy/video_try_on/d78b86fba73c04d4488ee99c8b111bc3614b3d4d/figures/p2.png -------------------------------------------------------------------------------- /mae/INSTALL.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | ## Requirements 4 | - Python >= 3.8 5 | - Numpy 6 | - PyTorch >= 1.3 7 | - [torchvision](https://github.com/pytorch/vision/) that matches the PyTorch installation. 8 | You can install them together at [pytorch.org](https://pytorch.org) to make sure of this. 9 | - simplejson: `pip install simplejson` 10 | - GCC >= 4.9 11 | - PyAV: `conda install av -c conda-forge` 12 | - ffmpeg (4.0 is prefereed, will be installed along with PyAV) 13 | - tqdm: `pip install tqdm` 14 | - iopath: `pip install -U iopath` or `conda install -c iopath iopath` 15 | - psutil: `pip install psutil` 16 | - OpenCV: `pip install opencv-python` 17 | - torchvision: `pip install torchvision` or `conda install torchvision -c pytorch` 18 | - tensorboard: `pip install tensorboard` 19 | - timm: `pip install timm==0.3.2` 20 | 21 | ## Pytorch 22 | Please follow PyTorch official instructions to install from source: 23 | ``` 24 | git clone --recursive https://github.com/pytorch/pytorch 25 | ``` 26 | -------------------------------------------------------------------------------- /mae/README.md: -------------------------------------------------------------------------------- 1 | ## Masked Autoencoders As Spatiotemporal Learners: A PyTorch Implementation 2 | 3 |

4 | 5 |

6 | 7 | 8 | This is a PyTorch/GPU re-implementation of the paper [Masked Autoencoders As Spatiotemporal Learners](https://arxiv.org/abs/2205.09113): 9 | ``` 10 | @Article{MaskedAutoencodersSpatiotemporal2022, 11 | author = {Christoph Feichtenhofer and Haoqi Fan and Yanghao Li and Kaiming He}, 12 | journal = {arXiv:2205.09113}, 13 | title = {Masked Autoencoders As Spatiotemporal Learners}, 14 | year = {2022}, 15 | } 16 | ``` 17 | 18 | * This repo is a modification on the [MAE repo](https://github.com/facebookresearch/mae). Installation and preparation follow [INSTALL.md](INSTALL.md). 19 | 20 | * This repo is based on [`timm==0.3.2`](https://github.com/rwightman/pytorch-image-models), for which a [fix](https://github.com/rwightman/pytorch-image-models/issues/420#issuecomment-776459842) is needed to work with PyTorch 1.8.1+. 21 | 22 | 23 | 24 | ### Catalog 25 | 26 | - [x] Visualization demo 27 | - [x] Pre-trained checkpoints + fine-tuning code + testing code 28 | - [x] Pre-training code 29 | 30 | ### Visualization demo 31 | 32 | 33 | Visualization of MAE output with 95% (left) and 98% (right) mask rate on the same video. 34 |
35 | 36 |
37 | 38 | 39 | Visualization of MAE output with 95% (left) and 98% (right) mask rate on the same video. 40 |
41 | 42 |
43 | 44 | Visualization of MAE output with 95% mask rate. 45 |
46 | 47 |
48 | 49 | Run our interactive visualization demo using [Colab notebook](https://colab.research.google.com/github/haooooooqi/visualization/blob/main/video_mae_visualize.ipynb) (no GPU needed): 50 |

51 | 52 |

53 | 54 | ### Fine-tuning with pre-trained checkpoints 55 | 56 | The following table provides the pre-trained checkpoints used in the paper, pretrained with **90%** mask ratio and **1600 effective epochs**, converted from PySlowFast codebase: 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 |
ViT-LargeViT-Huge
pre-trained checkpoint on Kinetics-400downloaddownload
md5edf3a53d7f64
75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 |
ViT-LargeViT-Huge
pre-trained checkpoint on Kinetics-600downloaddownload
md59a964527495e
94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 |
ViT-LargeViT-Huge
pre-trained checkpoint on Kinetics-700downloaddownload
md5cdbada4c4e3c
113 | 114 | 115 | The fine-tuning instruction is in [FINETUNE.md](FINETUNE.md). 116 | 117 | 118 | ### Pre-training 119 | 120 | The pre-training instruction is in [PRETRAIN.md](PRETRAIN.md). 121 | 122 | ### License 123 | 124 | This project is under the CC-BY-NC 4.0 license. See [LICENSE](LICENSE) for details. 125 | -------------------------------------------------------------------------------- /mae/TARGETS: -------------------------------------------------------------------------------- 1 | load("@fbcode_macros//build_defs:python_binary.bzl", "python_binary") 2 | load("@fbcode_macros//build_defs:python_library.bzl", "python_library") 3 | load("//tools/xar:defs.bzl", "xar_python_binary") 4 | 5 | python_library( 6 | name = "lib", 7 | srcs = glob( 8 | [ 9 | "*.py", 10 | "**/*.py", 11 | "**/**/*.py", 12 | ], 13 | exclude = [ 14 | "main_finetune.py", 15 | "main_linprobe.py", 16 | "main_pretrain.py", 17 | ], 18 | ), 19 | base_module = "mae", 20 | deps = [ 21 | "fbsource//third-party/pypi/timm:timm", 22 | "//caffe2:torch", 23 | "//caffe2:torch_tensorboard", 24 | "//caffe2/torch/fb/rendezvous:zeus", 25 | "//fair_infra/data/iopath/iopath:iopath", 26 | "//fair_infra/data/prefetcher:prefetcher", 27 | "//fblearner/flow/facebook:flow_fb_lib", 28 | "//github/facebookresearch/fairscale:fairscale", 29 | "//python/wheel/av:av", 30 | "//python/wheel/moviepy:moviepy", 31 | "//ti/urlgen:everstore_url_py", 32 | "//vision/fair/detectron2/detectron2:detectron2", 33 | "//vision/fair/fvcore/fvcore:fvcore", 34 | "//vision/fair/pytorchvideo/pytorchvideo:pytorchvideo", 35 | ], 36 | external_deps = [ 37 | "PyYAML", 38 | "matplotlib", 39 | "numpy", 40 | "opencv3", 41 | "simplejson", 42 | ("pycurl", None), 43 | "scikit-learn", 44 | ], 45 | ) 46 | 47 | python_library( 48 | name = "main_finetune", 49 | srcs = ["main_finetune.py"], 50 | base_module = "", 51 | py_version = ">=3.6", 52 | deps = ["//vision/fair/mae:lib"], 53 | ) 54 | 55 | python_binary( 56 | name = "run_finetune_bin", 57 | srcs = ["run_finetune.py"], 58 | base_module = "", 59 | compile = "with-source", 60 | main_module = "run_finetune", 61 | par_style = "xar", 62 | py_version = ">=3.6", 63 | deps = [ 64 | "//vision/fair/mae:lib", 65 | "//vision/fair/mae:main_finetune", 66 | ], 67 | ) 68 | 69 | xar_python_binary( 70 | name = "run_finetune_xar", 71 | output_name = "run_finetune.xar", 72 | src_rule_name = ":run_finetune_bin", 73 | ) 74 | 75 | python_library( 76 | name = "main_pretrain", 77 | srcs = ["main_pretrain.py"], 78 | base_module = "", 79 | py_version = ">=3.6", 80 | deps = ["//vision/fair/mae:lib"], 81 | ) 82 | 83 | python_binary( 84 | name = "run_pretrain_bin", 85 | srcs = ["run_pretrain.py"], 86 | base_module = "", 87 | compile = "with-source", 88 | main_module = "run_pretrain", 89 | par_style = "xar", 90 | py_version = ">=3.6", 91 | deps = [ 92 | "//vision/fair/mae:lib", 93 | "//vision/fair/mae:main_pretrain", 94 | ], 95 | ) 96 | 97 | xar_python_binary( 98 | name = "run_pretrain_xar", 99 | output_name = "run_pretrain.xar", 100 | src_rule_name = ":run_pretrain_bin", 101 | ) 102 | 103 | python_library( 104 | name = "main_test", 105 | srcs = ["main_test.py"], 106 | base_module = "", 107 | py_version = ">=3.6", 108 | deps = ["//vision/fair/mae:lib"], 109 | ) 110 | 111 | python_binary( 112 | name = "run_test_bin", 113 | srcs = ["run_test.py"], 114 | base_module = "", 115 | compile = "with-source", 116 | main_module = "run_test", 117 | par_style = "xar", 118 | py_version = ">=3.6", 119 | deps = [ 120 | "//vision/fair/mae:lib", 121 | "//vision/fair/mae:main_test", 122 | ], 123 | ) 124 | 125 | xar_python_binary( 126 | name = "run_test_xar", 127 | output_name = "run_test.xar", 128 | src_rule_name = ":run_test_bin", 129 | ) 130 | -------------------------------------------------------------------------------- /mae/TODO.md: -------------------------------------------------------------------------------- 1 | TODO: 2 | 3 | * finish 800 and 1600 ep sanity check and update it to the [benchmark sheet](https://docs.google.com/spreadsheets/d/1XiXG_227fu8GFD9tXvBBldj41d6K-87JNUKz7EZqB44/edit?usp=sharing) 4 | * Polish the comments/ instrcutions for the interactivate demo, README 5 | * Code review with the team 6 | * SRT for repo 7 | -------------------------------------------------------------------------------- /mae/__pycache__/models_mae.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scnuhealthy/video_try_on/d78b86fba73c04d4488ee99c8b111bc3614b3d4d/mae/__pycache__/models_mae.cpython-310.pyc -------------------------------------------------------------------------------- /mae/__pycache__/models_mae.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scnuhealthy/video_try_on/d78b86fba73c04d4488ee99c8b111bc3614b3d4d/mae/__pycache__/models_mae.cpython-39.pyc -------------------------------------------------------------------------------- /mae/demo-preview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scnuhealthy/video_try_on/d78b86fba73c04d4488ee99c8b111bc3614b3d4d/mae/demo-preview.png -------------------------------------------------------------------------------- /mae/demo/goods.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scnuhealthy/video_try_on/d78b86fba73c04d4488ee99c8b111bc3614b3d4d/mae/demo/goods.mp4 -------------------------------------------------------------------------------- /mae/demo/qZ_lFjCiR1c_000104_000114.avi: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scnuhealthy/video_try_on/d78b86fba73c04d4488ee99c8b111bc3614b3d4d/mae/demo/qZ_lFjCiR1c_000104_000114.avi -------------------------------------------------------------------------------- /mae/demo/v_KW4TDvxIc_000223_000233.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scnuhealthy/video_try_on/d78b86fba73c04d4488ee99c8b111bc3614b3d4d/mae/demo/v_KW4TDvxIc_000223_000233.mp4 -------------------------------------------------------------------------------- /mae/engine_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import mae.util.misc as misc 13 | import torch 14 | 15 | 16 | @torch.no_grad() 17 | def test(data_loader, model, device, test_meter, fp32=False): 18 | metric_logger = misc.MetricLogger(delimiter=" ") 19 | 20 | # switch to evaluation mode 21 | model.eval() 22 | 23 | for cur_iter, (images, labels, video_idx) in enumerate(data_loader): 24 | images = images.to(device, non_blocking=True) 25 | labels = labels.to(device, non_blocking=True) 26 | video_idx = video_idx.to(device, non_blocking=True) 27 | 28 | if len(images.shape) == 6: 29 | b, r, c, t, h, w = images.shape 30 | images = images.view(b * r, c, t, h, w) 31 | labels = labels.view(b * r) 32 | 33 | # compute output 34 | with torch.cuda.amp.autocast(enabled=not fp32): 35 | preds = model(images) 36 | 37 | if torch.distributed.is_initialized(): 38 | preds, labels, video_idx = misc.all_gather([preds, labels, video_idx]) 39 | preds = preds.cpu() 40 | labels = labels.cpu() 41 | video_idx = video_idx.cpu() 42 | # Update and log stats. 43 | test_meter.update_stats(preds.detach(), labels.detach(), video_idx.detach()) 44 | test_meter.log_iter_stats(cur_iter) 45 | 46 | test_meter.finalize_metrics() 47 | # gather the stats from all processes 48 | metric_logger.synchronize_between_processes() 49 | return test_meter.stats 50 | -------------------------------------------------------------------------------- /mae/launch_local.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # sudo fuser -v /dev/nvidia* | grep -o '[[:digit:]]*' |xargs -I{} sudo kill -9 {} 4 | 5 | # buck build --config client.skip-action-graph-cache=true @mode/opt -c python.native_link_strategy=separate \ 6 | buck build @mode/opt @mode/inplace \ 7 | //vision/fair/mae/... --show-output 8 | 9 | # 0: pretrain, 1: finetune, 2: test 10 | 11 | if [ "$1" -lt 1 ] 12 | then 13 | 14 | echo "pretrain" 15 | 16 | /data/users/"${USER}"/fbsource/fbcode/buck-out/gen/vision/fair/mae/run_pretrain_bin.par \ 17 | --decoder_attn AttentionOrg --encoder_attn AttentionOrg \ 18 | --decoder_num_heads 16 \ 19 | --batch_size 2 --decoder_embed_dim 512 --decoder_depth 4 \ 20 | --epochs 400 --mask_ratio 0.90 --repeat_aug 2 --video \ 21 | --model mae_vit_large_patch16 \ 22 | --sampling_rate 4 --num_frames 16 \ 23 | --mask_type st \ 24 | --num_workers 2 \ 25 | --bias_wd \ 26 | --mask_schedule "cos" --mask_ratio 0.90 --mask_ratio_end 0.99 \ 27 | --trunc_init \ 28 | --fp32 \ 29 | --knn_monitor --knn_period 1 \ 30 | --t_patch_size 2 \ 31 | --fp32 \ 32 | --mask_type st --mask_schedule const --mask_ratio 0.95 --repeat_aug 4 --sampling_rate 4 --norm_pix_loss \ 33 | --resume manifold://winvision/tree/haoqifan/logs/2022-04-22-212409-469/pretrain/checkpoint-00064.pth \ 34 | 35 | --sep_pos_embed \ 36 | --learnable_pos_embed \ 37 | 38 | --decoder_attn AttentionRelPos --encoder_attn AttentionRelPos --rel_pos_embed \ 39 | 40 | else 41 | 42 | if [ "$1" -lt 2 ] 43 | then 44 | 45 | echo "finetune" 46 | 47 | # AttentionSubsampleMaxpool, AttentionSubsampleStride2, AttentionSubsampleRand10, AttentionSubsampleRand25, AttentionSubsampleRand50, 48 | /data/users/"${USER}"/fbsource/fbcode/buck-out/gen/vision/fair/mae/run_finetune_bin.par \ 49 | --batch_size 1 --epochs 1 --repeat_aug 1 --video --smoothing 0.1 \ 50 | --mixup 0.0 --cutmix 0.0 --mixup_prob 0.0 \ 51 | --model vit_large_patch16 \ 52 | --t_patch_size 4 --num_frames 16 \ 53 | --rand_aug \ 54 | --encoder_attn AttentionRelPos \ 55 | --rel_pos_init_std 1.0 \ 56 | --sep_pos_embed \ 57 | --fp32 \ 58 | 59 | --finetune manifold://winvision/tree/haoqifan/logs/2022-02-05-204420-480/pretrain/checkpoint-00399.pth \ 60 | 61 | # --no_qkv_bias 62 | 63 | # --encoder_attn AttentionSubsampleRand10 \ 64 | 65 | # --finetune manifold://fair_logging/tree/haoqifan/logs/2022-01-17-162701-592/pretrain/checkpoint-399.pth 66 | 67 | else 68 | 69 | echo "test" 70 | 71 | # AttentionSubsampleMaxpool, AttentionSubsampleStride2, AttentionSubsampleRand10, AttentionSubsampleRand25, AttentionSubsampleRand50, 72 | /data/users/"${USER}"/fbsource/fbcode/buck-out/gen/vision/fair/mae/run_test_bin.par \ 73 | --batch_size 2 --encoder_attn AttentionSubsampleRand10 \ 74 | --model vit_large_patch16 \ 75 | --t_patch_size 4 --num_frames 16 \ 76 | --finetune manifold://fair_logging/tree/haoqifan/logs/2022-01-25-012936-625/downstream/checkpoint-99.pth 77 | 78 | # --finetune manifold://fair_logging/tree/haoqifan/logs/2022-01-17-162701-592/pretrain/checkpoint-399.pth 79 | 80 | fi 81 | fi 82 | -------------------------------------------------------------------------------- /mae/launch_tensorboard.sh: -------------------------------------------------------------------------------- 1 | # buck build @mode/opt //tensorboard 2 | if [ "$1" -lt 1 ] 3 | then 4 | ~/local/fbsource/fbcode/buck-out/gen/tensorboard/tensorboard.par --port=8092 --logdir=manifold://winvision/tree/haoqifan/logs/tensorboard/pretrain 5 | else 6 | ~/local/fbsource/fbcode/buck-out/gen/tensorboard/tensorboard.par --port=8091 --logdir=manifold://winvision/tree/haoqifan/logs/tensorboard/downstream 7 | fi 8 | 9 | # ~/local/fbsource/fbcode/buck-out/gen/tensorboard/tensorboard.par --port=8095 --logdir_spec=fair_logging:manifold://fair_logging/tree/haoqifan/logs/tensorboard/downstream,winvision:manifold://winvision/tree/haoqifan/logs/tensorboard/downstream 10 | -------------------------------------------------------------------------------- /mae/masks.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scnuhealthy/video_try_on/d78b86fba73c04d4488ee99c8b111bc3614b3d4d/mae/masks.png -------------------------------------------------------------------------------- /mae/mr-95-demo-vid-0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scnuhealthy/video_try_on/d78b86fba73c04d4488ee99c8b111bc3614b3d4d/mae/mr-95-demo-vid-0.gif -------------------------------------------------------------------------------- /mae/mr-95-demo-vid-1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scnuhealthy/video_try_on/d78b86fba73c04d4488ee99c8b111bc3614b3d4d/mae/mr-95-demo-vid-1.gif -------------------------------------------------------------------------------- /mae/mr-98-demo-vid-0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scnuhealthy/video_try_on/d78b86fba73c04d4488ee99c8b111bc3614b3d4d/mae/mr-98-demo-vid-0.gif -------------------------------------------------------------------------------- /mae/mr-98-demo-vid-1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scnuhealthy/video_try_on/d78b86fba73c04d4488ee99c8b111bc3614b3d4d/mae/mr-98-demo-vid-1.gif -------------------------------------------------------------------------------- /mae/rm_files/submitit_finetune.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # A script to run multinode training with submitit. 8 | # -------------------------------------------------------- 9 | 10 | import argparse 11 | import os 12 | import uuid 13 | from pathlib import Path 14 | 15 | import main_finetune as classification 16 | import submitit 17 | 18 | 19 | def parse_args(): 20 | classification_parser = classification.get_args_parser() 21 | parser = argparse.ArgumentParser("Submitit for MAE finetune", parents=[classification_parser]) 22 | parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node") 23 | parser.add_argument("--nodes", default=2, type=int, help="Number of nodes to request") 24 | parser.add_argument("--timeout", default=4320, type=int, help="Duration of the job") 25 | parser.add_argument("--job_dir", default="", type=str, help="Job dir. Leave empty for automatic.") 26 | 27 | parser.add_argument("--partition", default="learnfair", type=str, help="Partition where to submit") 28 | parser.add_argument("--use_volta32", action='store_true', help="Request 32G V100 GPUs") 29 | parser.add_argument('--comment', default="", type=str, help="Comment to pass to scheduler") 30 | return parser.parse_args() 31 | 32 | 33 | def get_shared_folder() -> Path: 34 | user = os.getenv("USER") 35 | if Path("/checkpoint/").is_dir(): 36 | p = Path(f"/checkpoint/{user}/experiments") 37 | p.mkdir(exist_ok=True) 38 | return p 39 | raise RuntimeError("No shared folder available") 40 | 41 | 42 | def get_init_file(): 43 | # Init file must not exist, but it's parent dir must exist. 44 | os.makedirs(str(get_shared_folder()), exist_ok=True) 45 | init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init" 46 | if init_file.exists(): 47 | os.remove(str(init_file)) 48 | return init_file 49 | 50 | 51 | class Trainer(object): 52 | def __init__(self, args): 53 | self.args = args 54 | 55 | def __call__(self): 56 | import main_finetune as classification 57 | 58 | self._setup_gpu_args() 59 | classification.main(self.args) 60 | 61 | def checkpoint(self): 62 | import os 63 | import submitit 64 | 65 | self.args.dist_url = get_init_file().as_uri() 66 | checkpoint_file = os.path.join(self.args.output_dir, "checkpoint.pth") 67 | if os.path.exists(checkpoint_file): 68 | self.args.resume = checkpoint_file 69 | print("Requeuing ", self.args) 70 | empty_trainer = type(self)(self.args) 71 | return submitit.helpers.DelayedSubmission(empty_trainer) 72 | 73 | def _setup_gpu_args(self): 74 | import submitit 75 | from pathlib import Path 76 | 77 | job_env = submitit.JobEnvironment() 78 | self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id))) 79 | self.args.log_dir = self.args.output_dir 80 | self.args.gpu = job_env.local_rank 81 | self.args.rank = job_env.global_rank 82 | self.args.world_size = job_env.num_tasks 83 | print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") 84 | 85 | 86 | def main(): 87 | args = parse_args() 88 | if args.job_dir == "": 89 | args.job_dir = get_shared_folder() / "%j" 90 | 91 | # Note that the folder will depend on the job_id, to easily track experiments 92 | executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30) 93 | 94 | num_gpus_per_node = args.ngpus 95 | nodes = args.nodes 96 | timeout_min = args.timeout 97 | 98 | partition = args.partition 99 | kwargs = {} 100 | if args.use_volta32: 101 | kwargs['slurm_constraint'] = 'volta32gb' 102 | if args.comment: 103 | kwargs['slurm_comment'] = args.comment 104 | 105 | executor.update_parameters( 106 | mem_gb=40 * num_gpus_per_node, 107 | gpus_per_node=num_gpus_per_node, 108 | tasks_per_node=num_gpus_per_node, # one task per GPU 109 | cpus_per_task=10, 110 | nodes=nodes, 111 | timeout_min=timeout_min, 112 | # Below are cluster dependent parameters 113 | slurm_partition=partition, 114 | slurm_signal_delay_s=120, 115 | **kwargs 116 | ) 117 | 118 | executor.update_parameters(name="mae") 119 | 120 | args.dist_url = get_init_file().as_uri() 121 | args.output_dir = args.job_dir 122 | 123 | trainer = Trainer(args) 124 | job = executor.submit(trainer) 125 | 126 | # print("Submitted job_id:", job.job_id) 127 | print(job.job_id) 128 | 129 | 130 | if __name__ == "__main__": 131 | main() 132 | -------------------------------------------------------------------------------- /mae/rm_files/submitit_linprobe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # A script to run multinode training with submitit. 8 | # -------------------------------------------------------- 9 | 10 | import argparse 11 | import os 12 | import uuid 13 | from pathlib import Path 14 | 15 | import main_linprobe as classification 16 | import submitit 17 | 18 | 19 | def parse_args(): 20 | classification_parser = classification.get_args_parser() 21 | parser = argparse.ArgumentParser("Submitit for MAE linear probe", parents=[classification_parser]) 22 | parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node") 23 | parser.add_argument("--nodes", default=2, type=int, help="Number of nodes to request") 24 | parser.add_argument("--timeout", default=4320, type=int, help="Duration of the job") 25 | parser.add_argument("--job_dir", default="", type=str, help="Job dir. Leave empty for automatic.") 26 | 27 | parser.add_argument("--partition", default="learnfair", type=str, help="Partition where to submit") 28 | parser.add_argument("--use_volta32", action='store_true', help="Request 32G V100 GPUs") 29 | parser.add_argument('--comment', default="", type=str, help="Comment to pass to scheduler") 30 | return parser.parse_args() 31 | 32 | 33 | def get_shared_folder() -> Path: 34 | user = os.getenv("USER") 35 | if Path("/checkpoint/").is_dir(): 36 | p = Path(f"/checkpoint/{user}/experiments") 37 | p.mkdir(exist_ok=True) 38 | return p 39 | raise RuntimeError("No shared folder available") 40 | 41 | 42 | def get_init_file(): 43 | # Init file must not exist, but it's parent dir must exist. 44 | os.makedirs(str(get_shared_folder()), exist_ok=True) 45 | init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init" 46 | if init_file.exists(): 47 | os.remove(str(init_file)) 48 | return init_file 49 | 50 | 51 | class Trainer(object): 52 | def __init__(self, args): 53 | self.args = args 54 | 55 | def __call__(self): 56 | import main_linprobe as classification 57 | 58 | self._setup_gpu_args() 59 | classification.main(self.args) 60 | 61 | def checkpoint(self): 62 | import os 63 | import submitit 64 | 65 | self.args.dist_url = get_init_file().as_uri() 66 | checkpoint_file = os.path.join(self.args.output_dir, "checkpoint.pth") 67 | if os.path.exists(checkpoint_file): 68 | self.args.resume = checkpoint_file 69 | print("Requeuing ", self.args) 70 | empty_trainer = type(self)(self.args) 71 | return submitit.helpers.DelayedSubmission(empty_trainer) 72 | 73 | def _setup_gpu_args(self): 74 | import submitit 75 | from pathlib import Path 76 | 77 | job_env = submitit.JobEnvironment() 78 | self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id))) 79 | self.args.log_dir = self.args.output_dir 80 | self.args.gpu = job_env.local_rank 81 | self.args.rank = job_env.global_rank 82 | self.args.world_size = job_env.num_tasks 83 | print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") 84 | 85 | 86 | def main(): 87 | args = parse_args() 88 | if args.job_dir == "": 89 | args.job_dir = get_shared_folder() / "%j" 90 | 91 | # Note that the folder will depend on the job_id, to easily track experiments 92 | executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30) 93 | 94 | num_gpus_per_node = args.ngpus 95 | nodes = args.nodes 96 | timeout_min = args.timeout 97 | 98 | partition = args.partition 99 | kwargs = {} 100 | if args.use_volta32: 101 | kwargs['slurm_constraint'] = 'volta32gb' 102 | if args.comment: 103 | kwargs['slurm_comment'] = args.comment 104 | 105 | executor.update_parameters( 106 | mem_gb=40 * num_gpus_per_node, 107 | gpus_per_node=num_gpus_per_node, 108 | tasks_per_node=num_gpus_per_node, # one task per GPU 109 | cpus_per_task=10, 110 | nodes=nodes, 111 | timeout_min=timeout_min, 112 | # Below are cluster dependent parameters 113 | slurm_partition=partition, 114 | slurm_signal_delay_s=120, 115 | **kwargs 116 | ) 117 | 118 | executor.update_parameters(name="mae") 119 | 120 | args.dist_url = get_init_file().as_uri() 121 | args.output_dir = args.job_dir 122 | 123 | trainer = Trainer(args) 124 | job = executor.submit(trainer) 125 | 126 | # print("Submitted job_id:", job.job_id) 127 | print(job.job_id) 128 | 129 | 130 | if __name__ == "__main__": 131 | main() 132 | -------------------------------------------------------------------------------- /mae/rm_files/submitit_pretrain.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # A script to run multinode training with submitit. 8 | # -------------------------------------------------------- 9 | 10 | import argparse 11 | import os 12 | import uuid 13 | from pathlib import Path 14 | 15 | import main_pretrain as trainer 16 | import submitit 17 | 18 | 19 | def parse_args(): 20 | trainer_parser = trainer.get_args_parser() 21 | parser = argparse.ArgumentParser("Submitit for MAE pretrain", parents=[trainer_parser]) 22 | parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node") 23 | parser.add_argument("--nodes", default=2, type=int, help="Number of nodes to request") 24 | parser.add_argument("--timeout", default=4320, type=int, help="Duration of the job") 25 | parser.add_argument("--job_dir", default="", type=str, help="Job dir. Leave empty for automatic.") 26 | 27 | parser.add_argument("--partition", default="learnfair", type=str, help="Partition where to submit") 28 | parser.add_argument("--use_volta32", action='store_true', help="Request 32G V100 GPUs") 29 | parser.add_argument('--comment', default="", type=str, help="Comment to pass to scheduler") 30 | return parser.parse_args() 31 | 32 | 33 | def get_shared_folder() -> Path: 34 | user = os.getenv("USER") 35 | if Path("/checkpoint/").is_dir(): 36 | p = Path(f"/checkpoint/{user}/experiments") 37 | p.mkdir(exist_ok=True) 38 | return p 39 | raise RuntimeError("No shared folder available") 40 | 41 | 42 | def get_init_file(): 43 | # Init file must not exist, but it's parent dir must exist. 44 | os.makedirs(str(get_shared_folder()), exist_ok=True) 45 | init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init" 46 | if init_file.exists(): 47 | os.remove(str(init_file)) 48 | return init_file 49 | 50 | 51 | class Trainer(object): 52 | def __init__(self, args): 53 | self.args = args 54 | 55 | def __call__(self): 56 | import main_pretrain as trainer 57 | 58 | self._setup_gpu_args() 59 | trainer.main(self.args) 60 | 61 | def checkpoint(self): 62 | import os 63 | import submitit 64 | 65 | self.args.dist_url = get_init_file().as_uri() 66 | checkpoint_file = os.path.join(self.args.output_dir, "checkpoint.pth") 67 | if os.path.exists(checkpoint_file): 68 | self.args.resume = checkpoint_file 69 | print("Requeuing ", self.args) 70 | empty_trainer = type(self)(self.args) 71 | return submitit.helpers.DelayedSubmission(empty_trainer) 72 | 73 | def _setup_gpu_args(self): 74 | import submitit 75 | from pathlib import Path 76 | 77 | job_env = submitit.JobEnvironment() 78 | self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id))) 79 | self.args.log_dir = self.args.output_dir 80 | self.args.gpu = job_env.local_rank 81 | self.args.rank = job_env.global_rank 82 | self.args.world_size = job_env.num_tasks 83 | print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") 84 | 85 | 86 | def main(): 87 | args = parse_args() 88 | if args.job_dir == "": 89 | args.job_dir = get_shared_folder() / "%j" 90 | 91 | # Note that the folder will depend on the job_id, to easily track experiments 92 | executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30) 93 | 94 | num_gpus_per_node = args.ngpus 95 | nodes = args.nodes 96 | timeout_min = args.timeout 97 | 98 | partition = args.partition 99 | kwargs = {} 100 | if args.use_volta32: 101 | kwargs['slurm_constraint'] = 'volta32gb' 102 | if args.comment: 103 | kwargs['slurm_comment'] = args.comment 104 | 105 | executor.update_parameters( 106 | mem_gb=40 * num_gpus_per_node, 107 | gpus_per_node=num_gpus_per_node, 108 | tasks_per_node=num_gpus_per_node, # one task per GPU 109 | cpus_per_task=10, 110 | nodes=nodes, 111 | timeout_min=timeout_min, # max is 60 * 72 112 | # Below are cluster dependent parameters 113 | slurm_partition=partition, 114 | slurm_signal_delay_s=120, 115 | **kwargs 116 | ) 117 | 118 | executor.update_parameters(name="mae") 119 | 120 | args.dist_url = get_init_file().as_uri() 121 | args.output_dir = args.job_dir 122 | 123 | trainer = Trainer(args) 124 | job = executor.submit(trainer) 125 | 126 | # print("Submitted job_id:", job.job_id) 127 | print(job.job_id) 128 | 129 | 130 | if __name__ == "__main__": 131 | main() 132 | -------------------------------------------------------------------------------- /mae/run_finetune.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from main_finetune import get_args_parser, main 4 | 5 | 6 | if __name__ == "__main__": 7 | args = get_args_parser() 8 | args = args.parse_args() 9 | if args.output_dir: 10 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 11 | main(args) 12 | -------------------------------------------------------------------------------- /mae/run_pretrain.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from main_pretrain import get_args_parser, main 3 | 4 | 5 | if __name__ == '__main__': 6 | args = get_args_parser() 7 | args = args.parse_args() 8 | if args.output_dir: 9 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 10 | main(args) 11 | -------------------------------------------------------------------------------- /mae/run_test.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from main_test import get_args_parser, main 4 | 5 | 6 | if __name__ == "__main__": 7 | args = get_args_parser() 8 | args = args.parse_args() 9 | if args.output_dir: 10 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 11 | main(args) 12 | -------------------------------------------------------------------------------- /mae/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scnuhealthy/video_try_on/d78b86fba73c04d4488ee99c8b111bc3614b3d4d/mae/teaser.png -------------------------------------------------------------------------------- /mae/util/.misc.py.swp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scnuhealthy/video_try_on/d78b86fba73c04d4488ee99c8b111bc3614b3d4d/mae/util/.misc.py.swp -------------------------------------------------------------------------------- /mae/util/__pycache__/logging.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scnuhealthy/video_try_on/d78b86fba73c04d4488ee99c8b111bc3614b3d4d/mae/util/__pycache__/logging.cpython-310.pyc -------------------------------------------------------------------------------- /mae/util/__pycache__/logging.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scnuhealthy/video_try_on/d78b86fba73c04d4488ee99c8b111bc3614b3d4d/mae/util/__pycache__/logging.cpython-39.pyc -------------------------------------------------------------------------------- /mae/util/__pycache__/pos_embed.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scnuhealthy/video_try_on/d78b86fba73c04d4488ee99c8b111bc3614b3d4d/mae/util/__pycache__/pos_embed.cpython-310.pyc -------------------------------------------------------------------------------- /mae/util/__pycache__/pos_embed.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scnuhealthy/video_try_on/d78b86fba73c04d4488ee99c8b111bc3614b3d4d/mae/util/__pycache__/pos_embed.cpython-39.pyc -------------------------------------------------------------------------------- /mae/util/__pycache__/video_vit.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scnuhealthy/video_try_on/d78b86fba73c04d4488ee99c8b111bc3614b3d4d/mae/util/__pycache__/video_vit.cpython-310.pyc -------------------------------------------------------------------------------- /mae/util/__pycache__/video_vit.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scnuhealthy/video_try_on/d78b86fba73c04d4488ee99c8b111bc3614b3d4d/mae/util/__pycache__/video_vit.cpython-39.pyc -------------------------------------------------------------------------------- /mae/util/crop.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | import torch 10 | 11 | from torchvision import transforms 12 | from torchvision.transforms import functional as F 13 | 14 | 15 | class RandomResizedCrop(transforms.RandomResizedCrop): 16 | """ 17 | RandomResizedCrop for matching TF/TPU implementation: no for-loop is used. 18 | This may lead to results different with torchvision's version. 19 | Following BYOL's TF code: 20 | https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206 21 | """ 22 | @staticmethod 23 | def get_params(img, scale, ratio): 24 | width, height = F._get_image_size(img) 25 | area = height * width 26 | 27 | target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() 28 | log_ratio = torch.log(torch.tensor(ratio)) 29 | aspect_ratio = torch.exp( 30 | torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) 31 | ).item() 32 | 33 | w = int(round(math.sqrt(target_area * aspect_ratio))) 34 | h = int(round(math.sqrt(target_area / aspect_ratio))) 35 | 36 | w = min(w, width) 37 | h = min(h, height) 38 | 39 | i = torch.randint(0, height - h + 1, size=(1,)).item() 40 | j = torch.randint(0, width - w + 1, size=(1,)).item() 41 | 42 | return i, j, h, w -------------------------------------------------------------------------------- /mae/util/datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # -------------------------------------------------------- 10 | 11 | import os 12 | import PIL 13 | 14 | from torchvision import datasets, transforms 15 | 16 | from timm.data import create_transform 17 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 18 | 19 | 20 | def build_dataset(is_train, args): 21 | transform = build_transform(is_train, args) 22 | 23 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 24 | dataset = datasets.ImageFolder(root, transform=transform) 25 | 26 | print(dataset) 27 | 28 | return dataset 29 | 30 | 31 | def build_transform(is_train, args): 32 | mean = IMAGENET_DEFAULT_MEAN 33 | std = IMAGENET_DEFAULT_STD 34 | # train transform 35 | if is_train: 36 | # this should always dispatch to transforms_imagenet_train 37 | transform = create_transform( 38 | input_size=args.input_size, 39 | is_training=True, 40 | color_jitter=args.color_jitter, 41 | auto_augment=args.aa, 42 | interpolation='bicubic', 43 | re_prob=args.reprob, 44 | re_mode=args.remode, 45 | re_count=args.recount, 46 | mean=mean, 47 | std=std, 48 | ) 49 | return transform 50 | 51 | # eval transform 52 | t = [] 53 | if args.input_size <= 224: 54 | crop_pct = 224 / 256 55 | else: 56 | crop_pct = 1.0 57 | size = int(args.input_size / crop_pct) 58 | t.append( 59 | transforms.Resize(size, interpolation=PIL.Image.BICUBIC), # to maintain same ratio w.r.t. 224 images 60 | ) 61 | t.append(transforms.CenterCrop(args.input_size)) 62 | 63 | t.append(transforms.ToTensor()) 64 | t.append(transforms.Normalize(mean, std)) 65 | return transforms.Compose(t) 66 | -------------------------------------------------------------------------------- /mae/util/decoder/__pycache__/rand_augment.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scnuhealthy/video_try_on/d78b86fba73c04d4488ee99c8b111bc3614b3d4d/mae/util/decoder/__pycache__/rand_augment.cpython-310.pyc -------------------------------------------------------------------------------- /mae/util/decoder/__pycache__/rand_augment.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scnuhealthy/video_try_on/d78b86fba73c04d4488ee99c8b111bc3614b3d4d/mae/util/decoder/__pycache__/rand_augment.cpython-39.pyc -------------------------------------------------------------------------------- /mae/util/decoder/__pycache__/random_erasing.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scnuhealthy/video_try_on/d78b86fba73c04d4488ee99c8b111bc3614b3d4d/mae/util/decoder/__pycache__/random_erasing.cpython-310.pyc -------------------------------------------------------------------------------- /mae/util/decoder/__pycache__/random_erasing.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scnuhealthy/video_try_on/d78b86fba73c04d4488ee99c8b111bc3614b3d4d/mae/util/decoder/__pycache__/random_erasing.cpython-39.pyc -------------------------------------------------------------------------------- /mae/util/decoder/__pycache__/transform.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scnuhealthy/video_try_on/d78b86fba73c04d4488ee99c8b111bc3614b3d4d/mae/util/decoder/__pycache__/transform.cpython-310.pyc -------------------------------------------------------------------------------- /mae/util/decoder/__pycache__/transform.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scnuhealthy/video_try_on/d78b86fba73c04d4488ee99c8b111bc3614b3d4d/mae/util/decoder/__pycache__/transform.cpython-39.pyc -------------------------------------------------------------------------------- /mae/util/decoder/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scnuhealthy/video_try_on/d78b86fba73c04d4488ee99c8b111bc3614b3d4d/mae/util/decoder/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /mae/util/decoder/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scnuhealthy/video_try_on/d78b86fba73c04d4488ee99c8b111bc3614b3d4d/mae/util/decoder/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /mae/util/decoder/video_container.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | import copy 5 | import sys 6 | import tempfile 7 | from io import BytesIO 8 | 9 | 10 | def get_video_container(handle, multi_thread_decode=False): 11 | # Use local data. 12 | with open(handle, "rb") as fp: 13 | container = fp.read() 14 | return container 15 | -------------------------------------------------------------------------------- /mae/util/env.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | """Set up Environment.""" 5 | 6 | from iopath.common.file_io import PathManagerFactory 7 | from iopath.fb.everstore import EverstorePathHandler 8 | from iopath.fb.manifold import ManifoldPathHandler 9 | 10 | _ENV_SETUP_DONE = False 11 | _MEMCACHE_KEY_PREFIX = "pyslowfast" 12 | _MANIFOLD_READ_CHUNK_SIZE = 200000000 # only for loading checkpoint from manifold 13 | 14 | pathmgr = PathManagerFactory.get(key="mae") 15 | checkpoint_pathmgr = PathManagerFactory.get(key="mae_checkpoint") 16 | 17 | 18 | def setup_environment(): 19 | global _ENV_SETUP_DONE 20 | if _ENV_SETUP_DONE: 21 | return 22 | _ENV_SETUP_DONE = True 23 | 24 | # Set distributed environment. 25 | import torch.fb.rendezvous.zeus # noqa 26 | 27 | # Register manifold handler for pathmgr. 28 | pathmgr.register_handler( 29 | ManifoldPathHandler( 30 | memcache_key_prefix=_MEMCACHE_KEY_PREFIX, handle_large_metadata=True 31 | ), 32 | allow_override=True, 33 | ) 34 | checkpoint_pathmgr.register_handler( 35 | ManifoldPathHandler( 36 | memcache_key_prefix=_MEMCACHE_KEY_PREFIX, 37 | handle_large_metadata=True, 38 | read_chunk_size=_MANIFOLD_READ_CHUNK_SIZE, 39 | ), 40 | allow_override=True, 41 | ) 42 | # Register everstore handler for pathmgr. 43 | pathmgr.register_handler(EverstorePathHandler(), allow_override=True) 44 | -------------------------------------------------------------------------------- /mae/util/lars.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # LARS optimizer, implementation from MoCo v3: 8 | # https://github.com/facebookresearch/moco-v3 9 | # -------------------------------------------------------- 10 | 11 | import torch 12 | 13 | 14 | class LARS(torch.optim.Optimizer): 15 | """ 16 | LARS optimizer, no rate scaling or weight decay for parameters <= 1D. 17 | """ 18 | def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001): 19 | defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient) 20 | super().__init__(params, defaults) 21 | 22 | @torch.no_grad() 23 | def step(self): 24 | for g in self.param_groups: 25 | for p in g['params']: 26 | dp = p.grad 27 | 28 | if dp is None: 29 | continue 30 | 31 | if p.ndim > 1: # if not normalization gamma/beta or bias 32 | dp = dp.add(p, alpha=g['weight_decay']) 33 | param_norm = torch.norm(p) 34 | update_norm = torch.norm(dp) 35 | one = torch.ones_like(param_norm) 36 | q = torch.where(param_norm > 0., 37 | torch.where(update_norm > 0, 38 | (g['trust_coefficient'] * param_norm / update_norm), one), 39 | one) 40 | dp = dp.mul(q) 41 | 42 | param_state = self.state[p] 43 | if 'mu' not in param_state: 44 | param_state['mu'] = torch.zeros_like(p) 45 | mu = param_state['mu'] 46 | mu.mul_(g['momentum']).add_(dp) 47 | p.add_(mu, alpha=-g['lr']) -------------------------------------------------------------------------------- /mae/util/logging.py: -------------------------------------------------------------------------------- 1 | """Logging.""" 2 | 3 | import atexit 4 | import builtins 5 | import decimal 6 | import functools 7 | import logging 8 | import os 9 | import sys 10 | 11 | import simplejson 12 | import torch 13 | import torch.distributed as dist 14 | from iopath.common.file_io import g_pathmgr as pathmgr 15 | 16 | 17 | def is_master_proc(multinode=False): 18 | """ 19 | Determines if the current process is the master process. 20 | """ 21 | if dist.is_initialized(): 22 | if multinode: 23 | return dist.get_rank() % dist.get_world_size() == 0 24 | else: 25 | return dist.get_rank() % torch.cuda.device_count() == 0 26 | else: 27 | return True 28 | 29 | 30 | def _suppress_print(): 31 | """ 32 | Suppresses printing from the current process. 33 | """ 34 | 35 | def print_pass(*objects, sep=" ", end="\n", file=sys.stdout, flush=False): 36 | pass 37 | 38 | builtins.print = print_pass 39 | 40 | 41 | @functools.lru_cache(maxsize=None) 42 | def _cached_log_stream(filename): 43 | # Use 1K buffer if writing to cloud storage. 44 | io = pathmgr.open(filename, "a", buffering=1024 if "://" in filename else -1) 45 | atexit.register(io.close) 46 | return io 47 | 48 | 49 | def setup_logging(output_dir=None): 50 | """ 51 | Sets up the logging for multiple processes. Only enable the logging for the 52 | master process, and suppress logging for the non-master processes. 53 | """ 54 | # Set up logging format. 55 | _FORMAT = "[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s" 56 | 57 | if is_master_proc(): 58 | # Enable logging for the master process. 59 | logging.root.handlers = [] 60 | else: 61 | # Suppress logging for non-master processes. 62 | _suppress_print() 63 | 64 | logger = logging.getLogger() 65 | logger.setLevel(logging.DEBUG) 66 | logger.propagate = False 67 | plain_formatter = logging.Formatter( 68 | "[%(asctime)s][%(levelname)s] %(filename)s: %(lineno)3d: %(message)s", 69 | datefmt="%m/%d %H:%M:%S", 70 | ) 71 | 72 | if is_master_proc(): 73 | ch = logging.StreamHandler(stream=sys.stdout) 74 | ch.setLevel(logging.DEBUG) 75 | ch.setFormatter(plain_formatter) 76 | logger.addHandler(ch) 77 | 78 | if output_dir is not None and is_master_proc(multinode=True): 79 | filename = os.path.join(output_dir, "stdout.log") 80 | fh = logging.StreamHandler(_cached_log_stream(filename)) 81 | fh.setLevel(logging.DEBUG) 82 | fh.setFormatter(plain_formatter) 83 | logger.addHandler(fh) 84 | 85 | 86 | def get_logger(name): 87 | """ 88 | Retrieve the logger with the specified name or, if name is None, return a 89 | logger which is the root logger of the hierarchy. 90 | Args: 91 | name (string): name of the logger. 92 | """ 93 | return logging.getLogger(name) 94 | 95 | 96 | def log_json_stats(stats): 97 | """ 98 | Logs json stats. 99 | Args: 100 | stats (dict): a dictionary of statistical information to log. 101 | """ 102 | stats = { 103 | k: decimal.Decimal("{:.5f}".format(v)) if isinstance(v, float) else v 104 | for k, v in stats.items() 105 | } 106 | json_stats = simplejson.dumps(stats, sort_keys=True, use_decimal=True) 107 | logger = get_logger(__name__) 108 | print("json_stats: {:s}".format(json_stats)) 109 | 110 | 111 | def master_print(*args, **kwargs): 112 | if is_master_proc(): 113 | print(*args, **kwargs) 114 | else: 115 | pass 116 | -------------------------------------------------------------------------------- /mae/util/lr_decay.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # ELECTRA https://github.com/google-research/electra 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import json 13 | 14 | 15 | def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75): 16 | """ 17 | Parameter groups for layer-wise lr decay 18 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 19 | """ 20 | param_group_names = {} 21 | param_groups = {} 22 | 23 | num_layers = len(model.blocks) + 1 24 | 25 | layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1)) 26 | 27 | for n, p in model.named_parameters(): 28 | if not p.requires_grad: 29 | continue 30 | 31 | # no decay: all 1D parameters and model specific ones 32 | if p.ndim == 1 or n in no_weight_decay_list: 33 | g_decay = "no_decay" 34 | this_decay = 0. 35 | else: 36 | g_decay = "decay" 37 | this_decay = weight_decay 38 | 39 | layer_id = get_layer_id_for_vit(n, num_layers) 40 | group_name = "layer_%d_%s" % (layer_id, g_decay) 41 | 42 | if group_name not in param_group_names: 43 | this_scale = layer_scales[layer_id] 44 | 45 | param_group_names[group_name] = { 46 | "lr_scale": this_scale, 47 | "weight_decay": this_decay, 48 | "params": [], 49 | } 50 | param_groups[group_name] = { 51 | "lr_scale": this_scale, 52 | "weight_decay": this_decay, 53 | "params": [], 54 | } 55 | 56 | param_group_names[group_name]["params"].append(n) 57 | param_groups[group_name]["params"].append(p) 58 | 59 | # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) 60 | 61 | return list(param_groups.values()) 62 | 63 | 64 | def get_layer_id_for_vit(name, num_layers): 65 | """ 66 | Assign a parameter with its layer id 67 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 68 | """ 69 | if name in ['cls_token', 'pos_embed']: 70 | return 0 71 | elif name.startswith('patch_embed'): 72 | return 0 73 | elif name.startswith('blocks'): 74 | return int(name.split('.')[1]) + 1 75 | else: 76 | return num_layers -------------------------------------------------------------------------------- /mae/util/lr_sched.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | def adjust_learning_rate(optimizer, epoch, args): 10 | """Decay the learning rate with half-cycle cosine after warmup""" 11 | if epoch < args.warmup_epochs: 12 | lr = args.lr * epoch / args.warmup_epochs 13 | else: 14 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \ 15 | (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))) 16 | for param_group in optimizer.param_groups: 17 | if "lr_scale" in param_group: 18 | param_group["lr"] = lr * param_group["lr_scale"] 19 | else: 20 | param_group["lr"] = lr 21 | return lr 22 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.4.0 2 | accelerate==0.21.0 3 | addict==2.4.0 4 | aenum==3.1.15 5 | aiofiles==23.1.0 6 | aiohttp==3.8.5 7 | aiosignal==1.3.1 8 | aliyun-python-sdk-core==2.14.0 9 | aliyun-python-sdk-kms==2.16.2 10 | altair==5.0.1 11 | antlr4-python3-runtime==4.9.3 12 | anyio==3.7.1 13 | asttokens==2.4.0 14 | astunparse==1.6.3 15 | async-timeout==4.0.2 16 | attrs==23.1.0 17 | av==11.0.0 18 | backcall==0.2.0 19 | basicsr==1.4.2 20 | beautifulsoup4==4.12.2 21 | blendmodes==2022 22 | boltons==23.0.0 23 | cachetools==5.3.1 24 | certifi==2023.7.22 25 | cffi==1.15.1 26 | charset-normalizer==3.2.0 27 | clean-fid==0.1.35 28 | click==8.1.6 29 | clip @ https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip 30 | cmake==3.27.0 31 | colorama==0.4.6 32 | coloredlogs==15.0.1 33 | contourpy==1.1.0 34 | crcmod==1.7 35 | cryptography==41.0.7 36 | cssselect2==0.7.0 37 | cycler==0.11.0 38 | datasets==2.13.0 39 | decorator==5.1.1 40 | decord==0.6.0 41 | deprecation==2.1.0 42 | diffusers==0.17.0 43 | dill==0.3.6 44 | efficientnet-pytorch==0.7.1 45 | einops==0.4.1 46 | exceptiongroup==1.1.2 47 | executing==1.2.0 48 | facexlib==0.3.0 49 | fastapi==0.94.0 50 | ffmpy==0.3.1 51 | filelock==3.12.2 52 | filterpy==1.4.5 53 | flatbuffers==23.5.26 54 | fonttools==4.41.1 55 | frozenlist==1.4.0 56 | fsspec==2023.9.2 57 | ftfy==6.1.1 58 | future==0.18.3 59 | fvcore==0.1.5.post20221221 60 | gast==0.5.4 61 | gdown==4.7.1 62 | gfpgan==1.3.8 63 | gitdb==4.0.10 64 | GitPython==3.1.32 65 | google-auth==2.22.0 66 | google-auth-oauthlib==1.0.0 67 | google-pasta==0.2.0 68 | gradio==3.41.2 69 | gradio_client==0.5.0 70 | grpcio==1.56.2 71 | h11==0.12.0 72 | h5py==3.10.0 73 | httpcore==0.15.0 74 | httpx==0.24.1 75 | huggingface-hub==0.16.4 76 | humanfriendly==10.0 77 | idna==3.4 78 | imageio==2.31.1 79 | importlib-metadata==6.8.0 80 | importlib-resources==6.0.0 81 | inflection==0.5.1 82 | invisible-watermark==0.2.0 83 | iopath==0.1.9 84 | ipython==8.15.0 85 | jedi==0.19.0 86 | Jinja2==3.1.2 87 | jmespath==0.10.0 88 | joblib==1.3.1 89 | jsonmerge==1.8.0 90 | jsonschema==4.18.4 91 | jsonschema-specifications==2023.7.1 92 | keras==2.15.0 93 | kiwisolver==1.4.4 94 | kornia==0.6.7 95 | lark==1.1.2 96 | lazy_loader==0.3 97 | libclang==16.0.6 98 | lightning-utilities==0.9.0 99 | linkify-it-py==2.0.2 100 | lit==16.0.6 101 | llvmlite==0.40.1 102 | lmdb==1.4.1 103 | lpips==0.1.4 104 | lxml==4.9.3 105 | Markdown==3.4.3 106 | markdown-it-py==2.2.0 107 | MarkupSafe==2.1.3 108 | matplotlib==3.7.2 109 | matplotlib-inline==0.1.6 110 | mdit-py-plugins==0.3.3 111 | mdurl==0.1.2 112 | mediapipe==0.10.2 113 | ml-dtypes==0.2.0 114 | mmengine==0.8.2 115 | model-index==0.1.11 116 | modelscope==1.9.3 117 | mpmath==1.3.0 118 | multidict==6.0.4 119 | multiprocess==0.70.14 120 | munch==4.0.0 121 | networkx==3.1 122 | ninja==1.11.1 123 | numba==0.57.1 124 | numpy==1.23.5 125 | nvidia-ml-py==12.535.133 126 | nvitop==1.3.0 127 | oauthlib==3.2.2 128 | omegaconf==2.2.3 129 | onnx==1.15.0 130 | onnxruntime==1.16.3 131 | open-clip-torch==2.20.0 132 | opencv-contrib-python==4.8.0.74 133 | opencv-python==4.8.0.74 134 | opendatalab==0.0.9 135 | openmim==0.3.9 136 | opt-einsum==3.3.0 137 | ordered-set==4.1.0 138 | orjson==3.9.2 139 | oss2==2.18.3 140 | packaging==23.1 141 | pandas==2.0.3 142 | parso==0.8.3 143 | pexpect==4.8.0 144 | pickleshare==0.7.5 145 | piexif==1.1.3 146 | Pillow==9.5.0 147 | platformdirs==3.9.1 148 | portalocker==2.7.0 149 | pretrainedmodels==0.7.4 150 | prompt-toolkit==3.0.39 151 | protobuf==3.20.3 152 | psutil==5.9.5 153 | ptyprocess==0.7.0 154 | pure-eval==0.2.2 155 | pyarrow==14.0.1 156 | pyasn1==0.5.0 157 | pyasn1-modules==0.3.0 158 | pycparser==2.21 159 | pycryptodome==3.18.0 160 | pydantic==1.10.11 161 | pydub==0.25.1 162 | Pygments==2.15.1 163 | pyparsing==3.0.9 164 | pyquaternion==0.9.9 165 | PySocks==1.7.1 166 | python-dateutil==2.8.2 167 | python-multipart==0.0.6 168 | pytorch-fid==0.3.0 169 | pytorch-lightning==1.9.4 170 | pytz==2023.3 171 | PyWavelets==1.4.1 172 | PyYAML==6.0.1 173 | realesrgan==0.3.0 174 | referencing==0.30.0 175 | regex==2023.6.3 176 | reportlab==4.0.4 177 | requests==2.31.0 178 | requests-oauthlib==1.3.1 179 | resize-right==0.0.2 180 | rich==13.4.2 181 | rpds-py==0.9.2 182 | rsa==4.9 183 | safetensors==0.3.1 184 | scikit-image==0.21.0 185 | scikit-learn==1.3.0 186 | scipy==1.9.1 187 | segment-anything==1.0 188 | segmentation-models-pytorch @ git+https://github.com/hyenal/segmentation_models.pytorch.git@efd6137e9f10e0701f2a1af286635fa5b363f1fd 189 | semantic-version==2.10.0 190 | sentencepiece==0.1.99 191 | shapely==2.0.2 192 | simplejson==3.19.1 193 | six==1.16.0 194 | smmap==5.0.0 195 | sniffio==1.3.0 196 | sortedcontainers==2.4.0 197 | sounddevice==0.4.6 198 | soupsieve==2.4.1 199 | stack-data==0.6.2 200 | starlette==0.26.1 201 | svglib==1.5.1 202 | sympy==1.12 203 | tabulate==0.9.0 204 | tb-nightly==2.14.0a20230723 205 | tensorboard==2.15.1 206 | tensorboard-data-server==0.7.1 207 | tensorflow-cpu==2.15.0.post1 208 | tensorflow-estimator==2.15.0 209 | tensorflow-io-gcs-filesystem==0.34.0 210 | termcolor==2.3.0 211 | threadpoolctl==3.2.0 212 | tifffile==2023.7.18 213 | timm==0.4.5 214 | tinycss2==1.2.1 215 | tokenizers==0.13.3 216 | tomesd==0.1.3 217 | tomli==2.0.1 218 | toolz==0.12.0 219 | torch==2.0.1+cu118 220 | torch-fidelity==0.3.0 221 | torchdiffeq==0.2.3 222 | torchmetrics==1.0.1 223 | torchsde==0.2.5 224 | torchvision==0.15.2+cu118 225 | tqdm==4.65.0 226 | traitlets==5.10.0 227 | trampoline==0.1.2 228 | transformers==4.30.2 229 | triton==2.0.0 230 | typing_extensions==4.7.1 231 | tzdata==2023.3 232 | uc-micro-py==1.0.2 233 | urllib3==1.26.16 234 | uvicorn==0.23.1 235 | wcwidth==0.2.6 236 | webencodings==0.5.1 237 | websockets==11.0.3 238 | Werkzeug==2.3.6 239 | wrapt==1.14.1 240 | xformers==0.0.20 241 | xxhash==3.4.1 242 | yacs==0.1.8 243 | yapf==0.40.1 244 | yarl==1.9.2 245 | zipp==3.16.2 246 | -------------------------------------------------------------------------------- /video_models/__pycache__/attention.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scnuhealthy/video_try_on/d78b86fba73c04d4488ee99c8b111bc3614b3d4d/video_models/__pycache__/attention.cpython-310.pyc -------------------------------------------------------------------------------- /video_models/__pycache__/attention.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scnuhealthy/video_try_on/d78b86fba73c04d4488ee99c8b111bc3614b3d4d/video_models/__pycache__/attention.cpython-39.pyc -------------------------------------------------------------------------------- /video_models/__pycache__/resnet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scnuhealthy/video_try_on/d78b86fba73c04d4488ee99c8b111bc3614b3d4d/video_models/__pycache__/resnet.cpython-310.pyc -------------------------------------------------------------------------------- /video_models/__pycache__/resnet.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scnuhealthy/video_try_on/d78b86fba73c04d4488ee99c8b111bc3614b3d4d/video_models/__pycache__/resnet.cpython-39.pyc -------------------------------------------------------------------------------- /video_models/__pycache__/unet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scnuhealthy/video_try_on/d78b86fba73c04d4488ee99c8b111bc3614b3d4d/video_models/__pycache__/unet.cpython-310.pyc -------------------------------------------------------------------------------- /video_models/__pycache__/unet.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scnuhealthy/video_try_on/d78b86fba73c04d4488ee99c8b111bc3614b3d4d/video_models/__pycache__/unet.cpython-39.pyc -------------------------------------------------------------------------------- /video_models/__pycache__/unet_blocks.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scnuhealthy/video_try_on/d78b86fba73c04d4488ee99c8b111bc3614b3d4d/video_models/__pycache__/unet_blocks.cpython-310.pyc -------------------------------------------------------------------------------- /video_models/__pycache__/unet_blocks.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scnuhealthy/video_try_on/d78b86fba73c04d4488ee99c8b111bc3614b3d4d/video_models/__pycache__/unet_blocks.cpython-39.pyc -------------------------------------------------------------------------------- /video_models/__pycache__/video_pipeline.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scnuhealthy/video_try_on/d78b86fba73c04d4488ee99c8b111bc3614b3d4d/video_models/__pycache__/video_pipeline.cpython-39.pyc -------------------------------------------------------------------------------- /video_models/__pycache__/video_pipeline_guided.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scnuhealthy/video_try_on/d78b86fba73c04d4488ee99c8b111bc3614b3d4d/video_models/__pycache__/video_pipeline_guided.cpython-310.pyc -------------------------------------------------------------------------------- /video_models/__pycache__/video_pipeline_guided.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scnuhealthy/video_try_on/d78b86fba73c04d4488ee99c8b111bc3614b3d4d/video_models/__pycache__/video_pipeline_guided.cpython-39.pyc -------------------------------------------------------------------------------- /video_models/__pycache__/video_pipeline_mae_guided.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scnuhealthy/video_try_on/d78b86fba73c04d4488ee99c8b111bc3614b3d4d/video_models/__pycache__/video_pipeline_mae_guided.cpython-310.pyc -------------------------------------------------------------------------------- /video_models/__pycache__/video_pipeline_mae_guided.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scnuhealthy/video_try_on/d78b86fba73c04d4488ee99c8b111bc3614b3d4d/video_models/__pycache__/video_pipeline_mae_guided.cpython-39.pyc -------------------------------------------------------------------------------- /video_models/__pycache__/video_pipeline_mae_guided_2.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scnuhealthy/video_try_on/d78b86fba73c04d4488ee99c8b111bc3614b3d4d/video_models/__pycache__/video_pipeline_mae_guided_2.cpython-310.pyc -------------------------------------------------------------------------------- /video_models/bug.txt: -------------------------------------------------------------------------------- 1 | 3D 训练时 开了gradient checkpoint, 选取训练参数后,实际上有参数没有训练到 -------------------------------------------------------------------------------- /video_models/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import imageio 3 | import numpy as np 4 | from typing import Union 5 | import decord 6 | decord.bridge.set_bridge('torch') 7 | import torch 8 | import torchvision 9 | import PIL 10 | from typing import List 11 | from tqdm import tqdm 12 | from einops import rearrange 13 | 14 | from controlnet_aux import CannyDetector 15 | 16 | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=4, fps=8): 17 | videos = rearrange(videos, "b c t h w -> t b c h w") 18 | outputs = [] 19 | for x in videos: 20 | x = torchvision.utils.make_grid(x, nrow=n_rows) 21 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) 22 | if rescale: 23 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1 24 | x = (x * 255).numpy().astype(np.uint8) 25 | outputs.append(x) 26 | 27 | os.makedirs(os.path.dirname(path), exist_ok=True) 28 | imageio.mimsave(path, outputs, fps=fps) 29 | 30 | def save_videos_grid_pil(videos: List[PIL.Image.Image], path: str, rescale=False, n_rows=4, fps=8): 31 | videos = rearrange(videos, "b c t h w -> t b c h w") 32 | outputs = [] 33 | for x in videos: 34 | x = torchvision.utils.make_grid(x, nrow=n_rows) 35 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) 36 | if rescale: 37 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1 38 | x = (x * 255).numpy().astype(np.uint8) 39 | outputs.append(x) 40 | 41 | os.makedirs(os.path.dirname(path), exist_ok=True) 42 | imageio.mimsave(path, outputs, fps=fps) 43 | 44 | def read_video(video_path, video_length, width=512, height=512, frame_rate=None): 45 | vr = decord.VideoReader(video_path, width=width, height=height) 46 | if frame_rate is None: 47 | frame_rate = max(1, len(vr) // video_length) 48 | sample_index = list(range(0, len(vr), frame_rate))[:video_length] 49 | video = vr.get_batch(sample_index) 50 | video = rearrange(video, "f h w c -> f c h w") 51 | video = (video / 127.5 - 1.0) 52 | return video 53 | 54 | 55 | def get_annotation(video, annotator): 56 | t2i_transform = torchvision.transforms.ToPILImage() 57 | annotation = [] 58 | for frame in video: 59 | pil_frame = t2i_transform(frame) 60 | if isinstance(annotator, CannyDetector): 61 | annotation.append(annotator(pil_frame, low_threshold=100, high_threshold=200)) 62 | else: 63 | annotation.append(annotator(pil_frame)) 64 | return annotation 65 | 66 | # DDIM Inversion 67 | @torch.no_grad() 68 | def init_prompt(prompt, pipeline): 69 | uncond_input = pipeline.tokenizer( 70 | [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length, 71 | return_tensors="pt" 72 | ) 73 | uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0] 74 | text_input = pipeline.tokenizer( 75 | [prompt], 76 | padding="max_length", 77 | max_length=pipeline.tokenizer.model_max_length, 78 | truncation=True, 79 | return_tensors="pt", 80 | ) 81 | text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0] 82 | context = torch.cat([uncond_embeddings, text_embeddings]) 83 | 84 | return context 85 | 86 | 87 | def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, 88 | sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler): 89 | timestep, next_timestep = min( 90 | timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep 91 | alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod 92 | alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep] 93 | beta_prod_t = 1 - alpha_prod_t 94 | next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 95 | next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output 96 | next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction 97 | return next_sample 98 | 99 | 100 | def get_noise_pred_single(latents, t, context, unet): 101 | noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"] 102 | return noise_pred 103 | 104 | 105 | @torch.no_grad() 106 | def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt): 107 | context = init_prompt(prompt, pipeline) 108 | uncond_embeddings, cond_embeddings = context.chunk(2) 109 | all_latent = [latent] 110 | latent = latent.clone().detach() 111 | for i in tqdm(range(num_inv_steps)): 112 | t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1] 113 | noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet) 114 | latent = next_step(noise_pred, t, latent, ddim_scheduler) 115 | all_latent.append(latent) 116 | return all_latent 117 | 118 | 119 | @torch.no_grad() 120 | def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""): 121 | ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt) 122 | return ddim_latents 123 | -------------------------------------------------------------------------------- /videomae/DATASET.md: -------------------------------------------------------------------------------- 1 | # Data Preparation 2 | 3 | We have successfully pre-trained and fine-tuned our VideoMAE on [Kinetics400](https://deepmind.com/research/open-source/kinetics), [Something-Something-V2](https://developer.qualcomm.com/software/ai-datasets/something-something), [UCF101](https://www.crcv.ucf.edu/data/UCF101.php) and [HMDB51](https://serre-lab.clps.brown.edu/resource/hmdb-a-large-human-motion-database/) with this codebase. 4 | 5 | - The pre-processing of **Something-Something-V2** can be summarized into 3 steps: 6 | 7 | 1. Download the dataset from [official website](https://developer.qualcomm.com/software/ai-datasets/something-something). 8 | 9 | 2. Preprocess the dataset by changing the video extension from `webm` to `.mp4` with the **original** height of **240px**. 10 | 11 | 3. Generate annotations needed for dataloader (" " in annotations). The annotation usually includes `train.csv`, `val.csv` and `test.csv` ( here `test.csv` is the same as `val.csv`). We **share** our annotation files (train.csv, val.csv, test.csv) via [Google Drive](https://drive.google.com/drive/folders/1cfA-SrPhDB9B8ZckPvnh8D5ysCjD-S_I?usp=share_link). The format of `*.csv` file is like: 12 | 13 | ``` 14 | dataset_root/video_1.mp4 label_1 15 | dataset_root/video_2.mp4 label_2 16 | dataset_root/video_3.mp4 label_3 17 | ... 18 | dataset_root/video_N.mp4 label_N 19 | ``` 20 | 21 | - The pre-processing of **Kinetics400** can be summarized into 3 steps: 22 | 23 | 1. Download the dataset from [official website](https://deepmind.com/research/open-source/kinetics). 24 | 25 | 2. Preprocess the dataset by resizing the short edge of video to **320px**. You can refer to [MMAction2 Data Benchmark](https://github.com/open-mmlab/mmaction2) for [TSN](https://github.com/open-mmlab/mmaction2/tree/master/configs/recognition/tsn#kinetics-400-data-benchmark-8-gpus-resnet50-imagenet-pretrain-3-segments) and [SlowOnly](https://github.com/open-mmlab/mmaction2/tree/master/configs/recognition/slowonly#kinetics-400-data-benchmark).
26 | **Recommend**: [OpenDataLab](https://opendatalab.com/) provides a copy of [Kinetics400](https://opendatalab.com/Kinetics-400) dataset, you can download Kinetics dataset with **short edge 320px** from [here](https://opendatalab.com/Kinetics-400).
27 | 28 | 3. Generate annotations needed for dataloader (" " in annotations). The annotation usually includes `train.csv`, `val.csv` and `test.csv` ( here `test.csv` is the same as `val.csv`). The format of `*.csv` file is like: 29 | 30 | ``` 31 | dataset_root/video_1.mp4 label_1 32 | dataset_root/video_2.mp4 label_2 33 | dataset_root/video_3.mp4 label_3 34 | ... 35 | dataset_root/video_N.mp4 label_N 36 | ``` 37 | 38 | ### Note: 39 | 40 | 1. We use [decord](https://github.com/dmlc/decord) to decode the videos **on the fly** during both pre-training and fine-tuning phases. 41 | 2. All experiments on Kinetics-400 in VideoMAE are based on [this version](https://opendatalab.com/Kinetics-400). 42 | -------------------------------------------------------------------------------- /videomae/INSTALL.md: -------------------------------------------------------------------------------- 1 | # VideoMAE Installation 2 | 3 | The codebase is mainly built with following libraries: 4 | 5 | - Python 3.6 or higher 6 | 7 | - [PyTorch](https://pytorch.org/) and [torchvision](https://github.com/pytorch/vision).
8 | We can successfully reproduce the main results under two settings below:
9 | Tesla **A100** (40G): CUDA 11.1 + PyTorch 1.8.0 + torchvision 0.9.0
10 | Tesla **V100** (32G): CUDA 10.1 + PyTorch 1.6.0 + torchvision 0.7.0 11 | 12 | - [timm==0.4.8/0.4.12](https://github.com/rwightman/pytorch-image-models) 13 | 14 | - [deepspeed==0.5.8](https://github.com/microsoft/DeepSpeed) 15 | 16 | `DS_BUILD_OPS=1 pip install deepspeed` 17 | 18 | - [TensorboardX](https://github.com/lanpa/tensorboardX) 19 | 20 | - [decord](https://github.com/dmlc/decord) 21 | 22 | - [einops](https://github.com/arogozhnikov/einops) 23 | 24 | ### Note: 25 | - We recommend you to use **`PyTorch >= 1.8.0`**. 26 | - We observed accidental interrupt in the last epoch when conducted the pre-training experiments on V100 GPUs (PyTorch 1.6.0). This interrupt is caused by the scheduler of learning rate. We naively set `--epochs 801` to walk away from issue :) 27 | 28 | -------------------------------------------------------------------------------- /videomae/PRETRAIN.md: -------------------------------------------------------------------------------- 1 | # Pre-training VideoMAE 2 | 3 | ## Original Implementation 4 | 5 | The implementation of our VideoMAE supports **multi-node distributed training**. We provide the **off-the-shelf** scripts in the [scripts folder](scripts). 6 | 7 | - For example, to pre-train VideoMAE ViT-Base on **Something-Something V2** with 64 GPUs (8 nodes x 8 GPUs), you can run 8 | 9 | ```bash 10 | OUTPUT_DIR='YOUR_PATH/ssv2_videomae_pretrain_base_patch16_224_frame_16x2_tube_mask_ratio_0.9_e800' 11 | DATA_PATH='YOUR_PATH/list_ssv2/train.csv' 12 | 13 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 \ 14 | --master_port 12320 --nnodes=8 \ 15 | --node_rank=0 --master_addr=$ip_node_0 \ 16 | run_mae_pretraining.py \ 17 | --data_path ${DATA_PATH} \ 18 | --mask_type tube \ 19 | --mask_ratio 0.9 \ 20 | --model pretrain_videomae_base_patch16_224 \ 21 | --decoder_depth 4 \ 22 | --batch_size 32 \ 23 | --num_frames 16 \ 24 | --sampling_rate 2 \ 25 | --opt adamw \ 26 | --opt_betas 0.9 0.95 \ 27 | --warmup_epochs 40 \ 28 | --save_ckpt_freq 20 \ 29 | --epochs 801 \ 30 | --log_dir ${OUTPUT_DIR} \ 31 | --output_dir ${OUTPUT_DIR} 32 | ``` 33 | 34 | on the first node. On other nodes, run the same command with `--node_rank 1`, ..., `--node_rank 7` respectively. `--master_addr` is set as the ip of the node 0. 35 | 36 | - For example, to pre-train VideoMAE ViT-Base on **Kinetics400** with 64 GPUs (8 nodes x 8 GPUs), you can run 37 | 38 | ```bash 39 | OUTPUT_DIR='YOUR_PATH/k400_videomae_pretrain_base_patch16_224_frame_16x4_tube_mask_ratio_0.9_e800' 40 | DATA_PATH='YOUR_PATH/list_kinetics-400/train.csv' 41 | 42 | OMP_NUM_THREADS=1 python3 -m torch.distributed.launch --nproc_per_node=8 \ 43 | --master_port 12320 --nnodes=8 \ 44 | --node_rank=0 --master_addr=$your_node_0_ip \ 45 | run_mae_pretraining.py \ 46 | --data_path ${DATA_PATH} \ 47 | --mask_type tube \ 48 | --mask_ratio 0.9 \ 49 | --model pretrain_videomae_base_patch16_224 \ 50 | --decoder_depth 4 \ 51 | --batch_size 32 \ 52 | --num_frames 16 \ 53 | --sampling_rate 4 \ 54 | --opt adamw \ 55 | --opt_betas 0.9 0.95 \ 56 | --warmup_epochs 40 \ 57 | --save_ckpt_freq 20 \ 58 | --epochs 801 \ 59 | --log_dir ${OUTPUT_DIR} \ 60 | --output_dir ${OUTPUT_DIR} 61 | ``` 62 | 63 | on the first node. On other nodes, run the same command with `--node_rank 1`, ..., `--node_rank 7` respectively. `--master_addr` is set as the ip of the node 0. 64 | 65 | ### Note: 66 | 67 | - Here the batch size is 32 (`batch_size` per gpu) * 8 (`nodes`) * 8 (gpus per node) = 2048. 68 | - `lr` here is the base learning rate and is set to `1.5e-4` as default. The ` actual lr` is computed by the [linear scaling rule](https://arxiv.org/abs/1706.02677): `` actual lr`` = `lr` * total batch size / 256. 69 | - [Fixed]~~We have observed accidental interrupt in the last epoch when conduct the experiment on V100 GPUs (torch 1.6.0). This interrupt is caused by the scheduler of learning rate. We naively set `--epochs 801` to walk away from issue :)~~ 70 | 71 | ## Slurm 72 | 73 | To help the community to reproduce our results on slurm cluster, we also provide the the **off-the-shelf** script. 74 | 75 | For example, to pre-train VideoMAE ViT-Base on **Kinetics400** with 64 GPUs (8 nodes x 8 GPUs), you can run 76 | 77 | ```bash 78 | export MASTER_PORT=$((12000 + $RANDOM % 20000)) 79 | export OMP_NUM_THREADS=1 80 | 81 | OUTPUT_DIR='YOUR_PATH/k400_videomae_pretrain_base_patch16_224_frame_16x4_tube_mask_ratio_0.9_e800' 82 | DATA_PATH='YOUR_PATH/list_kinetics-400/train.csv' 83 | 84 | JOB_NAME=$1 85 | PARTITION=${PARTITION:-"video"} 86 | # 8 for 1 node, 16 for 2 node, etc. 87 | GPUS=${GPUS:-64} 88 | GPUS_PER_NODE=${GPUS_PER_NODE:-8} 89 | CPUS_PER_TASK=${CPUS_PER_TASK:-8} 90 | SRUN_ARGS=${SRUN_ARGS:-""} 91 | PY_ARGS=${@:2} 92 | 93 | # batch_size can be adjusted according to the graphics card 94 | srun -p $PARTITION \ 95 | --job-name=${JOB_NAME} \ 96 | --gres=gpu:${GPUS_PER_NODE} \ 97 | --ntasks=${GPUS} \ 98 | --ntasks-per-node=${GPUS_PER_NODE} \ 99 | --cpus-per-task=${CPUS_PER_TASK} \ 100 | --kill-on-bad-exit=1 \ 101 | ${SRUN_ARGS} \ 102 | python -u run_mae_pretraining.py \ 103 | --data_path ${DATA_PATH} \ 104 | --mask_type tube \ 105 | --mask_ratio 0.9 \ 106 | --model pretrain_videomae_base_patch16_224 \ 107 | --decoder_depth 4 \ 108 | --batch_size 32 \ 109 | --num_frames 16 \ 110 | --sampling_rate 4 \ 111 | --opt adamw \ 112 | --opt_betas 0.9 0.95 \ 113 | --warmup_epochs 40 \ 114 | --save_ckpt_freq 20 \ 115 | --epochs 801 \ 116 | --log_dir ${OUTPUT_DIR} \ 117 | --output_dir ${OUTPUT_DIR} \ 118 | ${PY_ARGS} 119 | ``` 120 | 121 | -------------------------------------------------------------------------------- /videomae/__pycache__/datasets.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scnuhealthy/video_try_on/d78b86fba73c04d4488ee99c8b111bc3614b3d4d/videomae/__pycache__/datasets.cpython-310.pyc -------------------------------------------------------------------------------- /videomae/__pycache__/functional.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scnuhealthy/video_try_on/d78b86fba73c04d4488ee99c8b111bc3614b3d4d/videomae/__pycache__/functional.cpython-310.pyc -------------------------------------------------------------------------------- /videomae/__pycache__/kinetics.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scnuhealthy/video_try_on/d78b86fba73c04d4488ee99c8b111bc3614b3d4d/videomae/__pycache__/kinetics.cpython-310.pyc -------------------------------------------------------------------------------- /videomae/__pycache__/masking_generator.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scnuhealthy/video_try_on/d78b86fba73c04d4488ee99c8b111bc3614b3d4d/videomae/__pycache__/masking_generator.cpython-310.pyc -------------------------------------------------------------------------------- /videomae/__pycache__/modeling_finetune.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scnuhealthy/video_try_on/d78b86fba73c04d4488ee99c8b111bc3614b3d4d/videomae/__pycache__/modeling_finetune.cpython-310.pyc -------------------------------------------------------------------------------- /videomae/__pycache__/modeling_pretrain.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scnuhealthy/video_try_on/d78b86fba73c04d4488ee99c8b111bc3614b3d4d/videomae/__pycache__/modeling_pretrain.cpython-310.pyc -------------------------------------------------------------------------------- /videomae/__pycache__/rand_augment.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scnuhealthy/video_try_on/d78b86fba73c04d4488ee99c8b111bc3614b3d4d/videomae/__pycache__/rand_augment.cpython-310.pyc -------------------------------------------------------------------------------- /videomae/__pycache__/random_erasing.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scnuhealthy/video_try_on/d78b86fba73c04d4488ee99c8b111bc3614b3d4d/videomae/__pycache__/random_erasing.cpython-310.pyc -------------------------------------------------------------------------------- /videomae/__pycache__/ssv2.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scnuhealthy/video_try_on/d78b86fba73c04d4488ee99c8b111bc3614b3d4d/videomae/__pycache__/ssv2.cpython-310.pyc -------------------------------------------------------------------------------- /videomae/__pycache__/transforms.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scnuhealthy/video_try_on/d78b86fba73c04d4488ee99c8b111bc3614b3d4d/videomae/__pycache__/transforms.cpython-310.pyc -------------------------------------------------------------------------------- /videomae/__pycache__/video_transforms.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scnuhealthy/video_try_on/d78b86fba73c04d4488ee99c8b111bc3614b3d4d/videomae/__pycache__/video_transforms.cpython-310.pyc -------------------------------------------------------------------------------- /videomae/__pycache__/volume_transforms.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scnuhealthy/video_try_on/d78b86fba73c04d4488ee99c8b111bc3614b3d4d/videomae/__pycache__/volume_transforms.cpython-310.pyc -------------------------------------------------------------------------------- /videomae/engine_for_pretraining.py: -------------------------------------------------------------------------------- 1 | import math 2 | import sys 3 | from typing import Iterable 4 | import torch 5 | import torch.nn as nn 6 | import utils 7 | from einops import rearrange 8 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 9 | 10 | def train_one_epoch(model: torch.nn.Module, data_loader: Iterable, optimizer: torch.optim.Optimizer, 11 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, patch_size: int = 16, 12 | normlize_target: bool = True, log_writer=None, lr_scheduler=None, start_steps=None, 13 | lr_schedule_values=None, wd_schedule_values=None): 14 | model.train() 15 | metric_logger = utils.MetricLogger(delimiter=" ") 16 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 17 | metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 18 | header = 'Epoch: [{}]'.format(epoch) 19 | print_freq = 10 20 | 21 | loss_func = nn.MSELoss() 22 | 23 | for step, batch in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 24 | # assign learning rate & weight decay for each step 25 | it = start_steps + step # global training iteration 26 | if lr_schedule_values is not None or wd_schedule_values is not None: 27 | for i, param_group in enumerate(optimizer.param_groups): 28 | if lr_schedule_values is not None: 29 | param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"] 30 | if wd_schedule_values is not None and param_group["weight_decay"] > 0: 31 | param_group["weight_decay"] = wd_schedule_values[it] 32 | 33 | videos, bool_masked_pos = batch 34 | videos = videos.to(device, non_blocking=True) 35 | bool_masked_pos = bool_masked_pos.to(device, non_blocking=True).flatten(1).to(torch.bool) 36 | 37 | with torch.no_grad(): 38 | # calculate the predict label 39 | mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device)[None, :, None, None, None] 40 | std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device)[None, :, None, None, None] 41 | unnorm_videos = videos * std + mean # in [0, 1] 42 | 43 | if normlize_target: 44 | videos_squeeze = rearrange(unnorm_videos, 'b c (t p0) (h p1) (w p2) -> b (t h w) (p0 p1 p2) c', p0=2, p1=patch_size, p2=patch_size) 45 | videos_norm = (videos_squeeze - videos_squeeze.mean(dim=-2, keepdim=True) 46 | ) / (videos_squeeze.var(dim=-2, unbiased=True, keepdim=True).sqrt() + 1e-6) 47 | # we find that the mean is about 0.48 and standard deviation is about 0.08. 48 | videos_patch = rearrange(videos_norm, 'b n p c -> b n (p c)') 49 | else: 50 | videos_patch = rearrange(unnorm_videos, 'b c (t p0) (h p1) (w p2) -> b (t h w) (p0 p1 p2 c)', p0=2, p1=patch_size, p2=patch_size) 51 | 52 | B, _, C = videos_patch.shape 53 | labels = videos_patch[bool_masked_pos].reshape(B, -1, C) 54 | 55 | with torch.cuda.amp.autocast(): 56 | outputs = model(videos, bool_masked_pos) 57 | loss = loss_func(input=outputs, target=labels) 58 | 59 | loss_value = loss.item() 60 | 61 | if not math.isfinite(loss_value): 62 | print("Loss is {}, stopping training".format(loss_value)) 63 | sys.exit(1) 64 | 65 | optimizer.zero_grad() 66 | # this attribute is added by timm on one optimizer (adahessian) 67 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 68 | grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm, 69 | parameters=model.parameters(), create_graph=is_second_order) 70 | loss_scale_value = loss_scaler.state_dict()["scale"] 71 | 72 | torch.cuda.synchronize() 73 | 74 | metric_logger.update(loss=loss_value) 75 | metric_logger.update(loss_scale=loss_scale_value) 76 | min_lr = 10. 77 | max_lr = 0. 78 | for group in optimizer.param_groups: 79 | min_lr = min(min_lr, group["lr"]) 80 | max_lr = max(max_lr, group["lr"]) 81 | 82 | metric_logger.update(lr=max_lr) 83 | metric_logger.update(min_lr=min_lr) 84 | weight_decay_value = None 85 | for group in optimizer.param_groups: 86 | if group["weight_decay"] > 0: 87 | weight_decay_value = group["weight_decay"] 88 | metric_logger.update(weight_decay=weight_decay_value) 89 | metric_logger.update(grad_norm=grad_norm) 90 | 91 | if log_writer is not None: 92 | log_writer.update(loss=loss_value, head="loss") 93 | log_writer.update(loss_scale=loss_scale_value, head="opt") 94 | log_writer.update(lr=max_lr, head="opt") 95 | log_writer.update(min_lr=min_lr, head="opt") 96 | log_writer.update(weight_decay=weight_decay_value, head="opt") 97 | log_writer.update(grad_norm=grad_norm, head="opt") 98 | log_writer.set_step() 99 | 100 | if lr_scheduler is not None: 101 | lr_scheduler.step_update(start_steps + step) 102 | # gather the stats from all processes 103 | metric_logger.synchronize_between_processes() 104 | print("Averaged stats:", metric_logger) 105 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 106 | -------------------------------------------------------------------------------- /videomae/functional.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | import cv2 3 | import numpy as np 4 | import PIL 5 | import torch 6 | 7 | 8 | def _is_tensor_clip(clip): 9 | return torch.is_tensor(clip) and clip.ndimension() == 4 10 | 11 | 12 | def crop_clip(clip, min_h, min_w, h, w): 13 | if isinstance(clip[0], np.ndarray): 14 | cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip] 15 | 16 | elif isinstance(clip[0], PIL.Image.Image): 17 | cropped = [ 18 | img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip 19 | ] 20 | else: 21 | raise TypeError('Expected numpy.ndarray or PIL.Image' + 22 | 'but got list of {0}'.format(type(clip[0]))) 23 | return cropped 24 | 25 | 26 | def resize_clip(clip, size, interpolation='bilinear'): 27 | if isinstance(clip[0], np.ndarray): 28 | if isinstance(size, numbers.Number): 29 | im_h, im_w, im_c = clip[0].shape 30 | # Min spatial dim already matches minimal size 31 | if (im_w <= im_h and im_w == size) or (im_h <= im_w 32 | and im_h == size): 33 | return clip 34 | new_h, new_w = get_resize_sizes(im_h, im_w, size) 35 | size = (new_w, new_h) 36 | else: 37 | size = size[0], size[1] 38 | if interpolation == 'bilinear': 39 | np_inter = cv2.INTER_LINEAR 40 | else: 41 | np_inter = cv2.INTER_NEAREST 42 | scaled = [ 43 | cv2.resize(img, size, interpolation=np_inter) for img in clip 44 | ] 45 | elif isinstance(clip[0], PIL.Image.Image): 46 | if isinstance(size, numbers.Number): 47 | im_w, im_h = clip[0].size 48 | # Min spatial dim already matches minimal size 49 | if (im_w <= im_h and im_w == size) or (im_h <= im_w 50 | and im_h == size): 51 | return clip 52 | new_h, new_w = get_resize_sizes(im_h, im_w, size) 53 | size = (new_w, new_h) 54 | else: 55 | size = size[1], size[0] 56 | if interpolation == 'bilinear': 57 | pil_inter = PIL.Image.BILINEAR 58 | else: 59 | pil_inter = PIL.Image.NEAREST 60 | scaled = [img.resize(size, pil_inter) for img in clip] 61 | else: 62 | raise TypeError('Expected numpy.ndarray or PIL.Image' + 63 | 'but got list of {0}'.format(type(clip[0]))) 64 | return scaled 65 | 66 | 67 | def get_resize_sizes(im_h, im_w, size): 68 | if im_w < im_h: 69 | ow = size 70 | oh = int(size * im_h / im_w) 71 | else: 72 | oh = size 73 | ow = int(size * im_w / im_h) 74 | return oh, ow 75 | 76 | 77 | def normalize(clip, mean, std, inplace=False): 78 | if not _is_tensor_clip(clip): 79 | raise TypeError('tensor is not a torch clip.') 80 | 81 | if not inplace: 82 | clip = clip.clone() 83 | 84 | dtype = clip.dtype 85 | mean = torch.as_tensor(mean, dtype=dtype, device=clip.device) 86 | std = torch.as_tensor(std, dtype=dtype, device=clip.device) 87 | clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None]) 88 | 89 | return clip 90 | -------------------------------------------------------------------------------- /videomae/masking_generator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class TubeMaskingGenerator: 4 | def __init__(self, input_size, mask_ratio): 5 | self.frames, self.height, self.width = input_size 6 | self.num_patches_per_frame = self.height * self.width 7 | self.total_patches = self.frames * self.num_patches_per_frame 8 | self.num_masks_per_frame = int(mask_ratio * self.num_patches_per_frame) 9 | self.total_masks = self.frames * self.num_masks_per_frame 10 | 11 | def __repr__(self): 12 | repr_str = "Maks: total patches {}, mask patches {}".format( 13 | self.total_patches, self.total_masks 14 | ) 15 | return repr_str 16 | 17 | def __call__(self): 18 | mask_per_frame = np.hstack([ 19 | np.zeros(self.num_patches_per_frame - self.num_masks_per_frame), 20 | np.ones(self.num_masks_per_frame), 21 | ]) 22 | np.random.shuffle(mask_per_frame) 23 | mask = np.tile(mask_per_frame, (self.frames,1)).flatten() 24 | return mask -------------------------------------------------------------------------------- /videomae/vis.sh: -------------------------------------------------------------------------------- 1 | # Set the path to save video 2 | OUTPUT_DIR='TODO/VideoMAE/demo/vis_k400_1_0.9' 3 | # path to video for visualization 4 | VIDEO_PATH='TODO/TODO.avi' 5 | # path to pretrain model 6 | # MODEL_PATH='TODO/videomae_pretrain_base_patch16_224_frame_16x4_tube_mask_ratio_0.9_e1600/checkpoint-1599.pth' 7 | MODEL_PATH='TODO/vit_base.pth' 8 | 9 | python3 run_videomae_vis.py \ 10 | --mask_ratio 0.9 \ 11 | --mask_type tube \ 12 | --decoder_depth 4 \ 13 | --model pretrain_videomae_base_patch16_224 \ 14 | ${VIDEO_PATH} ${OUTPUT_DIR} ${MODEL_PATH} -------------------------------------------------------------------------------- /videomae/volume_transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import torch 4 | 5 | 6 | def convert_img(img): 7 | """Converts (H, W, C) numpy.ndarray to (C, W, H) format 8 | """ 9 | if len(img.shape) == 3: 10 | img = img.transpose(2, 0, 1) 11 | if len(img.shape) == 2: 12 | img = np.expand_dims(img, 0) 13 | return img 14 | 15 | 16 | class ClipToTensor(object): 17 | """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255] 18 | to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0] 19 | """ 20 | 21 | def __init__(self, channel_nb=3, div_255=True, numpy=False): 22 | self.channel_nb = channel_nb 23 | self.div_255 = div_255 24 | self.numpy = numpy 25 | 26 | def __call__(self, clip): 27 | """ 28 | Args: clip (list of numpy.ndarray): clip (list of images) 29 | to be converted to tensor. 30 | """ 31 | # Retrieve shape 32 | if isinstance(clip[0], np.ndarray): 33 | h, w, ch = clip[0].shape 34 | assert ch == self.channel_nb, 'Got {0} instead of 3 channels'.format( 35 | ch) 36 | elif isinstance(clip[0], Image.Image): 37 | w, h = clip[0].size 38 | else: 39 | raise TypeError('Expected numpy.ndarray or PIL.Image\ 40 | but got list of {0}'.format(type(clip[0]))) 41 | 42 | np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)]) 43 | 44 | # Convert 45 | for img_idx, img in enumerate(clip): 46 | if isinstance(img, np.ndarray): 47 | pass 48 | elif isinstance(img, Image.Image): 49 | img = np.array(img, copy=False) 50 | else: 51 | raise TypeError('Expected numpy.ndarray or PIL.Image\ 52 | but got list of {0}'.format(type(clip[0]))) 53 | img = convert_img(img) 54 | np_clip[:, img_idx, :, :] = img 55 | if self.numpy: 56 | if self.div_255: 57 | np_clip = np_clip / 255.0 58 | return np_clip 59 | 60 | else: 61 | tensor_clip = torch.from_numpy(np_clip) 62 | 63 | if not isinstance(tensor_clip, torch.FloatTensor): 64 | tensor_clip = tensor_clip.float() 65 | if self.div_255: 66 | tensor_clip = torch.div(tensor_clip, 255) 67 | return tensor_clip 68 | 69 | 70 | # Note this norms data to -1/1 71 | class ClipToTensor_K(object): 72 | """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255] 73 | to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0] 74 | """ 75 | 76 | def __init__(self, channel_nb=3, div_255=True, numpy=False): 77 | self.channel_nb = channel_nb 78 | self.div_255 = div_255 79 | self.numpy = numpy 80 | 81 | def __call__(self, clip): 82 | """ 83 | Args: clip (list of numpy.ndarray): clip (list of images) 84 | to be converted to tensor. 85 | """ 86 | # Retrieve shape 87 | if isinstance(clip[0], np.ndarray): 88 | h, w, ch = clip[0].shape 89 | assert ch == self.channel_nb, 'Got {0} instead of 3 channels'.format( 90 | ch) 91 | elif isinstance(clip[0], Image.Image): 92 | w, h = clip[0].size 93 | else: 94 | raise TypeError('Expected numpy.ndarray or PIL.Image\ 95 | but got list of {0}'.format(type(clip[0]))) 96 | 97 | np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)]) 98 | 99 | # Convert 100 | for img_idx, img in enumerate(clip): 101 | if isinstance(img, np.ndarray): 102 | pass 103 | elif isinstance(img, Image.Image): 104 | img = np.array(img, copy=False) 105 | else: 106 | raise TypeError('Expected numpy.ndarray or PIL.Image\ 107 | but got list of {0}'.format(type(clip[0]))) 108 | img = convert_img(img) 109 | np_clip[:, img_idx, :, :] = img 110 | if self.numpy: 111 | if self.div_255: 112 | np_clip = (np_clip - 127.5) / 127.5 113 | return np_clip 114 | 115 | else: 116 | tensor_clip = torch.from_numpy(np_clip) 117 | 118 | if not isinstance(tensor_clip, torch.FloatTensor): 119 | tensor_clip = tensor_clip.float() 120 | if self.div_255: 121 | tensor_clip = torch.div(torch.sub(tensor_clip, 127.5), 127.5) 122 | return tensor_clip 123 | 124 | 125 | class ToTensor(object): 126 | """Converts numpy array to tensor 127 | """ 128 | 129 | def __call__(self, array): 130 | tensor = torch.from_numpy(array) 131 | return tensor 132 | -------------------------------------------------------------------------------- /vis2.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import requests 4 | 5 | import torch 6 | import numpy as np 7 | 8 | import matplotlib.pyplot as plt 9 | from PIL import Image 10 | 11 | sys.path.append('./mae') 12 | 13 | from mae.util.decoder.utils import tensor_normalize, spatial_sampling 14 | import random 15 | # seed 16 | seed = 4 17 | random.seed(seed) 18 | torch.manual_seed(seed) 19 | torch.cuda.manual_seed(seed) 20 | 21 | MEAN = (0.45, 0.45, 0.45) 22 | STD = (0.225, 0.225, 0.225) 23 | 24 | 25 | def plot_input(tensor): 26 | tensor = tensor.float() 27 | f, ax = plt.subplots(nrows=tensor.shape[0], ncols=tensor.shape[1], figsize=(50, 20)) 28 | 29 | tensor = tensor * torch.tensor(STD).view(3, 1, 1) 30 | tensor = tensor + torch.tensor(MEAN).view(3, 1, 1) 31 | tensor = torch.clip(tensor * 255, 0, 255).int() 32 | 33 | for i in range(tensor.shape[0]): 34 | for j in range(tensor.shape[1]): 35 | ax[i][j].axis("off") 36 | ax[i][j].imshow(tensor[i][j].permute(1, 2, 0)) 37 | plt.show() 38 | 39 | 40 | from mae.models_mae import mae_vit_large_patch16 41 | model = mae_vit_large_patch16(decoder_embed_dim=512, decoder_depth=4, mask_type="st", t_patch_size=2, img_size=224) 42 | 43 | checkpoint = torch.load("./video-mae-100x4-joint.pth", map_location='cpu') 44 | msg = model.load_state_dict(checkpoint['model'], strict=False) 45 | model = model.cuda().to(dtype=torch.float16) 46 | 47 | # load data into tensor 48 | file_path = 'mae_test/07148_00_smooth' 49 | frames = [] 50 | frame_names = sorted(os.listdir(file_path)) 51 | for franme_name in frame_names: 52 | if franme_name[-3:] == 'gif': 53 | continue 54 | frame_path = os.path.join(file_path, franme_name) 55 | frame = Image.open(frame_path) 56 | frame = torch.tensor(np.array(frame)) 57 | frames.append(frame) 58 | frames = torch.stack(frames) 59 | print(frames.shape) 60 | frames = tensor_normalize( 61 | frames, 62 | torch.tensor(MEAN), 63 | torch.tensor(STD), 64 | ).permute(3, 0, 1, 2) 65 | print(frames.shape) 66 | frames = spatial_sampling( 67 | frames, 68 | spatial_idx=1, 69 | min_scale=256, 70 | max_scale=256, 71 | crop_size=224, 72 | random_horizontal_flip=False, 73 | inverse_uniform_sampling=False, 74 | aspect_ratio=None, 75 | scale=None, 76 | motion_shift=False, 77 | ) 78 | print(frames.shape) 79 | frames = frames.cuda().to(dtype=torch.float16) 80 | for ratio in [0.3, 0.5, 0.7, 0.9]: 81 | loss, _, _, vis = model(frames.unsqueeze(0), 1, mask_ratio=ratio, visualize=True) 82 | vis = vis.detach().cpu() 83 | print(ratio, loss) 84 | plot_input(vis[0].permute(0, 2, 1, 3, 4)) -------------------------------------------------------------------------------- /wild_config.py: -------------------------------------------------------------------------------- 1 | # data 2 | # dataroot = '/root/autodl-tmp/my_clothes_dataset/' 3 | dataroot = '/root/autodl-tmp/wild_video_dataset/10251_dataset' 4 | # dataroot = '/root/autodl-tmp/10091_dataset/' 5 | fine_height = 512 6 | fine_width = 384 7 | semantic_nc = 13 8 | with_one_hot= False 9 | is_atr = True 10 | 11 | # infer 12 | model_path = '/root/autodl-tmp/df-1.5' 13 | # model_path = 'runwayml/stable-diffusion-v1-5' 14 | unet_path = '/root/autodl-tmp/agnostic_norm_hair_have_background/checkpoint-50000/' 15 | vae_path = '/root/autodl-tmp/HR_VITON_vae' 16 | out_dir = 'test_guide' 17 | 18 | test_dataset = 'Wild' 19 | infer_datamode = 'test' 20 | infer_data_list = 'test_pairs.txt' 21 | infer_datasetting = 'unpaired' 22 | 23 | # train data 24 | train_datamode = 'train' 25 | train_data_list = 'train_pairs.txt' 26 | train_datasetting = 'paired' 27 | 28 | # consecutive prediction 29 | videos_root = '/root/autodl-tmp/wild_video_dataset/' 30 | videos_list = '/root/autodl-tmp/wild_video_dataset/data_list.txt' 31 | output_root = '/root/autodl-tmp/video_try_on/output' 32 | --------------------------------------------------------------------------------