├── .gitignore ├── requirements.txt ├── src ├── masks │ ├── default.py │ ├── utils.py │ ├── random_tube.py │ └── multiblock3d.py ├── models │ ├── utils │ │ ├── multimask.py │ │ ├── patch_embed.py │ │ ├── pos_embs.py │ │ └── modules.py │ ├── attentive_pooler.py │ ├── predictor.py │ └── vision_transformer.py ├── datasets │ ├── image_dataset.py │ ├── data_manager.py │ ├── utils │ │ ├── weighted_sampler.py │ │ └── video │ │ │ ├── functional.py │ │ │ ├── volume_transforms.py │ │ │ ├── randerase.py │ │ │ └── randaugment.py │ └── video_dataset.py └── utils │ ├── schedulers.py │ ├── tensors.py │ ├── distributed.py │ ├── logging.py │ └── monitoring.py ├── app ├── scaffold.py ├── main.py ├── main_distributed.py └── vjepa │ ├── transforms.py │ └── utils.py ├── setup.py ├── evals ├── scaffold.py ├── main.py ├── main_distributed.py ├── video_classification_frozen │ └── utils.py └── image_classification_frozen │ └── eval.py ├── configs ├── evals │ ├── vith16_inat.yaml │ ├── vitl16_inat.yaml │ ├── vith16_384_inat.yaml │ ├── vith16_in1k.yaml │ ├── vith16_places.yaml │ ├── vitl16_in1k.yaml │ ├── vith16_384_in1k.yaml │ ├── vith16_384_places.yaml │ ├── vitl16_places.yaml │ ├── vith16_ssv2_16x2x3.yaml │ ├── vitl16_ssv2_16x2x3.yaml │ ├── vith16_384_ssv2_16x2x3.yaml │ ├── vith16_k400_16x8x3.yaml │ ├── vitl16_k400_16x8x3.yaml │ └── vith16_384_k400_16x8x3.yaml └── pretrain │ ├── vith16.yaml │ ├── vitl16.yaml │ └── vith16_384.yaml ├── CONTRIBUTING.md └── CODE_OF_CONDUCT.md /.gitignore: -------------------------------------------------------------------------------- 1 | .*.swp 2 | *.pyc 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=2 2 | torchvision 3 | pyyaml 4 | numpy 5 | opencv-python 6 | submitit 7 | braceexpand 8 | webdataset 9 | timm 10 | decord 11 | pandas 12 | einops 13 | beartype 14 | -------------------------------------------------------------------------------- /src/masks/default.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from logging import getLogger 9 | 10 | import torch 11 | 12 | _GLOBAL_SEED = 0 13 | logger = getLogger() 14 | 15 | 16 | class DefaultCollator(object): 17 | 18 | def __call__(self, batch): 19 | collated_batch = torch.utils.data.default_collate(batch) 20 | return collated_batch, None, None 21 | -------------------------------------------------------------------------------- /app/scaffold.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import importlib 9 | import logging 10 | import sys 11 | 12 | logging.basicConfig(stream=sys.stdout, level=logging.INFO) 13 | logger = logging.getLogger() 14 | 15 | 16 | def main(app, args, resume_preempt=False): 17 | 18 | logger.info(f'Running pre-training of app: {app}') 19 | return importlib.import_module(f'app.{app}.train').main( 20 | args=args, 21 | resume_preempt=resume_preempt) 22 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | # 6 | 7 | import os 8 | from setuptools import setup 9 | 10 | VERSION = "0.0.1" 11 | 12 | def get_requirements(): 13 | with open("./requirements.txt") as reqsf: 14 | reqs = reqsf.readlines() 15 | return reqs 16 | 17 | 18 | if __name__ == "__main__": 19 | setup( 20 | name="jepa", 21 | version=VERSION, 22 | description="JEPA research code.", 23 | python_requires=">=3.9", 24 | install_requires=get_requirements(), 25 | ) 26 | -------------------------------------------------------------------------------- /evals/scaffold.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import importlib 9 | import logging 10 | import sys 11 | 12 | logging.basicConfig(stream=sys.stdout, level=logging.INFO) 13 | logger = logging.getLogger() 14 | 15 | 16 | def main( 17 | eval_name, 18 | args_eval, 19 | resume_preempt=False 20 | ): 21 | logger.info(f'Running evaluation: {eval_name}') 22 | return importlib.import_module(f'evals.{eval_name}.eval').main( 23 | args_eval=args_eval, 24 | resume_preempt=resume_preempt) 25 | -------------------------------------------------------------------------------- /src/masks/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import torch 9 | 10 | 11 | def apply_masks(x, masks, concat=True): 12 | """ 13 | :param x: tensor of shape [B (batch-size), N (num-patches), D (feature-dim)] 14 | :param masks: list of tensors of shape [B, K] containing indices of K patches in [N] to keep 15 | """ 16 | all_x = [] 17 | for m in masks: 18 | mask_keep = m.unsqueeze(-1).repeat(1, 1, x.size(-1)) 19 | all_x += [torch.gather(x, dim=1, index=mask_keep)] 20 | if not concat: 21 | return all_x 22 | 23 | return torch.cat(all_x, dim=0) 24 | -------------------------------------------------------------------------------- /configs/evals/vith16_inat.yaml: -------------------------------------------------------------------------------- 1 | nodes: 8 2 | tasks_per_node: 8 3 | tag: inat-16f 4 | eval_name: image_classification_frozen 5 | resume_checkpoint: false 6 | data: 7 | root_path: /your_absolute_file_path_to_directory_where_image_datasets_are_stored/ 8 | image_folder: iNaturalist-2021/110421/ 9 | num_classes: 10000 10 | resolution: 224 11 | dataset_name: iNat21 12 | optimization: 13 | num_epochs: 20 14 | batch_size: 16 15 | weight_decay: 0.001 16 | lr: 0.001 17 | start_lr: 0.001 18 | final_lr: 0.0 19 | warmup: 0. 20 | use_bfloat16: true 21 | pretrain: 22 | model_name: vit_huge 23 | checkpoint_key: target_encoder 24 | clip_duration: null 25 | frames_per_clip: 16 26 | tubelet_size: 2 27 | uniform_power: true 28 | use_sdpa: true 29 | use_silu: false 30 | tight_silu: false 31 | patch_size: 16 32 | folder: /your_absolute_file_path_to_directory_where_pretrained_models_are_contained/ 33 | checkpoint: jepa-latest.pth.tar # name of pretrained model file inside folder 34 | write_tag: jepa 35 | -------------------------------------------------------------------------------- /configs/evals/vitl16_inat.yaml: -------------------------------------------------------------------------------- 1 | nodes: 8 2 | tasks_per_node: 8 3 | tag: inat-16f 4 | eval_name: image_classification_frozen 5 | resume_checkpoint: false 6 | data: 7 | root_path: /your_absolute_file_path_to_directory_where_image_datasets_are_stored/ 8 | image_folder: iNaturalist-2021/110421/ 9 | num_classes: 10000 10 | resolution: 224 11 | dataset_name: iNat21 12 | optimization: 13 | num_epochs: 20 14 | batch_size: 16 15 | weight_decay: 0.001 16 | lr: 0.001 17 | start_lr: 0.001 18 | final_lr: 0.0 19 | warmup: 0. 20 | use_bfloat16: true 21 | pretrain: 22 | model_name: vit_large 23 | checkpoint_key: target_encoder 24 | clip_duration: null 25 | frames_per_clip: 16 26 | tubelet_size: 2 27 | uniform_power: true 28 | use_sdpa: true 29 | use_silu: false 30 | tight_silu: false 31 | patch_size: 16 32 | folder: /your_absolute_file_path_to_directory_where_pretrained_models_are_contained/ 33 | checkpoint: jepa-latest.pth.tar # name of pretrained model file inside folder 34 | write_tag: jepa 35 | -------------------------------------------------------------------------------- /configs/evals/vith16_384_inat.yaml: -------------------------------------------------------------------------------- 1 | nodes: 8 2 | tasks_per_node: 8 3 | tag: inat-16f 4 | eval_name: image_classification_frozen 5 | resume_checkpoint: false 6 | data: 7 | root_path: /your_absolute_file_path_to_directory_where_image_datasets_are_stored/ 8 | image_folder: iNaturalist-2021/110421/ 9 | num_classes: 10000 10 | resolution: 384 11 | dataset_name: iNat21 12 | optimization: 13 | num_epochs: 20 14 | batch_size: 16 15 | weight_decay: 0.001 16 | lr: 0.001 17 | start_lr: 0.001 18 | final_lr: 0.0 19 | warmup: 0. 20 | use_bfloat16: true 21 | pretrain: 22 | model_name: vit_huge 23 | checkpoint_key: target_encoder 24 | clip_duration: null 25 | frames_per_clip: 16 26 | tubelet_size: 2 27 | uniform_power: true 28 | use_sdpa: true 29 | use_silu: false 30 | tight_silu: false 31 | patch_size: 16 32 | folder: /your_absolute_file_path_to_directory_where_pretrained_models_are_contained/ 33 | checkpoint: jepa-latest.pth.tar # name of pretrained model file inside folder 34 | write_tag: jepa 35 | -------------------------------------------------------------------------------- /configs/evals/vith16_in1k.yaml: -------------------------------------------------------------------------------- 1 | nodes: 8 2 | tasks_per_node: 8 3 | tag: in1k-16f 4 | eval_name: image_classification_frozen 5 | resume_checkpoint: false 6 | data: 7 | root_path: /your_absolute_file_path_to_directory_where_image_datasets_are_stored/ 8 | image_folder: imagenet_full_size/061417/ 9 | num_classes: 1000 10 | resolution: 224 11 | dataset_name: ImageNet 12 | optimization: 13 | num_epochs: 20 14 | batch_size: 16 15 | weight_decay: 0.001 16 | lr: 0.001 17 | start_lr: 0.001 18 | final_lr: 0.0 19 | warmup: 0. 20 | use_bfloat16: true 21 | pretrain: 22 | model_name: vit_huge 23 | checkpoint_key: target_encoder 24 | clip_duration: null 25 | frames_per_clip: 16 26 | tubelet_size: 2 27 | uniform_power: true 28 | use_sdpa: true 29 | use_silu: false 30 | tight_silu: false 31 | patch_size: 16 32 | folder: /your_absolute_file_path_to_directory_where_pretrained_models_are_contained/ 33 | checkpoint: jepa-latest.pth.tar # name of pretrained model file inside folder 34 | write_tag: jepa 35 | -------------------------------------------------------------------------------- /configs/evals/vith16_places.yaml: -------------------------------------------------------------------------------- 1 | nodes: 8 2 | tasks_per_node: 8 3 | tag: places-16f 4 | eval_name: image_classification_frozen 5 | resume_checkpoint: false 6 | data: 7 | root_path: /your_absolute_file_path_to_directory_where_image_datasets_are_stored/ 8 | image_folder: places205/121517/pytorch/ 9 | num_classes: 205 10 | resolution: 224 11 | dataset_name: Places205 12 | optimization: 13 | num_epochs: 20 14 | batch_size: 16 15 | weight_decay: 0.001 16 | lr: 0.001 17 | start_lr: 0.001 18 | final_lr: 0.0 19 | warmup: 0. 20 | use_bfloat16: true 21 | pretrain: 22 | model_name: vit_huge 23 | checkpoint_key: target_encoder 24 | clip_duration: null 25 | frames_per_clip: 16 26 | tubelet_size: 2 27 | uniform_power: true 28 | use_sdpa: true 29 | use_silu: false 30 | tight_silu: false 31 | patch_size: 16 32 | folder: /your_absolute_file_path_to_directory_where_pretrained_models_are_contained/ 33 | checkpoint: jepa-latest.pth.tar # name of pretrained model file inside folder 34 | write_tag: jepa 35 | -------------------------------------------------------------------------------- /configs/evals/vitl16_in1k.yaml: -------------------------------------------------------------------------------- 1 | nodes: 8 2 | tasks_per_node: 8 3 | tag: in1k-16f 4 | eval_name: image_classification_frozen 5 | resume_checkpoint: false 6 | data: 7 | root_path: /your_absolute_file_path_to_directory_where_image_datasets_are_stored/ 8 | image_folder: imagenet_full_size/061417/ 9 | num_classes: 1000 10 | resolution: 224 11 | dataset_name: ImageNet 12 | optimization: 13 | num_epochs: 20 14 | batch_size: 16 15 | weight_decay: 0.001 16 | lr: 0.001 17 | start_lr: 0.001 18 | final_lr: 0.0 19 | warmup: 0. 20 | use_bfloat16: true 21 | pretrain: 22 | model_name: vit_large 23 | checkpoint_key: target_encoder 24 | clip_duration: null 25 | frames_per_clip: 16 26 | tubelet_size: 2 27 | uniform_power: true 28 | use_sdpa: true 29 | use_silu: false 30 | tight_silu: false 31 | patch_size: 16 32 | folder: /your_absolute_file_path_to_directory_where_pretrained_models_are_contained/ 33 | checkpoint: jepa-latest.pth.tar # name of pretrained model file inside folder 34 | write_tag: jepa 35 | -------------------------------------------------------------------------------- /configs/evals/vith16_384_in1k.yaml: -------------------------------------------------------------------------------- 1 | nodes: 8 2 | tasks_per_node: 8 3 | tag: in1k-16f 4 | eval_name: image_classification_frozen 5 | resume_checkpoint: false 6 | data: 7 | root_path: /your_absolute_file_path_to_directory_where_image_datasets_are_stored/ 8 | image_folder: imagenet_full_size/061417/ 9 | num_classes: 1000 10 | resolution: 384 11 | dataset_name: ImageNet 12 | optimization: 13 | num_epochs: 20 14 | batch_size: 16 15 | weight_decay: 0.001 16 | lr: 0.001 17 | start_lr: 0.001 18 | final_lr: 0.0 19 | warmup: 0. 20 | use_bfloat16: true 21 | pretrain: 22 | model_name: vit_huge 23 | checkpoint_key: target_encoder 24 | clip_duration: null 25 | frames_per_clip: 16 26 | tubelet_size: 2 27 | uniform_power: true 28 | use_sdpa: true 29 | use_silu: false 30 | tight_silu: false 31 | patch_size: 16 32 | folder: /your_absolute_file_path_to_directory_where_pretrained_models_are_contained/ 33 | checkpoint: jepa-latest.pth.tar # name of pretrained model file inside folder 34 | write_tag: jepa 35 | -------------------------------------------------------------------------------- /configs/evals/vith16_384_places.yaml: -------------------------------------------------------------------------------- 1 | nodes: 8 2 | tasks_per_node: 8 3 | tag: places-16f 4 | eval_name: image_classification_frozen 5 | resume_checkpoint: false 6 | data: 7 | root_path: /your_absolute_file_path_to_directory_where_image_datasets_are_stored/ 8 | image_folder: places205/121517/pytorch/ 9 | num_classes: 205 10 | resolution: 384 11 | dataset_name: Places205 12 | optimization: 13 | num_epochs: 20 14 | batch_size: 16 15 | weight_decay: 0.001 16 | lr: 0.001 17 | start_lr: 0.001 18 | final_lr: 0.0 19 | warmup: 0. 20 | use_bfloat16: true 21 | pretrain: 22 | model_name: vit_huge 23 | checkpoint_key: target_encoder 24 | clip_duration: null 25 | frames_per_clip: 16 26 | tubelet_size: 2 27 | uniform_power: true 28 | use_sdpa: true 29 | use_silu: false 30 | tight_silu: false 31 | patch_size: 16 32 | folder: /your_absolute_file_path_to_directory_where_pretrained_models_are_contained/ 33 | checkpoint: jepa-latest.pth.tar # name of pretrained model file inside folder 34 | write_tag: jepa 35 | -------------------------------------------------------------------------------- /configs/evals/vitl16_places.yaml: -------------------------------------------------------------------------------- 1 | nodes: 8 2 | tasks_per_node: 8 3 | tag: places-16f 4 | eval_name: image_classification_frozen 5 | resume_checkpoint: false 6 | data: 7 | root_path: /your_absolute_file_path_to_directory_where_image_datasets_are_stored/ 8 | image_folder: places205/121517/pytorch/ 9 | num_classes: 205 10 | resolution: 224 11 | dataset_name: Places205 12 | optimization: 13 | num_epochs: 20 14 | batch_size: 16 15 | weight_decay: 0.001 16 | lr: 0.001 17 | start_lr: 0.001 18 | final_lr: 0.0 19 | warmup: 0. 20 | use_bfloat16: true 21 | pretrain: 22 | model_name: vit_large 23 | checkpoint_key: target_encoder 24 | clip_duration: null 25 | frames_per_clip: 16 26 | tubelet_size: 2 27 | uniform_power: true 28 | use_sdpa: true 29 | use_silu: false 30 | tight_silu: false 31 | patch_size: 16 32 | folder: /your_absolute_file_path_to_directory_where_pretrained_models_are_contained/ 33 | checkpoint: jepa-latest.pth.tar # name of pretrained model file inside folder 34 | write_tag: jepa 35 | -------------------------------------------------------------------------------- /configs/evals/vith16_ssv2_16x2x3.yaml: -------------------------------------------------------------------------------- 1 | nodes: 8 2 | tasks_per_node: 8 3 | tag: ssv2-16x2x3 4 | eval_name: video_classification_frozen 5 | resume_checkpoint: false 6 | data: 7 | dataset_train: /your_path_to_ssv2_train_csv_file_index.csv 8 | dataset_val: /your_path_to_ssv2_val_csv_file_index.csv 9 | dataset_type: VideoDataset 10 | num_classes: 174 11 | frames_per_clip: 16 12 | num_segments: 2 13 | num_views_per_segment: 3 14 | frame_step: 4 15 | optimization: 16 | attend_across_segments: true 17 | num_epochs: 20 18 | resolution: 224 19 | batch_size: 4 20 | weight_decay: 0.01 21 | lr: 0.001 22 | start_lr: 0.001 23 | final_lr: 0.0 24 | warmup: 0. 25 | use_bfloat16: true 26 | pretrain: 27 | model_name: vit_huge 28 | checkpoint_key: target_encoder 29 | clip_duration: null 30 | frames_per_clip: 16 31 | tubelet_size: 2 32 | uniform_power: true 33 | use_silu: false 34 | tight_silu: false 35 | use_sdpa: true 36 | patch_size: 16 37 | folder: /your_absolute_file_path_to_directory_where_pretrained_models_are_contained/ 38 | checkpoint: jepa-latest.pth.tar # name of pretrained model file inside folder 39 | write_tag: jepa 40 | -------------------------------------------------------------------------------- /configs/evals/vitl16_ssv2_16x2x3.yaml: -------------------------------------------------------------------------------- 1 | nodes: 8 2 | tasks_per_node: 8 3 | tag: ssv2-16x2x3 4 | eval_name: video_classification_frozen 5 | resume_checkpoint: false 6 | data: 7 | dataset_train: /your_path_to_ssv2_train_csv_file_index.csv 8 | dataset_val: /your_path_to_ssv2_val_csv_file_index.csv 9 | dataset_type: VideoDataset 10 | num_classes: 174 11 | frames_per_clip: 16 12 | num_segments: 2 13 | num_views_per_segment: 3 14 | frame_step: 4 15 | optimization: 16 | attend_across_segments: true 17 | num_epochs: 20 18 | resolution: 224 19 | batch_size: 4 20 | weight_decay: 0.01 21 | lr: 0.001 22 | start_lr: 0.001 23 | final_lr: 0.0 24 | warmup: 0. 25 | use_bfloat16: true 26 | pretrain: 27 | model_name: vit_large 28 | checkpoint_key: target_encoder 29 | clip_duration: null 30 | frames_per_clip: 16 31 | tubelet_size: 2 32 | uniform_power: true 33 | use_silu: false 34 | tight_silu: false 35 | use_sdpa: true 36 | patch_size: 16 37 | folder: /your_absolute_file_path_to_directory_where_pretrained_models_are_contained/ 38 | checkpoint: jepa-latest.pth.tar # name of pretrained model file inside folder 39 | write_tag: jepa 40 | -------------------------------------------------------------------------------- /configs/evals/vith16_384_ssv2_16x2x3.yaml: -------------------------------------------------------------------------------- 1 | nodes: 8 2 | tasks_per_node: 8 3 | tag: ssv2-16x2x3 4 | eval_name: video_classification_frozen 5 | resume_checkpoint: false 6 | data: 7 | dataset_train: /your_path_to_ssv2_train_csv_file_index.csv 8 | dataset_val: /your_path_to_ssv2_val_csv_file_index.csv 9 | dataset_type: VideoDataset 10 | num_classes: 174 11 | frames_per_clip: 16 12 | num_segments: 2 13 | num_views_per_segment: 3 14 | frame_step: 4 15 | optimization: 16 | attend_across_segments: true 17 | num_epochs: 20 18 | resolution: 384 19 | batch_size: 4 20 | weight_decay: 0.01 21 | lr: 0.001 22 | start_lr: 0.001 23 | final_lr: 0.0 24 | warmup: 0. 25 | use_bfloat16: true 26 | pretrain: 27 | model_name: vit_huge 28 | checkpoint_key: target_encoder 29 | clip_duration: null 30 | frames_per_clip: 16 31 | tubelet_size: 2 32 | uniform_power: true 33 | use_silu: false 34 | tight_silu: false 35 | use_sdpa: true 36 | patch_size: 16 37 | folder: /your_absolute_file_path_to_directory_where_pretrained_models_are_contained/ 38 | checkpoint: jepa-latest.pth.tar # name of pretrained model file inside folder 39 | write_tag: jepa 40 | -------------------------------------------------------------------------------- /configs/evals/vith16_k400_16x8x3.yaml: -------------------------------------------------------------------------------- 1 | nodes: 8 2 | tasks_per_node: 8 3 | tag: k400-16x8x3 4 | eval_name: video_classification_frozen 5 | resume_checkpoint: false 6 | data: 7 | dataset_train: /your_path_to_kinetics400_train_csv_file_index.csv 8 | dataset_val: /your_path_to_kinetics400_val_csv_file_index.csv 9 | dataset_type: VideoDataset 10 | num_classes: 400 11 | frames_per_clip: 16 12 | num_segments: 8 13 | num_views_per_segment: 3 14 | frame_step: 4 15 | optimization: 16 | attend_across_segments: true 17 | num_epochs: 20 18 | resolution: 224 19 | batch_size: 4 20 | weight_decay: 0.01 21 | lr: 0.001 22 | start_lr: 0.001 23 | final_lr: 0.0 24 | warmup: 0. 25 | use_bfloat16: true 26 | pretrain: 27 | model_name: vit_huge 28 | checkpoint_key: target_encoder 29 | clip_duration: null 30 | frames_per_clip: 16 31 | tubelet_size: 2 32 | uniform_power: true 33 | use_silu: false 34 | tight_silu: false 35 | use_sdpa: true 36 | patch_size: 16 37 | folder: /your_absolute_file_path_to_directory_where_pretrained_models_are_contained/ 38 | checkpoint: jepa-latest.pth.tar # name of pretrained model file inside folder 39 | write_tag: jepa 40 | -------------------------------------------------------------------------------- /configs/evals/vitl16_k400_16x8x3.yaml: -------------------------------------------------------------------------------- 1 | nodes: 8 2 | tasks_per_node: 8 3 | tag: k400-16x8x3 4 | eval_name: video_classification_frozen 5 | resume_checkpoint: false 6 | data: 7 | dataset_train: /your_path_to_kinetics400_train_csv_file_index.csv 8 | dataset_val: /your_path_to_kinetics400_val_csv_file_index.csv 9 | dataset_type: VideoDataset 10 | num_classes: 400 11 | frames_per_clip: 16 12 | num_segments: 8 13 | num_views_per_segment: 3 14 | frame_step: 4 15 | optimization: 16 | attend_across_segments: true 17 | num_epochs: 20 18 | resolution: 224 19 | batch_size: 4 20 | weight_decay: 0.01 21 | lr: 0.001 22 | start_lr: 0.001 23 | final_lr: 0.0 24 | warmup: 0. 25 | use_bfloat16: true 26 | pretrain: 27 | model_name: vit_large 28 | checkpoint_key: target_encoder 29 | clip_duration: null 30 | frames_per_clip: 16 31 | tubelet_size: 2 32 | uniform_power: true 33 | use_silu: false 34 | tight_silu: false 35 | use_sdpa: true 36 | patch_size: 16 37 | folder: /your_absolute_file_path_to_directory_where_pretrained_models_are_contained/ 38 | checkpoint: jepa-latest.pth.tar # name of pretrained model file inside folder 39 | write_tag: jepa 40 | -------------------------------------------------------------------------------- /configs/evals/vith16_384_k400_16x8x3.yaml: -------------------------------------------------------------------------------- 1 | nodes: 8 2 | tasks_per_node: 8 3 | tag: k400-16x8x3 4 | eval_name: video_classification_frozen 5 | resume_checkpoint: false 6 | data: 7 | dataset_train: /your_path_to_kinetics400_train_csv_file_index.csv 8 | dataset_val: /your_path_to_kinetics400_val_csv_file_index.csv 9 | dataset_type: VideoDataset 10 | num_classes: 400 11 | frames_per_clip: 16 12 | num_segments: 8 13 | num_views_per_segment: 3 14 | frame_step: 4 15 | optimization: 16 | attend_across_segments: true 17 | num_epochs: 20 18 | resolution: 384 19 | batch_size: 4 20 | weight_decay: 0.01 21 | lr: 0.001 22 | start_lr: 0.001 23 | final_lr: 0.0 24 | warmup: 0. 25 | use_bfloat16: true 26 | pretrain: 27 | model_name: vit_huge 28 | checkpoint_key: target_encoder 29 | clip_duration: null 30 | frames_per_clip: 16 31 | tubelet_size: 2 32 | uniform_power: true 33 | use_silu: false 34 | tight_silu: false 35 | use_sdpa: true 36 | patch_size: 16 37 | folder: /your_absolute_file_path_to_directory_where_pretrained_models_are_contained/ 38 | checkpoint: jepa-latest.pth.tar # name of pretrained model file inside folder 39 | write_tag: jepa 40 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to JEPA 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## Coding Style 30 | * 4 spaces for indentation rather than tabs 31 | * 80 character line length 32 | * PEP8 formatting 33 | 34 | ## License 35 | By contributing to this repository, you agree that your contributions will be licensed 36 | under the LICENSE file in the root directory of this source tree. 37 | -------------------------------------------------------------------------------- /src/models/utils/multimask.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import torch.nn as nn 9 | 10 | 11 | class MultiMaskWrapper(nn.Module): 12 | 13 | def __init__(self, backbone): 14 | super().__init__() 15 | self.backbone = backbone 16 | 17 | def forward(self, x, masks=None): 18 | if masks is None: 19 | return self.backbone(x) 20 | 21 | if (masks is not None) and not isinstance(masks, list): 22 | masks = [masks] 23 | outs = [] 24 | for m in masks: 25 | outs += [self.backbone(x, masks=m)] 26 | return outs 27 | 28 | 29 | class PredictorMultiMaskWrapper(nn.Module): 30 | 31 | def __init__(self, backbone): 32 | super().__init__() 33 | self.backbone = backbone 34 | 35 | def forward(self, ctxt, tgt, masks_ctxt, masks_tgt): 36 | if type(ctxt) is not list: 37 | ctxt = [ctxt] 38 | if type(tgt) is not list: 39 | tgt = [tgt] 40 | if type(masks_ctxt) is not list: 41 | masks_ctxt = [masks_ctxt] 42 | if type(masks_tgt) is not list: 43 | masks_tgt = [masks_tgt] 44 | 45 | outs = [] 46 | for i, (zi, hi, mc, mt) in enumerate(zip(ctxt, tgt, masks_ctxt, masks_tgt)): 47 | outs += [self.backbone(zi, hi, mc, mt, mask_index=i)] 48 | return outs 49 | -------------------------------------------------------------------------------- /src/models/utils/patch_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import torch.nn as nn 9 | 10 | 11 | class PatchEmbed(nn.Module): 12 | """ 13 | Image to Patch Embedding 14 | """ 15 | def __init__( 16 | self, 17 | patch_size=16, 18 | in_chans=3, 19 | embed_dim=768 20 | ): 21 | super().__init__() 22 | self.patch_size = patch_size 23 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 24 | 25 | def forward(self, x): 26 | B, C, H, W = x.shape 27 | x = self.proj(x).flatten(2).transpose(1, 2) 28 | return x 29 | 30 | 31 | class PatchEmbed3D(nn.Module): 32 | """ 33 | Image to Patch Embedding 34 | """ 35 | 36 | def __init__( 37 | self, 38 | patch_size=16, 39 | tubelet_size=2, 40 | in_chans=3, 41 | embed_dim=768, 42 | ): 43 | super().__init__() 44 | self.patch_size = patch_size 45 | self.tubelet_size = tubelet_size 46 | 47 | self.proj = nn.Conv3d( 48 | in_channels=in_chans, 49 | out_channels=embed_dim, 50 | kernel_size=(tubelet_size, patch_size, patch_size), 51 | stride=(tubelet_size, patch_size, patch_size), 52 | ) 53 | 54 | def forward(self, x, **kwargs): 55 | B, C, T, H, W = x.shape 56 | x = self.proj(x).flatten(2).transpose(1, 2) 57 | return x 58 | -------------------------------------------------------------------------------- /configs/pretrain/vith16.yaml: -------------------------------------------------------------------------------- 1 | app: vjepa 2 | nodes: 16 3 | tasks_per_node: 8 4 | data: 5 | dataset_type: VideoDataset 6 | datasets: 7 | - /your_path_to_kinetics710_csv_file_index.csv 8 | - /your_path_to_ssv2_csv_file_index.csv 9 | - /your_path_to_howto100m_csv_file_index.csv 10 | decode_one_clip: true 11 | batch_size: 24 12 | num_clips: 1 13 | num_frames: 16 14 | tubelet_size: 2 15 | sampling_rate: 4 16 | crop_size: 224 17 | patch_size: 16 18 | pin_mem: true 19 | num_workers: 12 20 | filter_short_videos: false 21 | clip_duration: null 22 | data_aug: 23 | auto_augment: false 24 | motion_shift: false 25 | random_resize_aspect_ratio: 26 | - 0.75 27 | - 1.35 28 | random_resize_scale: 29 | - 0.3 30 | - 1.0 31 | reprob: 0.0 32 | logging: 33 | folder: /your_absolute_file_path_for_saving_logs_and_checkpoints/ 34 | write_tag: jepa 35 | loss: 36 | loss_exp: 1.0 37 | reg_coeff: 0.0 38 | mask: 39 | - aspect_ratio: 40 | - 0.75 41 | - 1.5 42 | num_blocks: 8 43 | spatial_scale: 44 | - 0.15 45 | - 0.15 46 | temporal_scale: 47 | - 1.0 48 | - 1.0 49 | max_temporal_keep: 1.0 50 | max_keep: null 51 | - aspect_ratio: 52 | - 0.75 53 | - 1.5 54 | num_blocks: 2 55 | spatial_scale: 56 | - 0.7 57 | - 0.7 58 | temporal_scale: 59 | - 1.0 60 | - 1.0 61 | max_temporal_keep: 1.0 62 | max_keep: null 63 | meta: 64 | load_checkpoint: false 65 | read_checkpoint: null 66 | seed: 234 67 | eval_freq: 100 68 | use_sdpa: true 69 | dtype: bfloat16 70 | model: 71 | model_name: vit_huge 72 | pred_depth: 12 73 | pred_embed_dim: 384 74 | uniform_power: true 75 | use_mask_tokens: true 76 | zero_init_mask_tokens: true 77 | optimization: 78 | ipe: 300 79 | ipe_scale: 1.25 80 | clip_grad: 10.0 81 | weight_decay: 0.04 82 | final_weight_decay: 0.4 83 | epochs: 300 84 | warmup: 40 85 | start_lr: 0.0002 86 | lr: 0.000625 87 | final_lr: 1.0e-06 88 | ema: 89 | - 0.998 90 | - 1.0 91 | -------------------------------------------------------------------------------- /configs/pretrain/vitl16.yaml: -------------------------------------------------------------------------------- 1 | app: vjepa 2 | nodes: 16 3 | tasks_per_node: 8 4 | data: 5 | dataset_type: VideoDataset 6 | datasets: 7 | - /your_path_to_kinetics710_csv_file_index.csv 8 | - /your_path_to_ssv2_csv_file_index.csv 9 | - /your_path_to_howto100m_csv_file_index.csv 10 | decode_one_clip: true 11 | batch_size: 24 12 | num_clips: 1 13 | num_frames: 16 14 | tubelet_size: 2 15 | sampling_rate: 4 16 | crop_size: 224 17 | patch_size: 16 18 | pin_mem: true 19 | num_workers: 12 20 | filter_short_videos: false 21 | clip_duration: null 22 | data_aug: 23 | auto_augment: false 24 | motion_shift: false 25 | random_resize_aspect_ratio: 26 | - 0.75 27 | - 1.35 28 | random_resize_scale: 29 | - 0.3 30 | - 1.0 31 | reprob: 0.0 32 | logging: 33 | folder: /your_absolute_file_path_for_saving_logs_and_checkpoints/ 34 | write_tag: jepa 35 | loss: 36 | loss_exp: 1.0 37 | reg_coeff: 0.0 38 | mask: 39 | - aspect_ratio: 40 | - 0.75 41 | - 1.5 42 | num_blocks: 8 43 | spatial_scale: 44 | - 0.15 45 | - 0.15 46 | temporal_scale: 47 | - 1.0 48 | - 1.0 49 | max_temporal_keep: 1.0 50 | max_keep: null 51 | - aspect_ratio: 52 | - 0.75 53 | - 1.5 54 | num_blocks: 2 55 | spatial_scale: 56 | - 0.7 57 | - 0.7 58 | temporal_scale: 59 | - 1.0 60 | - 1.0 61 | max_temporal_keep: 1.0 62 | max_keep: null 63 | meta: 64 | load_checkpoint: false 65 | read_checkpoint: null 66 | seed: 234 67 | eval_freq: 100 68 | use_sdpa: true 69 | dtype: bfloat16 70 | model: 71 | model_name: vit_large 72 | pred_depth: 12 73 | pred_embed_dim: 384 74 | uniform_power: true 75 | use_mask_tokens: true 76 | zero_init_mask_tokens: true 77 | optimization: 78 | ipe: 300 79 | ipe_scale: 1.25 80 | clip_grad: 10.0 81 | weight_decay: 0.04 82 | final_weight_decay: 0.4 83 | epochs: 300 84 | warmup: 40 85 | start_lr: 0.0002 86 | lr: 0.000625 87 | final_lr: 1.0e-06 88 | ema: 89 | - 0.998 90 | - 1.0 91 | -------------------------------------------------------------------------------- /configs/pretrain/vith16_384.yaml: -------------------------------------------------------------------------------- 1 | app: vjepa 2 | nodes: 30 3 | tasks_per_node: 8 4 | data: 5 | dataset_type: VideoDataset 6 | datasets: 7 | - /your_path_to_kinetics710_csv_file_index.csv 8 | - /your_path_to_ssv2_csv_file_index.csv 9 | - /your_path_to_howto100m_csv_file_index.csv 10 | decode_one_clip: true 11 | batch_size: 10 12 | num_clips: 1 13 | num_frames: 16 14 | tubelet_size: 2 15 | sampling_rate: 4 16 | crop_size: 384 17 | patch_size: 16 18 | pin_mem: true 19 | num_workers: 12 20 | filter_short_videos: false 21 | clip_duration: null 22 | data_aug: 23 | auto_augment: false 24 | motion_shift: false 25 | random_resize_aspect_ratio: 26 | - 0.75 27 | - 1.35 28 | random_resize_scale: 29 | - 0.3 30 | - 1.0 31 | reprob: 0.0 32 | logging: 33 | folder: /your_absolute_file_path_for_saving_logs_and_checkpoints/ 34 | write_tag: jepa 35 | loss: 36 | loss_exp: 1.0 37 | reg_coeff: 0.0 38 | mask: 39 | - aspect_ratio: 40 | - 0.75 41 | - 1.5 42 | num_blocks: 8 43 | spatial_scale: 44 | - 0.15 45 | - 0.15 46 | temporal_scale: 47 | - 1.0 48 | - 1.0 49 | max_temporal_keep: 1.0 50 | max_keep: null 51 | - aspect_ratio: 52 | - 0.75 53 | - 1.5 54 | num_blocks: 2 55 | spatial_scale: 56 | - 0.7 57 | - 0.7 58 | temporal_scale: 59 | - 1.0 60 | - 1.0 61 | max_temporal_keep: 1.0 62 | max_keep: null 63 | meta: 64 | load_checkpoint: false 65 | read_checkpoint: null 66 | seed: 234 67 | eval_freq: 100 68 | use_sdpa: true 69 | dtype: bfloat16 70 | model: 71 | model_name: vit_huge 72 | pred_depth: 12 73 | pred_embed_dim: 384 74 | uniform_power: true 75 | use_mask_tokens: true 76 | zero_init_mask_tokens: true 77 | optimization: 78 | ipe: 300 79 | ipe_scale: 1.25 80 | clip_grad: 10.0 81 | weight_decay: 0.04 82 | final_weight_decay: 0.4 83 | epochs: 300 84 | warmup: 40 85 | start_lr: 0.0002 86 | lr: 0.000625 87 | final_lr: 1.0e-06 88 | ema: 89 | - 0.998 90 | - 1.0 91 | -------------------------------------------------------------------------------- /evals/main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import argparse 9 | 10 | import multiprocessing as mp 11 | 12 | import pprint 13 | import yaml 14 | 15 | from src.utils.distributed import init_distributed 16 | 17 | from evals.scaffold import main as eval_main 18 | 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument( 21 | '--fname', type=str, 22 | help='name of config file to load', 23 | default='configs.yaml') 24 | parser.add_argument( 25 | '--devices', type=str, nargs='+', default=['cuda:0'], 26 | help='which devices to use on local machine') 27 | 28 | 29 | def process_main(rank, fname, world_size, devices): 30 | import os 31 | os.environ['CUDA_VISIBLE_DEVICES'] = str(devices[rank].split(':')[-1]) 32 | 33 | import logging 34 | logging.basicConfig() 35 | logger = logging.getLogger() 36 | if rank == 0: 37 | logger.setLevel(logging.INFO) 38 | else: 39 | logger.setLevel(logging.ERROR) 40 | 41 | logger.info(f'called-params {fname}') 42 | 43 | # Load config 44 | params = None 45 | with open(fname, 'r') as y_file: 46 | params = yaml.load(y_file, Loader=yaml.FullLoader) 47 | logger.info('loaded params...') 48 | pp = pprint.PrettyPrinter(indent=4) 49 | pp.pprint(params) 50 | 51 | # Init distributed (access to comm between GPUS on same machine) 52 | world_size, rank = init_distributed(rank_and_world_size=(rank, world_size)) 53 | logger.info(f'Running... (rank: {rank}/{world_size})') 54 | 55 | # Launch the eval with loaded config 56 | eval_main(params['eval_name'], args_eval=params) 57 | 58 | 59 | if __name__ == '__main__': 60 | args = parser.parse_args() 61 | num_gpus = len(args.devices) 62 | mp.set_start_method('spawn') 63 | for rank in range(num_gpus): 64 | mp.Process( 65 | target=process_main, 66 | args=(rank, args.fname, num_gpus, args.devices) 67 | ).start() 68 | -------------------------------------------------------------------------------- /app/main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import argparse 9 | 10 | import multiprocessing as mp 11 | 12 | import pprint 13 | import yaml 14 | 15 | from app.scaffold import main as app_main 16 | from src.utils.distributed import init_distributed 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument( 20 | '--fname', type=str, 21 | help='name of config file to load', 22 | default='configs.yaml') 23 | parser.add_argument( 24 | '--devices', type=str, nargs='+', default=['cuda:0'], 25 | help='which devices to use on local machine') 26 | 27 | 28 | def process_main(rank, fname, world_size, devices): 29 | import os 30 | os.environ['CUDA_VISIBLE_DEVICES'] = str(devices[rank].split(':')[-1]) 31 | 32 | import logging 33 | from src.utils.logging import get_logger 34 | logger = get_logger(force=True) 35 | if rank == 0: 36 | logger.setLevel(logging.INFO) 37 | else: 38 | logger.setLevel(logging.ERROR) 39 | 40 | logger.info(f'called-params {fname}') 41 | 42 | # Load config 43 | params = None 44 | with open(fname, 'r') as y_file: 45 | params = yaml.load(y_file, Loader=yaml.FullLoader) 46 | logger.info('loaded params...') 47 | 48 | # Log config 49 | if rank == 0: 50 | pprint.PrettyPrinter(indent=4).pprint(params) 51 | dump = os.path.join(params['logging']['folder'], 'params-pretrain.yaml') 52 | with open(dump, 'w') as f: 53 | yaml.dump(params, f) 54 | 55 | # Init distributed (access to comm between GPUS on same machine) 56 | world_size, rank = init_distributed(rank_and_world_size=(rank, world_size)) 57 | logger.info(f'Running... (rank: {rank}/{world_size})') 58 | 59 | # Launch the app with loaded config 60 | app_main(params['app'], args=params) 61 | 62 | 63 | if __name__ == '__main__': 64 | args = parser.parse_args() 65 | num_gpus = len(args.devices) 66 | mp.set_start_method('spawn') 67 | for rank in range(num_gpus): 68 | mp.Process( 69 | target=process_main, 70 | args=(rank, args.fname, num_gpus, args.devices) 71 | ).start() 72 | -------------------------------------------------------------------------------- /src/datasets/image_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import os 9 | 10 | from logging import getLogger 11 | 12 | import torch 13 | import torchvision 14 | 15 | _GLOBAL_SEED = 0 16 | logger = getLogger() 17 | 18 | 19 | class ImageFolder(torchvision.datasets.ImageFolder): 20 | 21 | def __init__( 22 | self, 23 | root, 24 | image_folder='imagenet_full_size/061417/', 25 | transform=None, 26 | train=True, 27 | ): 28 | """ 29 | ImageFolder 30 | :param root: root network directory for ImageFolder data 31 | :param image_folder: path to images inside root network directory 32 | :param train: whether to load train data (or validation) 33 | """ 34 | 35 | suffix = 'train/' if train else 'val/' 36 | data_path = os.path.join(root, image_folder, suffix) 37 | logger.info(f'data-path {data_path}') 38 | super(ImageFolder, self).__init__(root=data_path, transform=transform) 39 | logger.info('Initialized ImageFolder') 40 | 41 | 42 | def make_imagedataset( 43 | transform, 44 | batch_size, 45 | collator=None, 46 | pin_mem=True, 47 | num_workers=8, 48 | world_size=1, 49 | rank=0, 50 | root_path=None, 51 | image_folder=None, 52 | training=True, 53 | copy_data=False, 54 | drop_last=True, 55 | persistent_workers=False, 56 | subset_file=None 57 | ): 58 | dataset = ImageFolder( 59 | root=root_path, 60 | image_folder=image_folder, 61 | transform=transform, 62 | train=training) 63 | logger.info('ImageFolder dataset created') 64 | dist_sampler = torch.utils.data.distributed.DistributedSampler( 65 | dataset=dataset, 66 | num_replicas=world_size, 67 | rank=rank) 68 | data_loader = torch.utils.data.DataLoader( 69 | dataset, 70 | collate_fn=collator, 71 | sampler=dist_sampler, 72 | batch_size=batch_size, 73 | drop_last=drop_last, 74 | pin_memory=pin_mem, 75 | num_workers=num_workers, 76 | persistent_workers=persistent_workers) 77 | logger.info('ImageFolder unsupervised data loader created') 78 | 79 | return dataset, data_loader, dist_sampler 80 | -------------------------------------------------------------------------------- /src/utils/schedulers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import math 9 | 10 | 11 | class WarmupCosineSchedule(object): 12 | 13 | def __init__( 14 | self, 15 | optimizer, 16 | warmup_steps, 17 | start_lr, 18 | ref_lr, 19 | T_max, 20 | last_epoch=-1, 21 | final_lr=0. 22 | ): 23 | self.optimizer = optimizer 24 | self.start_lr = start_lr 25 | self.ref_lr = ref_lr 26 | self.final_lr = final_lr 27 | self.warmup_steps = warmup_steps 28 | self.T_max = T_max - warmup_steps 29 | self._step = 0. 30 | 31 | def step(self): 32 | self._step += 1 33 | if self._step < self.warmup_steps: 34 | progress = float(self._step) / float(max(1, self.warmup_steps)) 35 | new_lr = self.start_lr + progress * (self.ref_lr - self.start_lr) 36 | else: 37 | # -- progress after warmup 38 | progress = float(self._step - self.warmup_steps) / float(max(1, self.T_max)) 39 | new_lr = max(self.final_lr, 40 | self.final_lr + (self.ref_lr - self.final_lr) * 0.5 * (1. + math.cos(math.pi * progress))) 41 | 42 | for group in self.optimizer.param_groups: 43 | group['lr'] = new_lr 44 | 45 | return new_lr 46 | 47 | 48 | class CosineWDSchedule(object): 49 | 50 | def __init__( 51 | self, 52 | optimizer, 53 | ref_wd, 54 | T_max, 55 | final_wd=0. 56 | ): 57 | self.optimizer = optimizer 58 | self.ref_wd = ref_wd 59 | self.final_wd = final_wd 60 | self.T_max = T_max 61 | self._step = 0. 62 | 63 | def step(self): 64 | self._step += 1 65 | progress = self._step / self.T_max 66 | new_wd = self.final_wd + (self.ref_wd - self.final_wd) * 0.5 * (1. + math.cos(math.pi * progress)) 67 | 68 | if self.final_wd <= self.ref_wd: 69 | new_wd = max(self.final_wd, new_wd) 70 | else: 71 | new_wd = min(self.final_wd, new_wd) 72 | 73 | for group in self.optimizer.param_groups: 74 | if ('WD_exclude' not in group) or not group['WD_exclude']: 75 | group['weight_decay'] = new_wd 76 | return new_wd 77 | -------------------------------------------------------------------------------- /src/utils/tensors.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import math 9 | 10 | import torch 11 | 12 | from logging import getLogger 13 | 14 | logger = getLogger() 15 | 16 | 17 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 18 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 19 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 20 | def norm_cdf(x): 21 | # Computes standard normal cumulative distribution function 22 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 23 | 24 | with torch.no_grad(): 25 | # Values are generated by using a truncated uniform distribution and 26 | # then using the inverse CDF for the normal distribution. 27 | # Get upper and lower cdf values 28 | l = norm_cdf((a - mean) / std) 29 | u = norm_cdf((b - mean) / std) 30 | 31 | # Uniformly fill tensor with values from [l, u], then translate to 32 | # [2l-1, 2u-1]. 33 | tensor.uniform_(2 * l - 1, 2 * u - 1) 34 | 35 | # Use inverse cdf transform for normal distribution to get truncated 36 | # standard normal 37 | tensor.erfinv_() 38 | 39 | # Transform to proper mean, std 40 | tensor.mul_(std * math.sqrt(2.)) 41 | tensor.add_(mean) 42 | 43 | # Clamp to ensure it's in the proper range 44 | tensor.clamp_(min=a, max=b) 45 | return tensor 46 | 47 | 48 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 49 | # type: (Tensor, float, float, float, float) -> Tensor 50 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 51 | 52 | 53 | def apply_masks(x, masks): 54 | """ 55 | :param x: tensor of shape [B (batch-size), N (num-patches), D (feature-dim)] 56 | :param masks: list of tensors containing indices of patches [0,N) to keep 57 | """ 58 | all_x = [] 59 | for m in masks: 60 | mask_keep = m.unsqueeze(-1).repeat(1, 1, x.size(-1)) 61 | all_x += [torch.gather(x, dim=1, index=mask_keep)] 62 | return torch.cat(all_x, dim=0) 63 | 64 | 65 | def repeat_interleave_batch(x, B, repeat): 66 | N = len(x) // B 67 | x = torch.cat([ 68 | torch.cat([x[i*B:(i+1)*B] for _ in range(repeat)], dim=0) 69 | for i in range(N) 70 | ], dim=0) 71 | return x 72 | -------------------------------------------------------------------------------- /src/datasets/data_manager.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from logging import getLogger 9 | 10 | 11 | _GLOBAL_SEED = 0 12 | logger = getLogger() 13 | 14 | 15 | def init_data( 16 | batch_size, 17 | transform=None, 18 | shared_transform=None, 19 | data='ImageNet', 20 | collator=None, 21 | pin_mem=True, 22 | num_workers=8, 23 | world_size=1, 24 | rank=0, 25 | root_path=None, 26 | image_folder=None, 27 | training=True, 28 | copy_data=False, 29 | drop_last=True, 30 | tokenize_txt=True, 31 | subset_file=None, 32 | clip_len=8, 33 | frame_sample_rate=2, 34 | duration=None, 35 | num_clips=1, 36 | random_clip_sampling=True, 37 | allow_clip_overlap=False, 38 | filter_short_videos=False, 39 | filter_long_videos=int(1e9), 40 | decode_one_clip=True, 41 | datasets_weights=None, 42 | persistent_workers=False, 43 | repeat_wds=False, 44 | ipe=300, 45 | log_dir=None, 46 | ): 47 | 48 | if (data.lower() == 'imagenet') \ 49 | or (data.lower() == 'inat21') \ 50 | or (data.lower() == 'places205'): 51 | from src.datasets.image_dataset import make_imagedataset 52 | dataset, data_loader, dist_sampler = make_imagedataset( 53 | transform=transform, 54 | batch_size=batch_size, 55 | collator=collator, 56 | pin_mem=pin_mem, 57 | training=training, 58 | num_workers=num_workers, 59 | world_size=world_size, 60 | rank=rank, 61 | root_path=root_path, 62 | image_folder=image_folder, 63 | persistent_workers=persistent_workers, 64 | copy_data=copy_data, 65 | drop_last=drop_last, 66 | subset_file=subset_file) 67 | 68 | elif data.lower() == 'videodataset': 69 | from src.datasets.video_dataset import make_videodataset 70 | dataset, data_loader, dist_sampler = make_videodataset( 71 | data_paths=root_path, 72 | batch_size=batch_size, 73 | frames_per_clip=clip_len, 74 | frame_step=frame_sample_rate, 75 | duration=duration, 76 | num_clips=num_clips, 77 | random_clip_sampling=random_clip_sampling, 78 | allow_clip_overlap=allow_clip_overlap, 79 | filter_short_videos=filter_short_videos, 80 | filter_long_videos=filter_long_videos, 81 | shared_transform=shared_transform, 82 | transform=transform, 83 | datasets_weights=datasets_weights, 84 | collator=collator, 85 | num_workers=num_workers, 86 | world_size=world_size, 87 | rank=rank, 88 | drop_last=drop_last, 89 | log_dir=log_dir) 90 | 91 | return (data_loader, dist_sampler) 92 | -------------------------------------------------------------------------------- /src/datasets/utils/weighted_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from typing import Iterator, Optional 9 | from operator import itemgetter 10 | import numpy as np 11 | 12 | import torch 13 | from torch.utils.data import ( 14 | Dataset, 15 | Sampler, 16 | DistributedSampler, 17 | WeightedRandomSampler 18 | ) 19 | 20 | 21 | class DatasetFromSampler(Dataset): 22 | 23 | def __init__(self, sampler: Sampler): 24 | self.sampler = sampler 25 | self.sampler_list = None 26 | 27 | def __getitem__(self, index: int): 28 | if self.sampler_list is None: 29 | self.sampler_list = list(self.sampler) 30 | return self.sampler_list[index] 31 | 32 | def __len__(self) -> int: 33 | return len(self.sampler) 34 | 35 | 36 | class DistributedSamplerWrapper(DistributedSampler): 37 | """ Convert any Pytorch Sampler to a DistributedSampler """ 38 | 39 | def __init__( 40 | self, 41 | sampler, 42 | num_replicas: Optional[int] = None, 43 | rank: Optional[int] = None, 44 | shuffle: bool = True, 45 | ): 46 | super(DistributedSamplerWrapper, self).__init__( 47 | DatasetFromSampler(sampler), 48 | num_replicas=num_replicas, 49 | rank=rank, 50 | shuffle=shuffle, 51 | ) 52 | self.sampler = sampler 53 | 54 | def __iter__(self) -> Iterator[int]: 55 | self.dataset = DatasetFromSampler(self.sampler) 56 | indexes_of_indexes = super().__iter__() 57 | subsampler_indexes = self.dataset 58 | return iter(itemgetter(*indexes_of_indexes)(subsampler_indexes)) 59 | 60 | 61 | class CustomWeightedRandomSampler(WeightedRandomSampler): 62 | """ Generalized WeightedRandomSampler to allow for more than 2^24 samples """ 63 | 64 | def __init__(self, *args, **kwargs): 65 | super().__init__(*args, **kwargs) 66 | 67 | def __iter__(self): 68 | rand_tensor = np.random.choice( 69 | range(0, len(self.weights)), 70 | size=self.num_samples, 71 | p=self.weights.numpy() / torch.sum(self.weights).numpy(), 72 | replace=self.replacement 73 | ) 74 | rand_tensor = torch.from_numpy(rand_tensor) 75 | return iter(rand_tensor.tolist()) 76 | 77 | 78 | class DistributedWeightedSampler(DistributedSamplerWrapper): 79 | 80 | def __init__( 81 | self, 82 | weights, 83 | num_replicas: Optional[int] = None, 84 | rank: Optional[int] = None, 85 | shuffle: bool = True, 86 | ): 87 | weighted_sampler = CustomWeightedRandomSampler( 88 | weights=weights, 89 | num_samples=len(weights), 90 | replacement=False) 91 | 92 | super(DistributedWeightedSampler, self).__init__( 93 | sampler=weighted_sampler, 94 | num_replicas=num_replicas, 95 | rank=rank, 96 | shuffle=shuffle, 97 | ) 98 | -------------------------------------------------------------------------------- /src/datasets/utils/video/functional.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import numbers 9 | import cv2 10 | import numpy as np 11 | import PIL 12 | import torch 13 | 14 | 15 | def _is_tensor_clip(clip): 16 | return torch.is_tensor(clip) and clip.ndimension() == 4 17 | 18 | 19 | def crop_clip(clip, min_h, min_w, h, w): 20 | if isinstance(clip[0], np.ndarray): 21 | cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip] 22 | 23 | elif isinstance(clip[0], PIL.Image.Image): 24 | cropped = [ 25 | img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip 26 | ] 27 | else: 28 | raise TypeError('Expected numpy.ndarray or PIL.Image' + 29 | 'but got list of {0}'.format(type(clip[0]))) 30 | return cropped 31 | 32 | 33 | def resize_clip(clip, size, interpolation='bilinear'): 34 | if isinstance(clip[0], np.ndarray): 35 | if isinstance(size, numbers.Number): 36 | im_h, im_w, im_c = clip[0].shape 37 | # Min spatial dim already matches minimal size 38 | if (im_w <= im_h and im_w == size) or (im_h <= im_w 39 | and im_h == size): 40 | return clip 41 | new_h, new_w = get_resize_sizes(im_h, im_w, size) 42 | size = (new_w, new_h) 43 | else: 44 | size = size[0], size[1] 45 | if interpolation == 'bilinear': 46 | np_inter = cv2.INTER_LINEAR 47 | else: 48 | np_inter = cv2.INTER_NEAREST 49 | scaled = [ 50 | cv2.resize(img, size, interpolation=np_inter) for img in clip 51 | ] 52 | elif isinstance(clip[0], PIL.Image.Image): 53 | if isinstance(size, numbers.Number): 54 | im_w, im_h = clip[0].size 55 | # Min spatial dim already matches minimal size 56 | if (im_w <= im_h and im_w == size) or (im_h <= im_w 57 | and im_h == size): 58 | return clip 59 | new_h, new_w = get_resize_sizes(im_h, im_w, size) 60 | size = (new_w, new_h) 61 | else: 62 | size = size[1], size[0] 63 | if interpolation == 'bilinear': 64 | pil_inter = PIL.Image.BILINEAR 65 | else: 66 | pil_inter = PIL.Image.NEAREST 67 | scaled = [img.resize(size, pil_inter) for img in clip] 68 | else: 69 | raise TypeError('Expected numpy.ndarray or PIL.Image' + 70 | 'but got list of {0}'.format(type(clip[0]))) 71 | return scaled 72 | 73 | 74 | def get_resize_sizes(im_h, im_w, size): 75 | if im_w < im_h: 76 | ow = size 77 | oh = int(size * im_h / im_w) 78 | else: 79 | oh = size 80 | ow = int(size * im_w / im_h) 81 | return oh, ow 82 | 83 | 84 | def normalize(clip, mean, std, inplace=False): 85 | if not _is_tensor_clip(clip): 86 | raise TypeError('tensor is not a torch clip.') 87 | 88 | if not inplace: 89 | clip = clip.clone() 90 | 91 | dtype = clip.dtype 92 | mean = torch.as_tensor(mean, dtype=dtype, device=clip.device) 93 | std = torch.as_tensor(std, dtype=dtype, device=clip.device) 94 | clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None]) 95 | 96 | return clip 97 | -------------------------------------------------------------------------------- /src/utils/distributed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import os 9 | 10 | import torch 11 | import torch.distributed as dist 12 | 13 | from logging import getLogger 14 | 15 | logger = getLogger() 16 | 17 | 18 | def init_distributed(port=37123, rank_and_world_size=(None, None)): 19 | 20 | if dist.is_available() and dist.is_initialized(): 21 | return dist.get_world_size(), dist.get_rank() 22 | 23 | rank, world_size = rank_and_world_size 24 | os.environ['MASTER_ADDR'] = 'localhost' 25 | 26 | if (rank is None) or (world_size is None): 27 | try: 28 | world_size = int(os.environ['SLURM_NTASKS']) 29 | rank = int(os.environ['SLURM_PROCID']) 30 | os.environ['MASTER_ADDR'] = os.environ['HOSTNAME'] 31 | except Exception: 32 | logger.info('SLURM vars not set (distributed training not available)') 33 | world_size, rank = 1, 0 34 | return world_size, rank 35 | 36 | try: 37 | os.environ['MASTER_PORT'] = str(port) 38 | torch.distributed.init_process_group( 39 | backend='nccl', 40 | world_size=world_size, 41 | rank=rank 42 | ) 43 | except Exception as e: 44 | world_size, rank = 1, 0 45 | logger.info(f'Rank: {rank}. Distributed training not available {e}') 46 | 47 | return world_size, rank 48 | 49 | 50 | class AllGather(torch.autograd.Function): 51 | 52 | @staticmethod 53 | def forward(ctx, x): 54 | if ( 55 | dist.is_available() 56 | and dist.is_initialized() 57 | and (dist.get_world_size() > 1) 58 | ): 59 | x = x.contiguous() 60 | outputs = [torch.zeros_like(x) for _ in range(dist.get_world_size())] 61 | dist.all_gather(outputs, x) 62 | return torch.cat(outputs, 0) 63 | return x 64 | 65 | @staticmethod 66 | def backward(ctx, grads): 67 | if ( 68 | dist.is_available() 69 | and dist.is_initialized() 70 | and (dist.get_world_size() > 1) 71 | ): 72 | s = (grads.shape[0] // dist.get_world_size()) * dist.get_rank() 73 | e = (grads.shape[0] // dist.get_world_size()) * (dist.get_rank() + 1) 74 | grads = grads.contiguous() 75 | dist.all_reduce(grads) 76 | return grads[s:e] 77 | return grads 78 | 79 | 80 | class AllReduceSum(torch.autograd.Function): 81 | 82 | @staticmethod 83 | def forward(ctx, x): 84 | if ( 85 | dist.is_available() 86 | and dist.is_initialized() 87 | and (dist.get_world_size() > 1) 88 | ): 89 | x = x.contiguous() 90 | dist.all_reduce(x) 91 | return x 92 | 93 | @staticmethod 94 | def backward(ctx, grads): 95 | return grads 96 | 97 | 98 | class AllReduce(torch.autograd.Function): 99 | 100 | @staticmethod 101 | def forward(ctx, x): 102 | if ( 103 | dist.is_available() 104 | and dist.is_initialized() 105 | and (dist.get_world_size() > 1) 106 | ): 107 | x = x.contiguous() / dist.get_world_size() 108 | dist.all_reduce(x) 109 | return x 110 | 111 | @staticmethod 112 | def backward(ctx, grads): 113 | return grads 114 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /src/models/utils/pos_embs.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import numpy as np 9 | 10 | 11 | def get_3d_sincos_pos_embed( 12 | embed_dim, 13 | grid_size, 14 | grid_depth, 15 | cls_token=False, 16 | uniform_power=False 17 | ): 18 | """ 19 | grid_size: int of the grid height and width 20 | grid_depth: int of the grid depth 21 | returns: 22 | pos_embed: [grid_depth*grid_size*grid_size, embed_dim] (w/o cls_token) 23 | or [1+grid_depth*grid_size*grid_size, embed_dim] (w/ cls_token) 24 | """ 25 | grid_d = np.arange(grid_depth, dtype=float) 26 | grid_h = np.arange(grid_size, dtype=float) 27 | grid_w = np.arange(grid_size, dtype=float) 28 | grid_h, grid_d, grid_w = np.meshgrid(grid_h, grid_d, grid_w) # order of meshgrid is very important for indexing as [d,h,w] 29 | 30 | if not uniform_power: 31 | h_embed_dim = embed_dim // 4 32 | w_embed_dim = embed_dim // 4 33 | d_embed_dim = embed_dim // 2 34 | else: 35 | h_embed_dim = w_embed_dim = d_embed_dim = int(np.ceil(embed_dim/6)*2) 36 | 37 | emb_h = get_1d_sincos_pos_embed_from_grid(h_embed_dim, grid_h) # (T*H*W, D1) 38 | emb_w = get_1d_sincos_pos_embed_from_grid(w_embed_dim, grid_w) # (T*H*W, D2) 39 | emb_d = get_1d_sincos_pos_embed_from_grid(d_embed_dim, grid_d) # (T*H*W, D3) 40 | pos_embed = np.concatenate([emb_d, emb_h, emb_w], axis=1) 41 | pos_embed = pos_embed[:, :embed_dim] 42 | if cls_token: 43 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 44 | return pos_embed 45 | 46 | 47 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 48 | """ 49 | grid_size: int of the grid height and width 50 | returns: 51 | pos_embed: [grid_size*grid_size, embed_dim] (w/o cls_token) 52 | or [1+grid_size*grid_size, embed_dim] (w/ cls_token) 53 | """ 54 | grid_h = np.arange(grid_size, dtype=float) 55 | grid_w = np.arange(grid_size, dtype=float) 56 | grid_w, grid_h = np.meshgrid(grid_w, grid_h) # order of meshgrid is very important for indexing as [h, w] 57 | 58 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid_h) # (H*W, D/2) 59 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid_w) # (H*W, D/2) 60 | pos_embed = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 61 | if cls_token: 62 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 63 | return pos_embed 64 | 65 | 66 | def get_1d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 67 | """ 68 | embed_dim: output dimension for each position 69 | grid_size: int of the grid length 70 | returns: 71 | pos_embed: [grid_size, embed_dim] (w/o cls_token) 72 | or [1+grid_size, embed_dim] (w/ cls_token) 73 | """ 74 | grid = np.arange(grid_size, dtype=float) 75 | pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid) 76 | if cls_token: 77 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 78 | return pos_embed 79 | 80 | 81 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 82 | """ 83 | embed_dim: output dimension for each position 84 | pos: a list of positions to be encoded: size (M,) 85 | returns: (M, D) 86 | """ 87 | assert embed_dim % 2 == 0 88 | omega = np.arange(embed_dim // 2, dtype=float) 89 | omega /= embed_dim / 2. 90 | omega = 1. / 10000**omega # (D/2,) 91 | 92 | pos = pos.reshape(-1) # (M,) 93 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 94 | 95 | emb_sin = np.sin(out) # (M, D/2) 96 | emb_cos = np.cos(out) # (M, D/2) 97 | 98 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 99 | return emb 100 | -------------------------------------------------------------------------------- /src/utils/logging.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import logging 9 | import sys 10 | 11 | import torch 12 | 13 | 14 | def gpu_timer(closure, log_timings=True): 15 | """ Helper to time gpu-time to execute closure() """ 16 | log_timings = log_timings and torch.cuda.is_available() 17 | 18 | elapsed_time = -1. 19 | if log_timings: 20 | start = torch.cuda.Event(enable_timing=True) 21 | end = torch.cuda.Event(enable_timing=True) 22 | start.record() 23 | 24 | result = closure() 25 | 26 | if log_timings: 27 | end.record() 28 | torch.cuda.synchronize() 29 | elapsed_time = start.elapsed_time(end) 30 | 31 | return result, elapsed_time 32 | 33 | 34 | LOG_FORMAT = "[%(levelname)-8s][%(asctime)s][%(funcName)-25s] %(message)s" 35 | DATE_FORMAT = "%Y-%m-%d %H:%M:%S" 36 | 37 | 38 | def get_logger(name=None, force=False): 39 | logging.basicConfig(stream=sys.stdout, level=logging.INFO, 40 | format=LOG_FORMAT, datefmt=DATE_FORMAT, force=force) 41 | return logging.getLogger(name=name) 42 | 43 | 44 | class CSVLogger(object): 45 | 46 | def __init__(self, fname, *argv): 47 | self.fname = fname 48 | self.types = [] 49 | # -- print headers 50 | with open(self.fname, '+a') as f: 51 | for i, v in enumerate(argv, 1): 52 | self.types.append(v[0]) 53 | if i < len(argv): 54 | print(v[1], end=',', file=f) 55 | else: 56 | print(v[1], end='\n', file=f) 57 | 58 | def log(self, *argv): 59 | with open(self.fname, '+a') as f: 60 | for i, tv in enumerate(zip(self.types, argv), 1): 61 | end = ',' if i < len(argv) else '\n' 62 | print(tv[0] % tv[1], end=end, file=f) 63 | 64 | 65 | class AverageMeter(object): 66 | """computes and stores the average and current value""" 67 | 68 | def __init__(self): 69 | self.reset() 70 | 71 | def reset(self): 72 | self.val = 0 73 | self.avg = 0 74 | self.max = float('-inf') 75 | self.min = float('inf') 76 | self.sum = 0 77 | self.count = 0 78 | 79 | def update(self, val, n=1): 80 | self.val = val 81 | try: 82 | self.max = max(val, self.max) 83 | self.min = min(val, self.min) 84 | except Exception: 85 | pass 86 | self.sum += val * n 87 | self.count += n 88 | self.avg = self.sum / self.count 89 | 90 | 91 | def grad_logger(named_params): 92 | stats = AverageMeter() 93 | stats.first_layer = None 94 | stats.last_layer = None 95 | for n, p in named_params: 96 | if (p.grad is not None) and not (n.endswith('.bias') or len(p.shape) == 1): 97 | grad_norm = float(torch.norm(p.grad.data)) 98 | stats.update(grad_norm) 99 | if 'qkv' in n: 100 | stats.last_layer = grad_norm 101 | if stats.first_layer is None: 102 | stats.first_layer = grad_norm 103 | if stats.first_layer is None or stats.last_layer is None: 104 | stats.first_layer = stats.last_layer = 0. 105 | return stats 106 | 107 | 108 | def adamw_logger(optimizer): 109 | """ logging magnitude of first and second momentum buffers in adamw """ 110 | # TODO: assert that optimizer is instance of torch.optim.AdamW 111 | state = optimizer.state_dict().get('state') 112 | exp_avg_stats = AverageMeter() 113 | exp_avg_sq_stats = AverageMeter() 114 | for key in state: 115 | s = state.get(key) 116 | exp_avg_stats.update(float(s.get('exp_avg').abs().mean())) 117 | exp_avg_sq_stats.update(float(s.get('exp_avg_sq').abs().mean())) 118 | return {'exp_avg': exp_avg_stats, 'exp_avg_sq': exp_avg_sq_stats} 119 | -------------------------------------------------------------------------------- /src/masks/random_tube.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from multiprocessing import Value 9 | 10 | from logging import getLogger 11 | 12 | import torch 13 | import numpy as np 14 | 15 | _GLOBAL_SEED = 0 16 | logger = getLogger() 17 | 18 | 19 | class MaskCollator(object): 20 | 21 | def __init__( 22 | self, 23 | cfgs_mask, 24 | crop_size=(224, 224), 25 | num_frames=16, 26 | patch_size=(16, 16), 27 | tubelet_size=2, 28 | ): 29 | super(MaskCollator, self).__init__() 30 | 31 | self.mask_generators = [] 32 | for m in cfgs_mask: 33 | mask_generator = _MaskGenerator( 34 | crop_size=crop_size, 35 | num_frames=num_frames, 36 | spatial_patch_size=patch_size, 37 | temporal_patch_size=tubelet_size, 38 | ratio=m.get('ratio'), 39 | ) 40 | self.mask_generators.append(mask_generator) 41 | 42 | def step(self): 43 | for mask_generator in self.mask_generators: 44 | mask_generator.step() 45 | 46 | def __call__(self, batch): 47 | 48 | batch_size = len(batch) 49 | collated_batch = torch.utils.data.default_collate(batch) 50 | 51 | collated_masks_pred, collated_masks_enc = [], [] 52 | for i, mask_generator in enumerate(self.mask_generators): 53 | masks_enc, masks_pred = mask_generator(batch_size) 54 | collated_masks_enc.append(masks_enc) 55 | collated_masks_pred.append(masks_pred) 56 | 57 | return collated_batch, collated_masks_enc, collated_masks_pred 58 | 59 | 60 | class _MaskGenerator(object): 61 | 62 | def __init__( 63 | self, 64 | crop_size=(224, 224), 65 | num_frames=16, 66 | spatial_patch_size=(16, 16), 67 | temporal_patch_size=2, 68 | ratio=0.9, 69 | ): 70 | super(_MaskGenerator, self).__init__() 71 | if not isinstance(crop_size, tuple): 72 | crop_size = (crop_size, ) * 2 73 | self.crop_size = crop_size 74 | self.height, self.width = crop_size[0] // spatial_patch_size, crop_size[1] // spatial_patch_size 75 | self.duration = num_frames // temporal_patch_size 76 | 77 | self.spatial_patch_size = spatial_patch_size 78 | self.temporal_patch_size = temporal_patch_size 79 | self.num_patches_spatial = self.height*self.width 80 | 81 | self.ratio = ratio 82 | 83 | self.num_keep_spatial = int(self.num_patches_spatial*(1.-self.ratio)) 84 | self.num_keep = self.num_keep_spatial * self.duration 85 | 86 | self._itr_counter = Value('i', -1) # collator is shared across worker processes 87 | 88 | def step(self): 89 | i = self._itr_counter 90 | with i.get_lock(): 91 | i.value += 1 92 | v = i.value 93 | return v 94 | 95 | def __call__(self, batch_size): 96 | def sample_mask(): 97 | mask = np.hstack([ 98 | np.zeros(self.num_patches_spatial - self.num_keep_spatial), 99 | np.ones(self.num_keep_spatial), 100 | ]) 101 | np.random.shuffle(mask) 102 | mask = torch.tensor(np.tile(mask, (self.duration, 1))) 103 | mask = mask.flatten() 104 | mask_p = torch.argwhere(mask == 0).squeeze() 105 | mask_e = torch.nonzero(mask).squeeze() 106 | return mask_e, mask_p 107 | 108 | collated_masks_pred, collated_masks_enc = [], [] 109 | for _ in range(batch_size): 110 | mask_e, mask_p = sample_mask() 111 | collated_masks_enc.append(mask_e) 112 | collated_masks_pred.append(mask_p) 113 | 114 | collated_masks_enc = torch.utils.data.default_collate(collated_masks_enc) 115 | collated_masks_pred = torch.utils.data.default_collate(collated_masks_pred) 116 | 117 | return collated_masks_enc, collated_masks_pred 118 | -------------------------------------------------------------------------------- /src/models/attentive_pooler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import math 9 | 10 | import torch 11 | import torch.nn as nn 12 | 13 | from src.models.utils.modules import ( 14 | Block, 15 | CrossAttention, 16 | CrossAttentionBlock 17 | ) 18 | from src.utils.tensors import trunc_normal_ 19 | 20 | 21 | class AttentivePooler(nn.Module): 22 | """ Attentive Pooler """ 23 | def __init__( 24 | self, 25 | num_queries=1, 26 | embed_dim=768, 27 | num_heads=12, 28 | mlp_ratio=4.0, 29 | depth=1, 30 | norm_layer=nn.LayerNorm, 31 | init_std=0.02, 32 | qkv_bias=True, 33 | complete_block=True 34 | ): 35 | super().__init__() 36 | self.query_tokens = nn.Parameter(torch.zeros(1, num_queries, embed_dim)) 37 | 38 | self.complete_block = complete_block 39 | if complete_block: 40 | self.cross_attention_block = CrossAttentionBlock( 41 | dim=embed_dim, 42 | num_heads=num_heads, 43 | mlp_ratio=mlp_ratio, 44 | qkv_bias=qkv_bias, 45 | norm_layer=norm_layer) 46 | else: 47 | self.cross_attention_block = CrossAttention( 48 | dim=embed_dim, 49 | num_heads=num_heads, 50 | qkv_bias=qkv_bias) 51 | 52 | self.blocks = None 53 | if depth > 1: 54 | self.blocks = nn.ModuleList([ 55 | Block( 56 | dim=embed_dim, 57 | num_heads=num_heads, 58 | mlp_ratio=mlp_ratio, 59 | qkv_bias=qkv_bias, 60 | qk_scale=False, 61 | norm_layer=norm_layer) 62 | for i in range(depth-1)]) 63 | 64 | self.init_std = init_std 65 | trunc_normal_(self.query_tokens, std=self.init_std) 66 | self.apply(self._init_weights) 67 | self._rescale_blocks() 68 | 69 | def _rescale_blocks(self): 70 | def rescale(param, layer_id): 71 | param.div_(math.sqrt(2.0 * layer_id)) 72 | 73 | if self.complete_block: 74 | rescale(self.cross_attention_block.xattn.proj.weight.data, 1) 75 | rescale(self.cross_attention_block.mlp.fc2.weight.data, 1) 76 | else: 77 | rescale(self.cross_attention_block.proj.weight.data, 1) 78 | if self.blocks is not None: 79 | for layer_id, layer in enumerate(self.blocks, 1): 80 | rescale(layer.attn.proj.weight.data, layer_id + 1) 81 | rescale(layer.mlp.fc2.weight.data, layer_id + 1) 82 | 83 | def _init_weights(self, m): 84 | if isinstance(m, nn.Linear): 85 | trunc_normal_(m.weight, std=self.init_std) 86 | if isinstance(m, nn.Linear) and m.bias is not None: 87 | nn.init.constant_(m.bias, 0) 88 | elif isinstance(m, nn.LayerNorm): 89 | nn.init.constant_(m.bias, 0) 90 | nn.init.constant_(m.weight, 1.0) 91 | elif isinstance(m, nn.Conv2d): 92 | trunc_normal_(m.weight, std=self.init_std) 93 | if m.bias is not None: 94 | nn.init.constant_(m.bias, 0) 95 | 96 | def forward(self, x): 97 | q = self.query_tokens.repeat(len(x), 1, 1) 98 | q = self.cross_attention_block(q, x) 99 | if self.blocks is not None: 100 | for blk in self.blocks: 101 | q = blk(q) 102 | return q 103 | 104 | 105 | class AttentiveClassifier(nn.Module): 106 | """ Attentive Classifier """ 107 | def __init__( 108 | self, 109 | embed_dim=768, 110 | num_heads=12, 111 | mlp_ratio=4.0, 112 | depth=1, 113 | norm_layer=nn.LayerNorm, 114 | init_std=0.02, 115 | qkv_bias=True, 116 | num_classes=1000, 117 | complete_block=True, 118 | ): 119 | super().__init__() 120 | self.pooler = AttentivePooler( 121 | num_queries=1, 122 | embed_dim=embed_dim, 123 | num_heads=num_heads, 124 | mlp_ratio=mlp_ratio, 125 | depth=depth, 126 | norm_layer=norm_layer, 127 | init_std=init_std, 128 | qkv_bias=qkv_bias, 129 | complete_block=complete_block, 130 | ) 131 | self.linear = nn.Linear(embed_dim, num_classes, bias=True) 132 | 133 | def forward(self, x): 134 | x = self.pooler(x).squeeze(1) 135 | x = self.linear(x) 136 | return x 137 | -------------------------------------------------------------------------------- /src/datasets/utils/video/volume_transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import numpy as np 9 | from PIL import Image 10 | 11 | import torch 12 | 13 | 14 | def convert_img(img): 15 | """Converts (H, W, C) numpy.ndarray to (C, W, H) format""" 16 | if len(img.shape) == 3: 17 | img = img.transpose(2, 0, 1) 18 | if len(img.shape) == 2: 19 | img = np.expand_dims(img, 0) 20 | return img 21 | 22 | 23 | class ClipToTensor(object): 24 | """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255] 25 | to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0] 26 | """ 27 | 28 | def __init__(self, channel_nb=3, div_255=True, numpy=False): 29 | self.channel_nb = channel_nb 30 | self.div_255 = div_255 31 | self.numpy = numpy 32 | 33 | def __call__(self, clip): 34 | """ 35 | Args: clip (list of numpy.ndarray): clip (list of images) 36 | to be converted to tensor. 37 | """ 38 | # Retrieve shape 39 | if isinstance(clip[0], np.ndarray): 40 | h, w, ch = clip[0].shape 41 | assert ch == self.channel_nb, "Got {0} instead of 3 channels".format(ch) 42 | elif isinstance(clip[0], Image.Image): 43 | w, h = clip[0].size 44 | else: 45 | raise TypeError( 46 | "Expected numpy.ndarray or PIL.Image\ 47 | but got list of {0}".format( 48 | type(clip[0]) 49 | ) 50 | ) 51 | 52 | np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)]) 53 | 54 | # Convert 55 | for img_idx, img in enumerate(clip): 56 | if isinstance(img, np.ndarray): 57 | pass 58 | elif isinstance(img, Image.Image): 59 | img = np.array(img, copy=False) 60 | else: 61 | raise TypeError( 62 | "Expected numpy.ndarray or PIL.Image\ 63 | but got list of {0}".format( 64 | type(clip[0]) 65 | ) 66 | ) 67 | img = convert_img(img) 68 | np_clip[:, img_idx, :, :] = img 69 | if self.numpy: 70 | if self.div_255: 71 | np_clip = np_clip / 255.0 72 | return np_clip 73 | 74 | else: 75 | tensor_clip = torch.from_numpy(np_clip) 76 | 77 | if not isinstance(tensor_clip, torch.FloatTensor): 78 | tensor_clip = tensor_clip.float() 79 | if self.div_255: 80 | tensor_clip = torch.div(tensor_clip, 255) 81 | return tensor_clip 82 | 83 | 84 | # Note this norms data to -1/1 85 | class ClipToTensor_K(object): 86 | """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255] 87 | to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0] 88 | """ 89 | 90 | def __init__(self, channel_nb=3, div_255=True, numpy=False): 91 | self.channel_nb = channel_nb 92 | self.div_255 = div_255 93 | self.numpy = numpy 94 | 95 | def __call__(self, clip): 96 | """ 97 | Args: clip (list of numpy.ndarray): clip (list of images) 98 | to be converted to tensor. 99 | """ 100 | # Retrieve shape 101 | if isinstance(clip[0], np.ndarray): 102 | h, w, ch = clip[0].shape 103 | assert ch == self.channel_nb, "Got {0} instead of 3 channels".format(ch) 104 | elif isinstance(clip[0], Image.Image): 105 | w, h = clip[0].size 106 | else: 107 | raise TypeError( 108 | "Expected numpy.ndarray or PIL.Image\ 109 | but got list of {0}".format( 110 | type(clip[0]) 111 | ) 112 | ) 113 | 114 | np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)]) 115 | 116 | # Convert 117 | for img_idx, img in enumerate(clip): 118 | if isinstance(img, np.ndarray): 119 | pass 120 | elif isinstance(img, Image.Image): 121 | img = np.array(img, copy=False) 122 | else: 123 | raise TypeError( 124 | "Expected numpy.ndarray or PIL.Image\ 125 | but got list of {0}".format( 126 | type(clip[0]) 127 | ) 128 | ) 129 | img = convert_img(img) 130 | np_clip[:, img_idx, :, :] = img 131 | if self.numpy: 132 | if self.div_255: 133 | np_clip = (np_clip - 127.5) / 127.5 134 | return np_clip 135 | 136 | else: 137 | tensor_clip = torch.from_numpy(np_clip) 138 | 139 | if not isinstance(tensor_clip, torch.FloatTensor): 140 | tensor_clip = tensor_clip.float() 141 | if self.div_255: 142 | tensor_clip = torch.div(torch.sub(tensor_clip, 127.5), 127.5) 143 | return tensor_clip 144 | 145 | 146 | class ToTensor(object): 147 | """Converts numpy array to tensor""" 148 | 149 | def __call__(self, array): 150 | tensor = torch.from_numpy(array) 151 | return tensor 152 | -------------------------------------------------------------------------------- /app/main_distributed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import argparse 9 | import os 10 | import pprint 11 | import yaml 12 | 13 | import submitit 14 | 15 | from app.scaffold import main as app_main 16 | from src.utils.logging import get_logger 17 | 18 | logger = get_logger(force=True) 19 | 20 | 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument( 23 | '--folder', type=str, 24 | help='location to save submitit logs', 25 | default='/fsx-jepa/massran/submitit/') 26 | parser.add_argument( 27 | '--exclude', type=str, 28 | help='nodes to exclude from training', 29 | default=None) 30 | parser.add_argument( 31 | '--batch-launch', action='store_true', 32 | help='whether fname points to a file to batch-lauch several config files') 33 | parser.add_argument( 34 | '--fname', type=str, 35 | help='yaml file containing config file names to launch', 36 | default='configs.yaml') 37 | parser.add_argument( 38 | '--partition', type=str, 39 | help='cluster partition to submit jobs on') 40 | parser.add_argument( 41 | '--time', type=int, default=4300, 42 | help='time in minutes to run job') 43 | 44 | 45 | class Trainer: 46 | 47 | def __init__(self, args_pretrain, load_model=None): 48 | self.app = args_pretrain['app'] 49 | self.args_pretrain = args_pretrain 50 | self.load_model = load_model 51 | 52 | def __call__(self): 53 | app = self.app 54 | params = self.args_pretrain 55 | load_model = self.load_model 56 | 57 | logger.info('loaded pretrain params...') 58 | pp = pprint.PrettyPrinter(indent=4) 59 | pp.pprint(params) 60 | 61 | # Launch app with loaded config 62 | resume_preempt = False if load_model is None else load_model 63 | app_main(app, args=params, resume_preempt=resume_preempt) 64 | 65 | def checkpoint(self): 66 | fb_trainer = Trainer(self.args_pretrain, True) 67 | return submitit.helpers.DelayedSubmission(fb_trainer,) 68 | 69 | 70 | def launch_app_with_parsed_args( 71 | args_for_pretrain, 72 | submitit_folder, 73 | partition, 74 | timeout=4300, 75 | nodes=1, 76 | tasks_per_node=1, 77 | exclude_nodes=None 78 | ): 79 | executor = submitit.AutoExecutor( 80 | folder=os.path.join(submitit_folder, 'job_%j'), 81 | slurm_max_num_timeout=20) 82 | executor.update_parameters( 83 | slurm_partition=partition, 84 | slurm_mem_per_gpu='55G', 85 | timeout_min=timeout, 86 | nodes=nodes, 87 | tasks_per_node=tasks_per_node, 88 | cpus_per_task=12, 89 | gpus_per_node=tasks_per_node) 90 | 91 | if args.exclude is not None: 92 | executor.update_parameters(slurm_exclude=args.exclude) 93 | 94 | jobs, trainers = [], [] 95 | with executor.batch(): 96 | for ap in args_for_pretrain: 97 | fb_trainer = Trainer(ap) 98 | job = executor.submit(fb_trainer,) 99 | trainers.append(fb_trainer) 100 | jobs.append(job) 101 | 102 | for job in jobs: 103 | print(job.job_id) 104 | 105 | 106 | def launch(): 107 | 108 | # ---------------------------------------------------------------------- # 109 | # 1. Put config file names in a list 110 | # ---------------------------------------------------------------------- # 111 | config_fnames = [args.fname] 112 | 113 | # -- If batch-launch is True, then the args.fname yaml file is not a 114 | # -- config, but actually specifies a list of other config files 115 | # -- to run in a slurm job array 116 | if args.batch_launch: 117 | with open(args.fname, 'r') as y_file: 118 | config_fnames = yaml.load(y_file, Loader=yaml.FullLoader) 119 | # ---------------------------------------------------------------------- # 120 | 121 | # ---------------------------------------------------------------------- # 122 | # 2. Parse each yaml config file as a dict and place in list 123 | # ---------------------------------------------------------------------- # 124 | nodes, tasks_per_node = None, None 125 | configs = [] 126 | for f in config_fnames: 127 | with open(f, 'r') as y_file: 128 | _params = yaml.load(y_file, Loader=yaml.FullLoader) 129 | nodes = int(_params.get('nodes')) 130 | tasks_per_node = int(_params.get('tasks_per_node')) 131 | configs += [_params] 132 | logger.info(f'Loaded {len(configs)} config files') 133 | logger.info(f'Running all jobs with {nodes=} / {tasks_per_node=}') 134 | # ---------------------------------------------------------------------- # 135 | 136 | # ---------------------------------------------------------------------- # 137 | # 3. Launch evals with parsed config files 138 | # ---------------------------------------------------------------------- # 139 | launch_app_with_parsed_args( 140 | args_for_pretrain=configs, 141 | submitit_folder=args.folder, 142 | partition=args.partition, 143 | timeout=args.time, 144 | nodes=nodes, 145 | tasks_per_node=tasks_per_node, 146 | exclude_nodes=args.exclude) 147 | # ---------------------------------------------------------------------- # 148 | 149 | 150 | if __name__ == '__main__': 151 | args = parser.parse_args() 152 | launch() 153 | -------------------------------------------------------------------------------- /app/vjepa/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import torch 9 | import torchvision.transforms as transforms 10 | 11 | import src.datasets.utils.video.transforms as video_transforms 12 | from src.datasets.utils.video.randerase import RandomErasing 13 | 14 | 15 | def make_transforms( 16 | random_horizontal_flip=True, 17 | random_resize_aspect_ratio=(3/4, 4/3), 18 | random_resize_scale=(0.3, 1.0), 19 | reprob=0.0, 20 | auto_augment=False, 21 | motion_shift=False, 22 | crop_size=224, 23 | normalize=((0.485, 0.456, 0.406), 24 | (0.229, 0.224, 0.225)) 25 | ): 26 | 27 | _frames_augmentation = VideoTransform( 28 | random_horizontal_flip=random_horizontal_flip, 29 | random_resize_aspect_ratio=random_resize_aspect_ratio, 30 | random_resize_scale=random_resize_scale, 31 | reprob=reprob, 32 | auto_augment=auto_augment, 33 | motion_shift=motion_shift, 34 | crop_size=crop_size, 35 | normalize=normalize, 36 | ) 37 | return _frames_augmentation 38 | 39 | 40 | class VideoTransform(object): 41 | 42 | def __init__( 43 | self, 44 | random_horizontal_flip=True, 45 | random_resize_aspect_ratio=(3/4, 4/3), 46 | random_resize_scale=(0.3, 1.0), 47 | reprob=0.0, 48 | auto_augment=False, 49 | motion_shift=False, 50 | crop_size=224, 51 | normalize=((0.485, 0.456, 0.406), 52 | (0.229, 0.224, 0.225)) 53 | ): 54 | 55 | self.random_horizontal_flip = random_horizontal_flip 56 | self.random_resize_aspect_ratio = random_resize_aspect_ratio 57 | self.random_resize_scale = random_resize_scale 58 | self.auto_augment = auto_augment 59 | self.motion_shift = motion_shift 60 | self.crop_size = crop_size 61 | self.mean = torch.tensor(normalize[0], dtype=torch.float32) 62 | self.std = torch.tensor(normalize[1], dtype=torch.float32) 63 | if not self.auto_augment: 64 | # Without auto-augment, PIL and tensor conversions simply scale uint8 space by 255. 65 | self.mean *= 255. 66 | self.std *= 255. 67 | 68 | self.autoaug_transform = video_transforms.create_random_augment( 69 | input_size=(crop_size, crop_size), 70 | auto_augment='rand-m7-n4-mstd0.5-inc1', 71 | interpolation='bicubic', 72 | ) 73 | 74 | self.spatial_transform = video_transforms.random_resized_crop_with_shift \ 75 | if motion_shift else video_transforms.random_resized_crop 76 | 77 | self.reprob = reprob 78 | self.erase_transform = RandomErasing( 79 | reprob, 80 | mode='pixel', 81 | max_count=1, 82 | num_splits=1, 83 | device='cpu', 84 | ) 85 | 86 | def __call__(self, buffer): 87 | 88 | if self.auto_augment: 89 | buffer = [transforms.ToPILImage()(frame) for frame in buffer] 90 | buffer = self.autoaug_transform(buffer) 91 | buffer = [transforms.ToTensor()(img) for img in buffer] 92 | buffer = torch.stack(buffer) # T C H W 93 | buffer = buffer.permute(0, 2, 3, 1) # T H W C 94 | else: 95 | buffer = torch.tensor(buffer, dtype=torch.float32) 96 | 97 | buffer = buffer.permute(3, 0, 1, 2) # T H W C -> C T H W 98 | 99 | buffer = self.spatial_transform( 100 | images=buffer, 101 | target_height=self.crop_size, 102 | target_width=self.crop_size, 103 | scale=self.random_resize_scale, 104 | ratio=self.random_resize_aspect_ratio, 105 | ) 106 | if self.random_horizontal_flip: 107 | buffer, _ = video_transforms.horizontal_flip(0.5, buffer) 108 | 109 | buffer = _tensor_normalize_inplace(buffer, self.mean, self.std) 110 | if self.reprob > 0: 111 | buffer = buffer.permute(1, 0, 2, 3) 112 | buffer = self.erase_transform(buffer) 113 | buffer = buffer.permute(1, 0, 2, 3) 114 | 115 | return buffer 116 | 117 | 118 | def tensor_normalize(tensor, mean, std): 119 | """ 120 | Normalize a given tensor by subtracting the mean and dividing the std. 121 | Args: 122 | tensor (tensor): tensor to normalize. 123 | mean (tensor or list): mean value to subtract. 124 | std (tensor or list): std to divide. 125 | """ 126 | if tensor.dtype == torch.uint8: 127 | tensor = tensor.float() 128 | tensor = tensor / 255.0 129 | if type(mean) == list: 130 | mean = torch.tensor(mean) 131 | if type(std) == list: 132 | std = torch.tensor(std) 133 | tensor = tensor - mean 134 | tensor = tensor / std 135 | return tensor 136 | 137 | 138 | def _tensor_normalize_inplace(tensor, mean, std): 139 | """ 140 | Normalize a given tensor by subtracting the mean and dividing the std. 141 | Args: 142 | tensor (tensor): tensor to normalize (with dimensions C, T, H, W). 143 | mean (tensor): mean value to subtract (in 0 to 255 floats). 144 | std (tensor): std to divide (in 0 to 255 floats). 145 | """ 146 | if tensor.dtype == torch.uint8: 147 | tensor = tensor.float() 148 | 149 | C, T, H, W = tensor.shape 150 | tensor = tensor.view(C, -1).permute(1, 0) # Make C the last dimension 151 | tensor.sub_(mean).div_(std) 152 | tensor = tensor.permute(1, 0).view(C, T, H, W) # Put C back in front 153 | return tensor 154 | -------------------------------------------------------------------------------- /evals/main_distributed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import argparse 9 | import logging 10 | import os 11 | import pprint 12 | import sys 13 | import time 14 | import yaml 15 | 16 | import submitit 17 | 18 | from evals.scaffold import main as eval_main 19 | 20 | logging.basicConfig(stream=sys.stdout, level=logging.INFO) 21 | logger = logging.getLogger() 22 | 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument( 25 | '--folder', type=str, 26 | help='location to save submitit logs', 27 | default='/fsx-jepa/massran/submitit/') 28 | parser.add_argument( 29 | '--exclude', type=str, 30 | help='nodes to exclude from training', 31 | default=None) 32 | parser.add_argument( 33 | '--batch-launch', action='store_true', 34 | help='whether fname points to a file to batch-lauch several config files') 35 | parser.add_argument( 36 | '--fname', type=str, 37 | help='yaml file containing config file names to launch', 38 | default='configs.yaml') 39 | parser.add_argument( 40 | '--partition', type=str, 41 | help='cluster partition to submit jobs on') 42 | parser.add_argument( 43 | '--time', type=int, default=4300, 44 | help='time in minutes to run job') 45 | 46 | 47 | class Trainer: 48 | 49 | def __init__(self, args_eval=None, resume_preempt=None): 50 | self.eval_name = args_eval['eval_name'] 51 | self.args_eval = args_eval 52 | self.resume_preempt = resume_preempt 53 | 54 | def __call__(self): 55 | eval_name = self.eval_name 56 | args_eval = self.args_eval 57 | resume_preempt = self.resume_preempt 58 | 59 | logger.info('loaded eval params...') 60 | pp = pprint.PrettyPrinter(indent=4) 61 | pp.pprint(args_eval) 62 | 63 | eval_main( 64 | eval_name, 65 | args_eval=args_eval, 66 | resume_preempt=resume_preempt) 67 | 68 | def checkpoint(self): 69 | fb_trainer = Trainer(self.args_eval, True) 70 | return submitit.helpers.DelayedSubmission(fb_trainer,) 71 | 72 | 73 | def launch_evals_with_parsed_args( 74 | args_for_evals, 75 | submitit_folder, 76 | partition='learnlab,learnfair', 77 | timeout=4300, 78 | nodes=1, 79 | tasks_per_node=1, 80 | delay_seconds=10, 81 | exclude_nodes=None 82 | ): 83 | if not isinstance(args_for_evals, list): 84 | logger.info(f'Passed in eval-args of type {type(args_for_evals)}') 85 | args_for_evals = [args_for_evals] 86 | 87 | time.sleep(delay_seconds) 88 | logger.info('Launching evaluations in separate jobs...') 89 | executor = submitit.AutoExecutor( 90 | folder=os.path.join(submitit_folder, 'job_%j'), 91 | slurm_max_num_timeout=20) 92 | executor.update_parameters( 93 | slurm_partition=partition, 94 | slurm_mem_per_gpu='55G', 95 | timeout_min=timeout, 96 | nodes=nodes, 97 | tasks_per_node=tasks_per_node, 98 | cpus_per_task=12, 99 | gpus_per_node=tasks_per_node) 100 | 101 | if exclude_nodes is not None: 102 | executor.update_parameters(slurm_exclude=exclude_nodes) 103 | 104 | jobs, trainers = [], [] 105 | with executor.batch(): 106 | for ae in args_for_evals: 107 | fb_trainer = Trainer(ae) 108 | job = executor.submit(fb_trainer,) 109 | trainers.append(fb_trainer) 110 | jobs.append(job) 111 | 112 | for job in jobs: 113 | logger.info(f'Launched eval job with id {job.job_id}') 114 | 115 | 116 | def launch_evals(): 117 | 118 | # ---------------------------------------------------------------------- # 119 | # 1. Put config file names in a list 120 | # ---------------------------------------------------------------------- # 121 | config_fnames = [args.fname] 122 | 123 | # -- If batch-launch is True, then the args.fname yaml file is not a 124 | # -- config, but actually specifies a list of other config files 125 | # -- to run in a slurm job array 126 | if args.batch_launch: 127 | with open(args.fname, 'r') as y_file: 128 | config_fnames = yaml.load(y_file, Loader=yaml.FullLoader) 129 | # ---------------------------------------------------------------------- # 130 | 131 | # ---------------------------------------------------------------------- # 132 | # 2. Parse each yaml config file as a dict and place in list 133 | # ---------------------------------------------------------------------- # 134 | nodes, tasks_per_node = None, None 135 | configs = [] 136 | for f in config_fnames: 137 | with open(f, 'r') as y_file: 138 | _params = yaml.load(y_file, Loader=yaml.FullLoader) 139 | nodes = int(_params.get('nodes')) 140 | tasks_per_node = int(_params.get('tasks_per_node')) 141 | configs += [_params] 142 | logger.info(f'Loaded {len(configs)} config files') 143 | logger.info(f'Running all jobs with {nodes=} / {tasks_per_node=}') 144 | # ---------------------------------------------------------------------- # 145 | 146 | # ---------------------------------------------------------------------- # 147 | # 3. Launch evals with parsed config files 148 | # ---------------------------------------------------------------------- # 149 | launch_evals_with_parsed_args( 150 | args_for_evals=configs, 151 | submitit_folder=args.folder, 152 | partition=args.partition, 153 | timeout=args.time, 154 | nodes=nodes, 155 | tasks_per_node=tasks_per_node, 156 | exclude_nodes=args.exclude) 157 | # ---------------------------------------------------------------------- # 158 | 159 | 160 | if __name__ == '__main__': 161 | args = parser.parse_args() 162 | launch_evals() 163 | -------------------------------------------------------------------------------- /src/models/utils/modules.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | 13 | class MLP(nn.Module): 14 | def __init__( 15 | self, 16 | in_features, 17 | hidden_features=None, 18 | out_features=None, 19 | act_layer=nn.GELU, 20 | drop=0. 21 | ): 22 | super().__init__() 23 | out_features = out_features or in_features 24 | hidden_features = hidden_features or in_features 25 | self.fc1 = nn.Linear(in_features, hidden_features) 26 | self.act = act_layer() 27 | self.fc2 = nn.Linear(hidden_features, out_features) 28 | self.drop = nn.Dropout(drop) 29 | 30 | def forward(self, x): 31 | x = self.fc1(x) 32 | x = self.act(x) 33 | x = self.drop(x) 34 | x = self.fc2(x) 35 | x = self.drop(x) 36 | return x 37 | 38 | 39 | class Attention(nn.Module): 40 | def __init__( 41 | self, 42 | dim, 43 | num_heads=8, 44 | qkv_bias=False, 45 | qk_scale=None, 46 | attn_drop=0., 47 | proj_drop=0., 48 | use_sdpa=True 49 | ): 50 | super().__init__() 51 | self.num_heads = num_heads 52 | head_dim = dim // num_heads 53 | self.scale = qk_scale or head_dim ** -0.5 54 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 55 | self.attn_drop = nn.Dropout(attn_drop) 56 | self.proj = nn.Linear(dim, dim) 57 | self.proj_drop_prob = proj_drop 58 | self.proj_drop = nn.Dropout(proj_drop) 59 | self.use_sdpa = use_sdpa 60 | 61 | def forward(self, x, mask=None): 62 | B, N, C = x.shape 63 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 64 | q, k, v = qkv[0], qkv[1], qkv[2] # [B, num_heads, N, D] 65 | 66 | if self.use_sdpa: 67 | with torch.backends.cuda.sdp_kernel(): 68 | x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.proj_drop_prob) 69 | attn = None 70 | else: 71 | attn = (q @ k.transpose(-2, -1)) * self.scale # [B, num_heads, D, D] 72 | attn = attn.softmax(dim=-1) 73 | attn = self.attn_drop(attn) 74 | x = (attn @ v) 75 | x = x.transpose(1, 2).reshape(B, N, C) 76 | x = self.proj(x) 77 | x = self.proj_drop(x) 78 | return x, attn 79 | 80 | 81 | class Block(nn.Module): 82 | def __init__( 83 | self, 84 | dim, 85 | num_heads, 86 | mlp_ratio=4., 87 | qkv_bias=False, 88 | qk_scale=None, 89 | drop=0., 90 | attn_drop=0., 91 | act_layer=nn.GELU, 92 | norm_layer=nn.LayerNorm, 93 | grid_size=None, 94 | grid_depth=None, 95 | ): 96 | super().__init__() 97 | self.norm1 = norm_layer(dim) 98 | self.attn = Attention( 99 | dim, 100 | num_heads=num_heads, 101 | qkv_bias=qkv_bias, 102 | qk_scale=qk_scale, 103 | attn_drop=attn_drop, 104 | proj_drop=drop) 105 | 106 | self.norm2 = norm_layer(dim) 107 | mlp_hidden_dim = int(dim * mlp_ratio) 108 | self.mlp = MLP( 109 | in_features=dim, 110 | hidden_features=mlp_hidden_dim, 111 | act_layer=act_layer, 112 | drop=drop) 113 | 114 | def forward(self, x, return_attention=False, mask=None): 115 | y, attn = self.attn(self.norm1(x), mask=mask) 116 | if return_attention: 117 | return attn 118 | x = x + y 119 | x = x + self.mlp(self.norm2(x)) 120 | return x 121 | 122 | 123 | class CrossAttention(nn.Module): 124 | def __init__( 125 | self, 126 | dim, 127 | num_heads=12, 128 | qkv_bias=False, 129 | use_sdpa=True 130 | ): 131 | super().__init__() 132 | self.num_heads = num_heads 133 | head_dim = dim // num_heads 134 | self.scale = head_dim ** -0.5 135 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 136 | self.kv = nn.Linear(dim, int(dim*2), bias=qkv_bias) 137 | self.proj = nn.Linear(dim, dim) 138 | self.use_sdpa = use_sdpa 139 | 140 | def forward(self, q, x): 141 | B, n, C = q.shape 142 | q = self.q(q).reshape(B, n, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 143 | 144 | B, N, C = x.shape 145 | kv = self.kv(x).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 146 | k, v = kv[0], kv[1] # (batch_size, num_heads, seq_len, feature_dim_per_head) 147 | 148 | if self.use_sdpa: 149 | with torch.backends.cuda.sdp_kernel(): 150 | q = F.scaled_dot_product_attention(q, k, v) 151 | else: 152 | xattn = (q @ k.transpose(-2, -1)) * self.scale 153 | xattn = xattn.softmax(dim=-1) # (batch_size, num_heads, query_len, seq_len) 154 | q = (xattn @ v) 155 | 156 | q = q.transpose(1, 2).reshape(B, n, C) 157 | return q 158 | 159 | 160 | class CrossAttentionBlock(nn.Module): 161 | def __init__( 162 | self, 163 | dim, 164 | num_heads, 165 | mlp_ratio=4., 166 | qkv_bias=False, 167 | act_layer=nn.GELU, 168 | norm_layer=nn.LayerNorm 169 | ): 170 | super().__init__() 171 | self.norm1 = norm_layer(dim) 172 | self.xattn = CrossAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias) 173 | self.norm2 = norm_layer(dim) 174 | mlp_hidden_dim = int(dim * mlp_ratio) 175 | self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer) 176 | 177 | def forward(self, q, x): 178 | y = self.xattn(q, self.norm1(x)) 179 | q = q + y 180 | q = q + self.mlp(self.norm2(q)) 181 | return q 182 | -------------------------------------------------------------------------------- /src/utils/monitoring.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import dataclasses 9 | import threading 10 | from typing import Dict, Tuple 11 | 12 | import psutil 13 | 14 | 15 | @dataclasses.dataclass 16 | class ResourceStatsSample: 17 | timestamp: float 18 | cpu_percent: float 19 | read_count: int 20 | write_count: int 21 | read_bytes: int 22 | write_bytes: int 23 | read_chars: int 24 | write_chars: int 25 | cpu_times_user: float 26 | cpu_times_system: float 27 | cpu_times_children_user: float 28 | cpu_times_children_system: float 29 | cpu_times_iowait: float 30 | cpu_affinity: str 31 | cpu_num: int 32 | num_threads: int 33 | num_voluntary_ctx_switches: int 34 | num_involuntary_ctx_switches: int 35 | 36 | def as_tuple(self) -> Dict: 37 | """Return values mirroring fields.""" 38 | return dataclasses.astuple(self) 39 | 40 | def fields(self) -> Tuple[dataclasses.Field, ...]: 41 | """Return fields in this dataclass.""" 42 | return dataclasses.fields(self.__class__) 43 | 44 | 45 | class ResourceMonitoringThread(threading.Thread): 46 | def __init__(self, pid=None, refresh_interval=None, stats_callback_fn=None): 47 | """Starts a thread to monitor pid every refresh_interval seconds. 48 | 49 | Passes a ResourceStatsSample object to the callback.""" 50 | super(ResourceMonitoringThread, self).__init__() 51 | if refresh_interval is None: 52 | refresh_interval = 5 53 | self.is_running_event = threading.Event() 54 | self.p = psutil.Process(pid) 55 | self.refresh_interval = refresh_interval 56 | if stats_callback_fn is None: 57 | # Default callback 58 | def stats_callback_fn(resource_sample: ResourceStatsSample): 59 | print( 60 | f"PID {self.p.pid} Stats: {resource_sample.resource_stats}") 61 | elif not callable(stats_callback_fn): 62 | raise ValueError("Callback needs to be callable, got {}".format( 63 | type(stats_callback_fn))) 64 | self.stats_callback_fn = stats_callback_fn 65 | 66 | def stop(self) -> None: 67 | self.is_running_event.set() 68 | 69 | def run(self) -> None: 70 | while not self.is_running_event.is_set(): 71 | self.sample_counters() 72 | self.is_running_event.wait(self.refresh_interval) 73 | 74 | def log_sample(self, resource_sample: ResourceStatsSample) -> None: 75 | self.stats_callback_fn(resource_sample) 76 | 77 | def sample_counters(self) -> None: 78 | if not self.p.is_running(): 79 | self.stop() 80 | return 81 | 82 | with self.p.oneshot(): 83 | cpu_percent = self.p.cpu_percent() 84 | cpu_times = self.p.cpu_times() 85 | io_counters = self.p.io_counters() 86 | cpu_affinity = self.p.cpu_affinity() 87 | cpu_num = self.p.cpu_num() 88 | num_threads = self.p.num_threads() 89 | num_ctx_switches = self.p.num_ctx_switches() 90 | timestamp = time.time() 91 | 92 | read_count = io_counters.read_count 93 | write_count = io_counters.write_count 94 | read_bytes = io_counters.read_bytes 95 | write_bytes = io_counters.write_bytes 96 | read_chars = io_counters.read_chars 97 | write_chars = io_counters.write_chars 98 | 99 | def compress_cpu_affinity(cpu_affinity): 100 | """Change list representation to interval/range representation.""" 101 | if not cpu_affinity: 102 | return "" 103 | cpu_affinity_compressed = [] 104 | min_x = None 105 | max_x = None 106 | last_x = None 107 | 108 | # Find contiguous ranges 109 | for x in cpu_affinity: 110 | if last_x is None: 111 | # Start interval 112 | min_x = x 113 | max_x = x 114 | last_x = x 115 | continue 116 | elif x == (last_x + 1): 117 | # Move interval up 118 | max_x = x 119 | elif max_x is not None: 120 | # Interval ended, start again 121 | if min_x == max_x: 122 | cpu_affinity_compressed.append("{}".format(min_x)) 123 | else: 124 | cpu_affinity_compressed.append( 125 | "{}-{}".format(min_x, max_x)) 126 | min_x = x 127 | max_x = x 128 | last_x = x 129 | # Terminate last range 130 | if max_x is not None: 131 | if min_x == max_x: 132 | cpu_affinity_compressed.append("{}".format(min_x)) 133 | else: 134 | cpu_affinity_compressed.append( 135 | "{}-{}".format(min_x, max_x)) 136 | 137 | # Concat 138 | cpu_affinity_compressed = ",".join(cpu_affinity_compressed) 139 | 140 | return cpu_affinity_compressed 141 | 142 | cpu_affinity = compress_cpu_affinity(cpu_affinity) 143 | 144 | resource_sample = ResourceStatsSample( 145 | timestamp=timestamp, 146 | cpu_percent=cpu_percent, 147 | read_count=read_count, 148 | write_count=write_count, 149 | read_bytes=read_bytes, 150 | write_bytes=write_bytes, 151 | read_chars=read_chars, 152 | write_chars=write_chars, 153 | cpu_times_user=cpu_times.user, 154 | cpu_times_system=cpu_times.system, 155 | cpu_times_children_user=cpu_times.children_user, 156 | cpu_times_children_system=cpu_times.children_system, 157 | cpu_times_iowait=cpu_times.iowait, 158 | cpu_affinity=cpu_affinity, 159 | cpu_num=cpu_num, 160 | num_threads=num_threads, 161 | num_voluntary_ctx_switches=num_ctx_switches.voluntary, 162 | num_involuntary_ctx_switches=num_ctx_switches.involuntary, 163 | ) 164 | self.log_sample(resource_sample) 165 | 166 | 167 | if __name__ == "__main__": 168 | import multiprocessing 169 | import time 170 | pid = multiprocessing.current_process().pid 171 | monitor_thread = ResourceMonitoringThread(pid, 1) 172 | monitor_thread.start() 173 | time.sleep(5) 174 | print("Shutdown") 175 | monitor_thread.stop() 176 | -------------------------------------------------------------------------------- /app/vjepa/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import logging 9 | import sys 10 | import warnings 11 | import yaml 12 | 13 | 14 | import torch 15 | 16 | import src.models.vision_transformer as video_vit 17 | import src.models.predictor as vit_pred 18 | from src.models.utils.multimask import MultiMaskWrapper, PredictorMultiMaskWrapper 19 | from src.utils.schedulers import ( 20 | WarmupCosineSchedule, 21 | CosineWDSchedule) 22 | from src.utils.tensors import trunc_normal_ 23 | 24 | logging.basicConfig(stream=sys.stdout, level=logging.INFO) 25 | logger = logging.getLogger() 26 | 27 | 28 | def load_checkpoint( 29 | r_path, 30 | encoder, 31 | predictor, 32 | target_encoder, 33 | opt, 34 | scaler, 35 | ): 36 | try: 37 | checkpoint = torch.load(r_path, map_location=torch.device('cpu')) 38 | except Exception as e: 39 | logger.info(f'Encountered exception when loading checkpoint {e}') 40 | 41 | epoch = 0 42 | try: 43 | epoch = checkpoint['epoch'] 44 | 45 | # -- loading encoder 46 | pretrained_dict = checkpoint['encoder'] 47 | msg = encoder.load_state_dict(pretrained_dict) 48 | logger.info(f'loaded pretrained encoder from epoch {epoch} with msg: {msg}') 49 | 50 | # -- loading predictor 51 | pretrained_dict = checkpoint['predictor'] 52 | msg = predictor.load_state_dict(pretrained_dict) 53 | logger.info(f'loaded pretrained predictor from epoch {epoch} with msg: {msg}') 54 | 55 | # -- loading target_encoder 56 | if target_encoder is not None: 57 | print(list(checkpoint.keys())) 58 | pretrained_dict = checkpoint['target_encoder'] 59 | msg = target_encoder.load_state_dict(pretrained_dict) 60 | logger.info( 61 | f'loaded pretrained target encoder from epoch {epoch} with msg: {msg}' 62 | ) 63 | 64 | # -- loading optimizer 65 | opt.load_state_dict(checkpoint['opt']) 66 | if scaler is not None: 67 | scaler.load_state_dict(checkpoint['scaler']) 68 | logger.info(f'loaded optimizers from epoch {epoch}') 69 | logger.info(f'read-path: {r_path}') 70 | del checkpoint 71 | 72 | except Exception as e: 73 | logger.info(f'Encountered exception when loading checkpoint {e}') 74 | epoch = 0 75 | 76 | return ( 77 | encoder, 78 | predictor, 79 | target_encoder, 80 | opt, 81 | scaler, 82 | epoch, 83 | ) 84 | 85 | 86 | def init_video_model( 87 | device, 88 | patch_size=16, 89 | num_frames=16, 90 | tubelet_size=2, 91 | model_name='vit_base', 92 | crop_size=224, 93 | pred_depth=6, 94 | pred_embed_dim=384, 95 | uniform_power=False, 96 | use_mask_tokens=False, 97 | num_mask_tokens=2, 98 | zero_init_mask_tokens=True, 99 | use_sdpa=False, 100 | ): 101 | encoder = video_vit.__dict__[model_name]( 102 | img_size=crop_size, 103 | patch_size=patch_size, 104 | num_frames=num_frames, 105 | tubelet_size=tubelet_size, 106 | uniform_power=uniform_power, 107 | use_sdpa=use_sdpa, 108 | ) 109 | encoder = MultiMaskWrapper(encoder) 110 | predictor = vit_pred.__dict__['vit_predictor']( 111 | img_size=crop_size, 112 | use_mask_tokens=use_mask_tokens, 113 | patch_size=patch_size, 114 | num_frames=num_frames, 115 | tubelet_size=tubelet_size, 116 | embed_dim=encoder.backbone.embed_dim, 117 | predictor_embed_dim=pred_embed_dim, 118 | depth=pred_depth, 119 | num_heads=encoder.backbone.num_heads, 120 | uniform_power=uniform_power, 121 | num_mask_tokens=num_mask_tokens, 122 | zero_init_mask_tokens=zero_init_mask_tokens, 123 | use_sdpa=use_sdpa, 124 | ) 125 | predictor = PredictorMultiMaskWrapper(predictor) 126 | 127 | def init_weights(m): 128 | if isinstance(m, torch.nn.Linear): 129 | trunc_normal_(m.weight, std=0.02) 130 | if m.bias is not None: 131 | torch.nn.init.constant_(m.bias, 0) 132 | elif isinstance(m, torch.nn.LayerNorm): 133 | torch.nn.init.constant_(m.bias, 0) 134 | torch.nn.init.constant_(m.weight, 1.0) 135 | 136 | for m in encoder.modules(): 137 | init_weights(m) 138 | 139 | for m in predictor.modules(): 140 | init_weights(m) 141 | 142 | encoder.to(device) 143 | predictor.to(device) 144 | logger.info(encoder) 145 | logger.info(predictor) 146 | 147 | def count_parameters(model): 148 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 149 | 150 | logger.info(f'Encoder number of parameters: {count_parameters(encoder)}') 151 | logger.info(f'Predictor number of parameters: {count_parameters(predictor)}') 152 | 153 | return encoder, predictor 154 | 155 | 156 | def init_opt( 157 | encoder, 158 | predictor, 159 | iterations_per_epoch, 160 | start_lr, 161 | ref_lr, 162 | warmup, 163 | num_epochs, 164 | wd=1e-6, 165 | final_wd=1e-6, 166 | final_lr=0.0, 167 | mixed_precision=False, 168 | ipe_scale=1.25, 169 | betas=(0.9, 0.999), 170 | eps=1e-8, 171 | zero_init_bias_wd=True, 172 | ): 173 | param_groups = [ 174 | { 175 | 'params': (p for n, p in encoder.named_parameters() 176 | if ('bias' not in n) and (len(p.shape) != 1)) 177 | }, { 178 | 'params': (p for n, p in predictor.named_parameters() 179 | if ('bias' not in n) and (len(p.shape) != 1)) 180 | }, { 181 | 'params': (p for n, p in encoder.named_parameters() 182 | if ('bias' in n) or (len(p.shape) == 1)), 183 | 'WD_exclude': zero_init_bias_wd, 184 | 'weight_decay': 0, 185 | }, { 186 | 'params': (p for n, p in predictor.named_parameters() 187 | if ('bias' in n) or (len(p.shape) == 1)), 188 | 'WD_exclude': zero_init_bias_wd, 189 | 'weight_decay': 0, 190 | }, 191 | ] 192 | 193 | logger.info('Using AdamW') 194 | optimizer = torch.optim.AdamW(param_groups, betas=betas, eps=eps) 195 | scheduler = WarmupCosineSchedule( 196 | optimizer, 197 | warmup_steps=int(warmup * iterations_per_epoch), 198 | start_lr=start_lr, 199 | ref_lr=ref_lr, 200 | final_lr=final_lr, 201 | T_max=int(ipe_scale * num_epochs * iterations_per_epoch), 202 | ) 203 | wd_scheduler = CosineWDSchedule( 204 | optimizer, 205 | ref_wd=wd, 206 | final_wd=final_wd, 207 | T_max=int(ipe_scale * num_epochs * iterations_per_epoch), 208 | ) 209 | scaler = torch.cuda.amp.GradScaler() if mixed_precision else None 210 | return optimizer, scaler, scheduler, wd_scheduler 211 | -------------------------------------------------------------------------------- /src/datasets/utils/video/randerase.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | """ 9 | This implementation is based on 10 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/random_erasing.py 11 | pulished under an Apache License 2.0. 12 | """ 13 | import math 14 | import random 15 | import torch 16 | 17 | 18 | def _get_pixels( 19 | per_pixel, rand_color, patch_size, dtype=torch.float32, device="cuda" 20 | ): 21 | # NOTE I've seen CUDA illegal memory access errors being caused by the normal_() 22 | # paths, flip the order so normal is run on CPU if this becomes a problem 23 | # Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508 24 | if per_pixel: 25 | return torch.empty(patch_size, dtype=dtype, device=device).normal_() 26 | elif rand_color: 27 | return torch.empty( 28 | (patch_size[0], 1, 1), dtype=dtype, device=device 29 | ).normal_() 30 | else: 31 | return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device) 32 | 33 | 34 | class RandomErasing: 35 | """Randomly selects a rectangle region in an image and erases its pixels. 36 | 'Random Erasing Data Augmentation' by Zhong et al. 37 | See https://arxiv.org/pdf/1708.04896.pdf 38 | This variant of RandomErasing is intended to be applied to either a batch 39 | or single image tensor after it has been normalized by dataset mean and std. 40 | Args: 41 | probability: Probability that the Random Erasing operation will be performed. 42 | min_area: Minimum percentage of erased area wrt input image area. 43 | max_area: Maximum percentage of erased area wrt input image area. 44 | min_aspect: Minimum aspect ratio of erased area. 45 | mode: pixel color mode, one of 'const', 'rand', or 'pixel' 46 | 'const' - erase block is constant color of 0 for all channels 47 | 'rand' - erase block is same per-channel random (normal) color 48 | 'pixel' - erase block is per-pixel random (normal) color 49 | max_count: maximum number of erasing blocks per image, area per box is scaled by count. 50 | per-image count is randomly chosen between 1 and this value. 51 | """ 52 | 53 | def __init__( 54 | self, 55 | probability=0.5, 56 | min_area=0.02, 57 | max_area=1 / 3, 58 | min_aspect=0.3, 59 | max_aspect=None, 60 | mode="const", 61 | min_count=1, 62 | max_count=None, 63 | num_splits=0, 64 | device="cuda", 65 | cube=True, 66 | ): 67 | self.probability = probability 68 | self.min_area = min_area 69 | self.max_area = max_area 70 | max_aspect = max_aspect or 1 / min_aspect 71 | self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) 72 | self.min_count = min_count 73 | self.max_count = max_count or min_count 74 | self.num_splits = num_splits 75 | mode = mode.lower() 76 | self.rand_color = False 77 | self.per_pixel = False 78 | self.cube = cube 79 | if mode == "rand": 80 | self.rand_color = True # per block random normal 81 | elif mode == "pixel": 82 | self.per_pixel = True # per pixel random normal 83 | else: 84 | assert not mode or mode == "const" 85 | self.device = device 86 | 87 | def _erase(self, img, chan, img_h, img_w, dtype): 88 | if random.random() > self.probability: 89 | return 90 | area = img_h * img_w 91 | count = ( 92 | self.min_count 93 | if self.min_count == self.max_count 94 | else random.randint(self.min_count, self.max_count) 95 | ) 96 | for _ in range(count): 97 | for _ in range(10): 98 | target_area = ( 99 | random.uniform(self.min_area, self.max_area) * area / count 100 | ) 101 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) 102 | h = int(round(math.sqrt(target_area * aspect_ratio))) 103 | w = int(round(math.sqrt(target_area / aspect_ratio))) 104 | if w < img_w and h < img_h: 105 | top = random.randint(0, img_h - h) 106 | left = random.randint(0, img_w - w) 107 | img[:, top:top + h, left:left + w] = _get_pixels( 108 | self.per_pixel, 109 | self.rand_color, 110 | (chan, h, w), 111 | dtype=dtype, 112 | device=self.device, 113 | ) 114 | break 115 | 116 | def _erase_cube( 117 | self, 118 | img, 119 | batch_start, 120 | batch_size, 121 | chan, 122 | img_h, 123 | img_w, 124 | dtype, 125 | ): 126 | if random.random() > self.probability: 127 | return 128 | area = img_h * img_w 129 | count = ( 130 | self.min_count 131 | if self.min_count == self.max_count 132 | else random.randint(self.min_count, self.max_count) 133 | ) 134 | for _ in range(count): 135 | for _ in range(100): 136 | target_area = ( 137 | random.uniform(self.min_area, self.max_area) * area / count 138 | ) 139 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) 140 | h = int(round(math.sqrt(target_area * aspect_ratio))) 141 | w = int(round(math.sqrt(target_area / aspect_ratio))) 142 | if w < img_w and h < img_h: 143 | top = random.randint(0, img_h - h) 144 | left = random.randint(0, img_w - w) 145 | for i in range(batch_start, batch_size): 146 | img_instance = img[i] 147 | img_instance[ 148 | :, top:top + h, left:left + w 149 | ] = _get_pixels( 150 | self.per_pixel, 151 | self.rand_color, 152 | (chan, h, w), 153 | dtype=dtype, 154 | device=self.device, 155 | ) 156 | break 157 | 158 | def __call__(self, input): 159 | if len(input.size()) == 3: 160 | self._erase(input, *input.size(), input.dtype) 161 | else: 162 | batch_size, chan, img_h, img_w = input.size() 163 | # skip first slice of batch if num_splits is set (for clean portion of samples) 164 | batch_start = ( 165 | batch_size // self.num_splits if self.num_splits > 1 else 0 166 | ) 167 | if self.cube: 168 | self._erase_cube( 169 | input, 170 | batch_start, 171 | batch_size, 172 | chan, 173 | img_h, 174 | img_w, 175 | input.dtype, 176 | ) 177 | else: 178 | for i in range(batch_start, batch_size): 179 | self._erase(input[i], chan, img_h, img_w, input.dtype) 180 | return input 181 | -------------------------------------------------------------------------------- /src/masks/multiblock3d.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import math 9 | 10 | from multiprocessing import Value 11 | 12 | from logging import getLogger 13 | 14 | import torch 15 | 16 | _GLOBAL_SEED = 0 17 | logger = getLogger() 18 | 19 | 20 | class MaskCollator(object): 21 | 22 | def __init__( 23 | self, 24 | cfgs_mask, 25 | crop_size=(224, 224), 26 | num_frames=16, 27 | patch_size=(16, 16), 28 | tubelet_size=2, 29 | ): 30 | super(MaskCollator, self).__init__() 31 | 32 | self.mask_generators = [] 33 | for m in cfgs_mask: 34 | mask_generator = _MaskGenerator( 35 | crop_size=crop_size, 36 | num_frames=num_frames, 37 | spatial_patch_size=patch_size, 38 | temporal_patch_size=tubelet_size, 39 | spatial_pred_mask_scale=m.get('spatial_scale'), 40 | temporal_pred_mask_scale=m.get('temporal_scale'), 41 | aspect_ratio=m.get('aspect_ratio'), 42 | npred=m.get('num_blocks'), 43 | max_context_frames_ratio=m.get('max_temporal_keep', 1.0), 44 | max_keep=m.get('max_keep', None), 45 | ) 46 | self.mask_generators.append(mask_generator) 47 | 48 | def step(self): 49 | for mask_generator in self.mask_generators: 50 | mask_generator.step() 51 | 52 | def __call__(self, batch): 53 | 54 | batch_size = len(batch) 55 | collated_batch = torch.utils.data.default_collate(batch) 56 | 57 | collated_masks_pred, collated_masks_enc = [], [] 58 | for i, mask_generator in enumerate(self.mask_generators): 59 | masks_enc, masks_pred = mask_generator(batch_size) 60 | collated_masks_enc.append(masks_enc) 61 | collated_masks_pred.append(masks_pred) 62 | 63 | return collated_batch, collated_masks_enc, collated_masks_pred 64 | 65 | 66 | class _MaskGenerator(object): 67 | 68 | def __init__( 69 | self, 70 | crop_size=(224, 224), 71 | num_frames=16, 72 | spatial_patch_size=(16, 16), 73 | temporal_patch_size=2, 74 | spatial_pred_mask_scale=(0.2, 0.8), 75 | temporal_pred_mask_scale=(1.0, 1.0), 76 | aspect_ratio=(0.3, 3.0), 77 | npred=1, 78 | max_context_frames_ratio=1.0, 79 | max_keep=None, 80 | ): 81 | super(_MaskGenerator, self).__init__() 82 | if not isinstance(crop_size, tuple): 83 | crop_size = (crop_size, ) * 2 84 | self.crop_size = crop_size 85 | self.height, self.width = crop_size[0] // spatial_patch_size, crop_size[1] // spatial_patch_size 86 | self.duration = num_frames // temporal_patch_size 87 | 88 | self.spatial_patch_size = spatial_patch_size 89 | self.temporal_patch_size = temporal_patch_size 90 | 91 | self.aspect_ratio = aspect_ratio 92 | self.spatial_pred_mask_scale = spatial_pred_mask_scale 93 | self.temporal_pred_mask_scale = temporal_pred_mask_scale 94 | self.npred = npred 95 | self.max_context_duration = max(1, int(self.duration * max_context_frames_ratio)) # maximum number of time-steps (frames) spanned by context mask 96 | self.max_keep = max_keep # maximum number of patches to keep in context 97 | self._itr_counter = Value('i', -1) # collator is shared across worker processes 98 | 99 | def step(self): 100 | i = self._itr_counter 101 | with i.get_lock(): 102 | i.value += 1 103 | v = i.value 104 | return v 105 | 106 | def _sample_block_size( 107 | self, 108 | generator, 109 | temporal_scale, 110 | spatial_scale, 111 | aspect_ratio_scale 112 | ): 113 | # -- Sample temporal block mask scale 114 | _rand = torch.rand(1, generator=generator).item() 115 | min_t, max_t = temporal_scale 116 | temporal_mask_scale = min_t + _rand * (max_t - min_t) 117 | t = max(1, int(self.duration * temporal_mask_scale)) 118 | 119 | # -- Sample spatial block mask scale 120 | _rand = torch.rand(1, generator=generator).item() 121 | min_s, max_s = spatial_scale 122 | spatial_mask_scale = min_s + _rand * (max_s - min_s) 123 | spatial_num_keep = int(self.height * self.width * spatial_mask_scale) 124 | 125 | # -- Sample block aspect-ratio 126 | _rand = torch.rand(1, generator=generator).item() 127 | min_ar, max_ar = aspect_ratio_scale 128 | aspect_ratio = min_ar + _rand * (max_ar - min_ar) 129 | 130 | # -- Compute block height and width (given scale and aspect-ratio) 131 | h = int(round(math.sqrt(spatial_num_keep * aspect_ratio))) 132 | w = int(round(math.sqrt(spatial_num_keep / aspect_ratio))) 133 | h = min(h, self.height) 134 | w = min(w, self.width) 135 | 136 | return (t, h, w) 137 | 138 | def _sample_block_mask(self, b_size): 139 | t, h, w = b_size 140 | top = torch.randint(0, self.height - h + 1, (1,)) 141 | left = torch.randint(0, self.width - w + 1, (1,)) 142 | start = torch.randint(0, self.duration - t + 1, (1,)) 143 | 144 | mask = torch.ones((self.duration, self.height, self.width), dtype=torch.int32) 145 | mask[start:start+t, top:top+h, left:left+w] = 0 146 | 147 | # Context mask will only span the first X frames 148 | # (X=self.max_context_frames) 149 | if self.max_context_duration < self.duration: 150 | mask[self.max_context_duration:, :, :] = 0 151 | 152 | # -- 153 | return mask 154 | 155 | def __call__(self, batch_size): 156 | """ 157 | Create encoder and predictor masks when collating imgs into a batch 158 | # 1. sample pred block size using seed 159 | # 2. sample several pred block locations for each image (w/o seed) 160 | # 3. return pred masks and complement (enc mask) 161 | """ 162 | seed = self.step() 163 | g = torch.Generator() 164 | g.manual_seed(seed) 165 | p_size = self._sample_block_size( 166 | generator=g, 167 | temporal_scale=self.temporal_pred_mask_scale, 168 | spatial_scale=self.spatial_pred_mask_scale, 169 | aspect_ratio_scale=self.aspect_ratio, 170 | ) 171 | 172 | collated_masks_pred, collated_masks_enc = [], [] 173 | min_keep_enc = min_keep_pred = self.duration * self.height * self.width 174 | for _ in range(batch_size): 175 | 176 | empty_context = True 177 | while empty_context: 178 | 179 | mask_e = torch.ones((self.duration, self.height, self.width), dtype=torch.int32) 180 | for _ in range(self.npred): 181 | mask_e *= self._sample_block_mask(p_size) 182 | mask_e = mask_e.flatten() 183 | 184 | mask_p = torch.argwhere(mask_e == 0).squeeze() 185 | mask_e = torch.nonzero(mask_e).squeeze() 186 | 187 | empty_context = len(mask_e) == 0 188 | if not empty_context: 189 | min_keep_pred = min(min_keep_pred, len(mask_p)) 190 | min_keep_enc = min(min_keep_enc, len(mask_e)) 191 | collated_masks_pred.append(mask_p) 192 | collated_masks_enc.append(mask_e) 193 | 194 | if self.max_keep is not None: 195 | min_keep_enc = min(min_keep_enc, self.max_keep) 196 | 197 | collated_masks_pred = [cm[:min_keep_pred] for cm in collated_masks_pred] 198 | collated_masks_pred = torch.utils.data.default_collate(collated_masks_pred) 199 | # -- 200 | collated_masks_enc = [cm[:min_keep_enc] for cm in collated_masks_enc] 201 | collated_masks_enc = torch.utils.data.default_collate(collated_masks_enc) 202 | 203 | return collated_masks_enc, collated_masks_pred 204 | -------------------------------------------------------------------------------- /src/models/predictor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import math 9 | from functools import partial 10 | 11 | import torch 12 | import torch.nn as nn 13 | 14 | from src.models.utils.modules import Block 15 | from src.models.utils.pos_embs import get_2d_sincos_pos_embed, get_3d_sincos_pos_embed 16 | from src.utils.tensors import ( 17 | trunc_normal_, 18 | repeat_interleave_batch 19 | ) 20 | from src.masks.utils import apply_masks 21 | 22 | 23 | class VisionTransformerPredictor(nn.Module): 24 | """ Vision Transformer """ 25 | def __init__( 26 | self, 27 | img_size=224, 28 | patch_size=16, 29 | num_frames=1, 30 | tubelet_size=2, 31 | embed_dim=768, 32 | predictor_embed_dim=384, 33 | depth=6, 34 | num_heads=12, 35 | mlp_ratio=4.0, 36 | qkv_bias=True, 37 | qk_scale=None, 38 | drop_rate=0.0, 39 | attn_drop_rate=0.0, 40 | norm_layer=nn.LayerNorm, 41 | init_std=0.02, 42 | uniform_power=False, 43 | use_mask_tokens=False, 44 | num_mask_tokens=2, 45 | zero_init_mask_tokens=True, 46 | **kwargs 47 | ): 48 | super().__init__() 49 | # Map input to predictor dimension 50 | self.predictor_embed = nn.Linear(embed_dim, predictor_embed_dim, bias=True) 51 | 52 | # Mask tokens 53 | self.mask_tokens = None 54 | self.num_mask_tokens = 0 55 | if use_mask_tokens: 56 | self.num_mask_tokens = num_mask_tokens 57 | self.mask_tokens = nn.ParameterList([ 58 | nn.Parameter(torch.zeros(1, 1, predictor_embed_dim)) 59 | for i in range(num_mask_tokens) 60 | ]) 61 | 62 | # Determine positional embedding 63 | self.input_size = img_size 64 | self.patch_size = patch_size 65 | # -- 66 | self.num_frames = num_frames 67 | self.tubelet_size = tubelet_size 68 | self.is_video = num_frames > 1 69 | 70 | grid_size = self.input_size // self.patch_size 71 | grid_depth = self.num_frames // self.tubelet_size 72 | 73 | if self.is_video: 74 | self.num_patches = num_patches = ( 75 | (num_frames // tubelet_size) 76 | * (img_size // patch_size) 77 | * (img_size // patch_size) 78 | ) 79 | else: 80 | self.num_patches = num_patches = ( 81 | (img_size // patch_size) 82 | * (img_size // patch_size) 83 | ) 84 | # Position embedding 85 | self.uniform_power = uniform_power 86 | self.predictor_pos_embed = None 87 | self.predictor_pos_embed = nn.Parameter( 88 | torch.zeros(1, num_patches, predictor_embed_dim), 89 | requires_grad=False) 90 | 91 | # Attention Blocks 92 | self.predictor_blocks = nn.ModuleList([ 93 | Block( 94 | dim=predictor_embed_dim, 95 | num_heads=num_heads, 96 | mlp_ratio=mlp_ratio, 97 | qkv_bias=qkv_bias, 98 | qk_scale=qk_scale, 99 | drop=drop_rate, 100 | act_layer=nn.GELU, 101 | attn_drop=attn_drop_rate, 102 | grid_size=grid_size, 103 | grid_depth=grid_depth, 104 | norm_layer=norm_layer) 105 | for i in range(depth)]) 106 | 107 | # Normalize & project back to input dimension 108 | self.predictor_norm = norm_layer(predictor_embed_dim) 109 | self.predictor_proj = nn.Linear(predictor_embed_dim, embed_dim, bias=True) 110 | 111 | # ------ initialize weights 112 | if self.predictor_pos_embed is not None: 113 | self._init_pos_embed(self.predictor_pos_embed.data) # sincos pos-embed 114 | self.init_std = init_std 115 | if not zero_init_mask_tokens: 116 | for mt in self.mask_tokens: 117 | trunc_normal_(mt, std=init_std) 118 | self.apply(self._init_weights) 119 | self._rescale_blocks() 120 | 121 | def _init_pos_embed(self, pos_embed): 122 | embed_dim = pos_embed.size(-1) 123 | grid_size = self.input_size // self.patch_size 124 | if self.is_video: 125 | grid_depth = self.num_frames // self.tubelet_size 126 | sincos = get_3d_sincos_pos_embed( 127 | embed_dim, 128 | grid_size, 129 | grid_depth, 130 | cls_token=False, 131 | uniform_power=self.uniform_power 132 | ) 133 | else: 134 | sincos = get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False) 135 | pos_embed.copy_(torch.from_numpy(sincos).float().unsqueeze(0)) 136 | 137 | def _init_weights(self, m): 138 | if isinstance(m, nn.Linear): 139 | trunc_normal_(m.weight, std=self.init_std) 140 | if isinstance(m, nn.Linear) and m.bias is not None: 141 | nn.init.constant_(m.bias, 0) 142 | elif isinstance(m, nn.LayerNorm): 143 | nn.init.constant_(m.bias, 0) 144 | nn.init.constant_(m.weight, 1.0) 145 | 146 | def _rescale_blocks(self): 147 | def rescale(param, layer_id): 148 | param.div_(math.sqrt(2.0 * layer_id)) 149 | 150 | for layer_id, layer in enumerate(self.predictor_blocks): 151 | rescale(layer.attn.proj.weight.data, layer_id + 1) 152 | rescale(layer.mlp.fc2.weight.data, layer_id + 1) 153 | 154 | def diffusion(self, x, noise_beta=(0.5, 1.0), steps=1000): 155 | 156 | # Prepare diffusion noise schedule 157 | b1, b2 = noise_beta 158 | beta_scheduler = (b1 + i*(b2-b1)/steps for i in range(steps)) 159 | alpha_scheduler = [] 160 | _alpha = 1.0 161 | for _beta in beta_scheduler: 162 | _alpha *= 1.-_beta 163 | alpha_scheduler += [_alpha] 164 | 165 | # Sample diffusion time step 166 | T = torch.randint(0, steps, (len(x),)) 167 | alpha = torch.tensor(alpha_scheduler, device=x.device)[T].unsqueeze(-1).unsqueeze(-1) 168 | 169 | # Normalize features and apply noise 170 | x = torch.nn.functional.layer_norm(x, (x.size(-1),)) 171 | x = alpha**0.5 * x + (1.-alpha)**0.5 * torch.randn(x.shape, device=x.device) 172 | return x 173 | 174 | def forward(self, ctxt, tgt, masks_ctxt, masks_tgt, mask_index=1): 175 | """ 176 | :param ctxt: context tokens 177 | :param tgt: target tokens 178 | :param masks_ctxt: indices of context tokens in input 179 | :params masks_tgt: indices of target tokens in input 180 | """ 181 | 182 | assert (masks_ctxt is not None) and (masks_tgt is not None), 'Cannot run predictor without mask indices' 183 | 184 | if not isinstance(masks_ctxt, list): 185 | masks_ctxt = [masks_ctxt] 186 | 187 | if not isinstance(masks_tgt, list): 188 | masks_tgt = [masks_tgt] 189 | 190 | # Batch Size 191 | B = len(ctxt) // len(masks_ctxt) 192 | 193 | # Map context tokens to pedictor dimensions 194 | x = self.predictor_embed(ctxt) 195 | _, N_ctxt, D = x.shape 196 | 197 | # Add positional embedding to ctxt tokens 198 | if self.predictor_pos_embed is not None: 199 | ctxt_pos_embed = self.predictor_pos_embed.repeat(B, 1, 1) 200 | x += apply_masks(ctxt_pos_embed, masks_ctxt) 201 | 202 | # Map target tokens to predictor dimensions & add noise (fwd diffusion) 203 | if self.mask_tokens is None: 204 | pred_tokens = self.predictor_embed(tgt) 205 | pred_tokens = self.diffusion(pred_tokens) 206 | else: 207 | mask_index = mask_index % self.num_mask_tokens 208 | pred_tokens = self.mask_tokens[mask_index] 209 | pred_tokens = pred_tokens.repeat(B, self.num_patches, 1) 210 | pred_tokens = apply_masks(pred_tokens, masks_tgt) 211 | 212 | # Add positional embedding to target tokens 213 | if self.predictor_pos_embed is not None: 214 | pos_embs = self.predictor_pos_embed.repeat(B, 1, 1) 215 | pos_embs = apply_masks(pos_embs, masks_tgt) 216 | pos_embs = repeat_interleave_batch(pos_embs, B, repeat=len(masks_ctxt)) 217 | pred_tokens += pos_embs 218 | 219 | # Concatenate context & target tokens 220 | x = x.repeat(len(masks_tgt), 1, 1) 221 | x = torch.cat([x, pred_tokens], dim=1) 222 | 223 | # FIXME: this implementation currently assumes masks_ctxt and masks_tgt 224 | # are alligned 1:1 (ok with MultiMask wrapper on predictor but 225 | # otherwise will break) 226 | masks_ctxt = torch.cat(masks_ctxt, dim=0) 227 | masks_tgt = torch.cat(masks_tgt, dim=0) 228 | masks = torch.cat([masks_ctxt, masks_tgt], dim=1) 229 | 230 | # Fwd prop 231 | for blk in self.predictor_blocks: 232 | x = blk(x, mask=masks) 233 | x = self.predictor_norm(x) 234 | 235 | # Return output corresponding to target tokens 236 | x = x[:, N_ctxt:] 237 | x = self.predictor_proj(x) 238 | 239 | return x 240 | 241 | 242 | def vit_predictor(**kwargs): 243 | model = VisionTransformerPredictor( 244 | mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), 245 | **kwargs) 246 | return model 247 | -------------------------------------------------------------------------------- /src/datasets/video_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import os 9 | import pathlib 10 | import warnings 11 | 12 | from logging import getLogger 13 | 14 | import numpy as np 15 | import pandas as pd 16 | 17 | from decord import VideoReader, cpu 18 | 19 | import torch 20 | 21 | from src.datasets.utils.weighted_sampler import DistributedWeightedSampler 22 | 23 | _GLOBAL_SEED = 0 24 | logger = getLogger() 25 | 26 | 27 | def make_videodataset( 28 | data_paths, 29 | batch_size, 30 | frames_per_clip=8, 31 | frame_step=4, 32 | num_clips=1, 33 | random_clip_sampling=True, 34 | allow_clip_overlap=False, 35 | filter_short_videos=False, 36 | filter_long_videos=int(10**9), 37 | transform=None, 38 | shared_transform=None, 39 | rank=0, 40 | world_size=1, 41 | datasets_weights=None, 42 | collator=None, 43 | drop_last=True, 44 | num_workers=10, 45 | pin_mem=True, 46 | duration=None, 47 | log_dir=None, 48 | ): 49 | dataset = VideoDataset( 50 | data_paths=data_paths, 51 | datasets_weights=datasets_weights, 52 | frames_per_clip=frames_per_clip, 53 | frame_step=frame_step, 54 | num_clips=num_clips, 55 | random_clip_sampling=random_clip_sampling, 56 | allow_clip_overlap=allow_clip_overlap, 57 | filter_short_videos=filter_short_videos, 58 | filter_long_videos=filter_long_videos, 59 | duration=duration, 60 | shared_transform=shared_transform, 61 | transform=transform) 62 | 63 | logger.info('VideoDataset dataset created') 64 | if datasets_weights is not None: 65 | dist_sampler = DistributedWeightedSampler( 66 | dataset.sample_weights, 67 | num_replicas=world_size, 68 | rank=rank, 69 | shuffle=True) 70 | else: 71 | dist_sampler = torch.utils.data.distributed.DistributedSampler( 72 | dataset, 73 | num_replicas=world_size, 74 | rank=rank, 75 | shuffle=True) 76 | 77 | data_loader = torch.utils.data.DataLoader( 78 | dataset, 79 | collate_fn=collator, 80 | sampler=dist_sampler, 81 | batch_size=batch_size, 82 | drop_last=drop_last, 83 | pin_memory=pin_mem, 84 | num_workers=num_workers, 85 | persistent_workers=num_workers > 0) 86 | logger.info('VideoDataset unsupervised data loader created') 87 | 88 | return dataset, data_loader, dist_sampler 89 | 90 | 91 | class VideoDataset(torch.utils.data.Dataset): 92 | """ Video classification dataset. """ 93 | 94 | def __init__( 95 | self, 96 | data_paths, 97 | datasets_weights=None, 98 | frames_per_clip=16, 99 | frame_step=4, 100 | num_clips=1, 101 | transform=None, 102 | shared_transform=None, 103 | random_clip_sampling=True, 104 | allow_clip_overlap=False, 105 | filter_short_videos=False, 106 | filter_long_videos=int(10**9), 107 | duration=None, # duration in seconds 108 | ): 109 | self.data_paths = data_paths 110 | self.datasets_weights = datasets_weights 111 | self.frames_per_clip = frames_per_clip 112 | self.frame_step = frame_step 113 | self.num_clips = num_clips 114 | self.transform = transform 115 | self.shared_transform = shared_transform 116 | self.random_clip_sampling = random_clip_sampling 117 | self.allow_clip_overlap = allow_clip_overlap 118 | self.filter_short_videos = filter_short_videos 119 | self.filter_long_videos = filter_long_videos 120 | self.duration = duration 121 | 122 | if VideoReader is None: 123 | raise ImportError('Unable to import "decord" which is required to read videos.') 124 | 125 | # Load video paths and labels 126 | samples, labels = [], [] 127 | self.num_samples_per_dataset = [] 128 | for data_path in self.data_paths: 129 | 130 | if data_path[-4:] == '.csv': 131 | data = pd.read_csv(data_path, header=None, delimiter=" ") 132 | samples += list(data.values[:, 0]) 133 | labels += list(data.values[:, 1]) 134 | num_samples = len(data) 135 | self.num_samples_per_dataset.append(num_samples) 136 | 137 | elif data_path[-4:] == '.npy': 138 | data = np.load(data_path, allow_pickle=True) 139 | data = list(map(lambda x: repr(x)[1:-1], data)) 140 | samples += data 141 | labels += [0] * len(data) 142 | num_samples = len(data) 143 | self.num_samples_per_dataset.append(len(data)) 144 | 145 | # [Optional] Weights for each sample to be used by downstream 146 | # weighted video sampler 147 | self.sample_weights = None 148 | if self.datasets_weights is not None: 149 | self.sample_weights = [] 150 | for dw, ns in zip(self.datasets_weights, self.num_samples_per_dataset): 151 | self.sample_weights += [dw / ns] * ns 152 | 153 | self.samples = samples 154 | self.labels = labels 155 | 156 | def __getitem__(self, index): 157 | sample = self.samples[index] 158 | 159 | # Keep trying to load videos until you find a valid sample 160 | loaded_video = False 161 | while not loaded_video: 162 | buffer, clip_indices = self.loadvideo_decord(sample) # [T H W 3] 163 | loaded_video = len(buffer) > 0 164 | if not loaded_video: 165 | index = np.random.randint(self.__len__()) 166 | sample = self.samples[index] 167 | 168 | # Label/annotations for video 169 | label = self.labels[index] 170 | 171 | def split_into_clips(video): 172 | """ Split video into a list of clips """ 173 | fpc = self.frames_per_clip 174 | nc = self.num_clips 175 | return [video[i*fpc:(i+1)*fpc] for i in range(nc)] 176 | 177 | # Parse video into frames & apply data augmentations 178 | if self.shared_transform is not None: 179 | buffer = self.shared_transform(buffer) 180 | buffer = split_into_clips(buffer) 181 | if self.transform is not None: 182 | buffer = [self.transform(clip) for clip in buffer] 183 | 184 | return buffer, label, clip_indices 185 | 186 | def loadvideo_decord(self, sample): 187 | """ Load video content using Decord """ 188 | 189 | fname = sample 190 | if not os.path.exists(fname): 191 | warnings.warn(f'video path not found {fname=}') 192 | return [], None 193 | 194 | _fsize = os.path.getsize(fname) 195 | if _fsize < 1 * 1024: # avoid hanging issue 196 | warnings.warn(f'video too short {fname=}') 197 | return [], None 198 | if _fsize > self.filter_long_videos: 199 | warnings.warn(f'skipping long video of size {_fsize=} (bytes)') 200 | return [], None 201 | 202 | try: 203 | vr = VideoReader(fname, num_threads=-1, ctx=cpu(0)) 204 | except Exception: 205 | return [], None 206 | 207 | fpc = self.frames_per_clip 208 | fstp = self.frame_step 209 | if self.duration is not None: 210 | try: 211 | fps = vr.get_avg_fps() 212 | fstp = int(self.duration * fps / fpc) 213 | except Exception as e: 214 | warnings.warn(e) 215 | clip_len = int(fpc * fstp) 216 | 217 | if self.filter_short_videos and len(vr) < clip_len: 218 | warnings.warn(f'skipping video of length {len(vr)}') 219 | return [], None 220 | 221 | vr.seek(0) # Go to start of video before sampling frames 222 | 223 | # Partition video into equal sized segments and sample each clip 224 | # from a different segment 225 | partition_len = len(vr) // self.num_clips 226 | 227 | all_indices, clip_indices = [], [] 228 | for i in range(self.num_clips): 229 | 230 | if partition_len > clip_len: 231 | # If partition_len > clip len, then sample a random window of 232 | # clip_len frames within the segment 233 | end_indx = clip_len 234 | if self.random_clip_sampling: 235 | end_indx = np.random.randint(clip_len, partition_len) 236 | start_indx = end_indx - clip_len 237 | indices = np.linspace(start_indx, end_indx, num=fpc) 238 | indices = np.clip(indices, start_indx, end_indx-1).astype(np.int64) 239 | # -- 240 | indices = indices + i * partition_len 241 | else: 242 | # If partition overlap not allowed and partition_len < clip_len 243 | # then repeatedly append the last frame in the segment until 244 | # we reach the desired clip length 245 | if not self.allow_clip_overlap: 246 | indices = np.linspace(0, partition_len, num=partition_len // fstp) 247 | indices = np.concatenate((indices, np.ones(fpc - partition_len // fstp) * partition_len,)) 248 | indices = np.clip(indices, 0, partition_len-1).astype(np.int64) 249 | # -- 250 | indices = indices + i * partition_len 251 | 252 | # If partition overlap is allowed and partition_len < clip_len 253 | # then start_indx of segment i+1 will lie within segment i 254 | else: 255 | sample_len = min(clip_len, len(vr)) - 1 256 | indices = np.linspace(0, sample_len, num=sample_len // fstp) 257 | indices = np.concatenate((indices, np.ones(fpc - sample_len // fstp) * sample_len,)) 258 | indices = np.clip(indices, 0, sample_len-1).astype(np.int64) 259 | # -- 260 | clip_step = 0 261 | if len(vr) > clip_len: 262 | clip_step = (len(vr) - clip_len) // (self.num_clips - 1) 263 | indices = indices + i * clip_step 264 | 265 | clip_indices.append(indices) 266 | all_indices.extend(list(indices)) 267 | 268 | buffer = vr.get_batch(all_indices).asnumpy() 269 | return buffer, clip_indices 270 | 271 | def __len__(self): 272 | return len(self.samples) 273 | -------------------------------------------------------------------------------- /src/models/vision_transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import math 9 | from functools import partial 10 | 11 | import torch 12 | import torch.nn as nn 13 | 14 | from src.models.utils.patch_embed import PatchEmbed, PatchEmbed3D 15 | from src.models.utils.modules import Block 16 | from src.models.utils.pos_embs import get_2d_sincos_pos_embed, get_3d_sincos_pos_embed 17 | from src.utils.tensors import trunc_normal_ 18 | from src.masks.utils import apply_masks 19 | 20 | 21 | class VisionTransformer(nn.Module): 22 | """ Vision Transformer """ 23 | def __init__( 24 | self, 25 | img_size=224, 26 | patch_size=16, 27 | num_frames=1, 28 | tubelet_size=2, 29 | in_chans=3, 30 | embed_dim=768, 31 | depth=12, 32 | num_heads=12, 33 | mlp_ratio=4.0, 34 | qkv_bias=True, 35 | qk_scale=None, 36 | drop_rate=0.0, 37 | attn_drop_rate=0.0, 38 | norm_layer=nn.LayerNorm, 39 | init_std=0.02, 40 | out_layers=None, 41 | uniform_power=False, 42 | **kwargs 43 | ): 44 | super().__init__() 45 | self.num_features = self.embed_dim = embed_dim 46 | self.num_heads = num_heads 47 | self.out_layers = out_layers 48 | 49 | self.input_size = img_size 50 | self.patch_size = patch_size 51 | 52 | self.num_frames = num_frames 53 | self.tubelet_size = tubelet_size 54 | self.is_video = num_frames > 1 55 | 56 | grid_size = self.input_size // self.patch_size 57 | grid_depth = self.num_frames // self.tubelet_size 58 | 59 | # Tokenize pixels with convolution 60 | if self.is_video: 61 | self.patch_embed = PatchEmbed3D( 62 | patch_size=patch_size, 63 | tubelet_size=tubelet_size, 64 | in_chans=in_chans, 65 | embed_dim=embed_dim) 66 | self.num_patches = ( 67 | (num_frames // tubelet_size) 68 | * (img_size // patch_size) 69 | * (img_size // patch_size) 70 | ) 71 | else: 72 | self.patch_embed = PatchEmbed( 73 | patch_size=patch_size, 74 | in_chans=in_chans, 75 | embed_dim=embed_dim) 76 | self.num_patches = ( 77 | (img_size // patch_size) 78 | * (img_size // patch_size) 79 | ) 80 | 81 | # Position embedding 82 | self.uniform_power = uniform_power 83 | self.pos_embed = None 84 | self.pos_embed = nn.Parameter( 85 | torch.zeros(1, self.num_patches, embed_dim), 86 | requires_grad=False) 87 | 88 | # Attention Blocks 89 | self.blocks = nn.ModuleList([ 90 | Block( 91 | dim=embed_dim, 92 | num_heads=num_heads, 93 | mlp_ratio=mlp_ratio, 94 | qkv_bias=qkv_bias, 95 | qk_scale=qk_scale, 96 | drop=drop_rate, 97 | act_layer=nn.GELU, 98 | grid_size=grid_size, 99 | grid_depth=grid_depth, 100 | attn_drop=attn_drop_rate, 101 | norm_layer=norm_layer) 102 | for i in range(depth)]) 103 | self.norm = norm_layer(embed_dim) 104 | 105 | # ------ initialize weights 106 | if self.pos_embed is not None: 107 | self._init_pos_embed(self.pos_embed.data) # sincos pos-embed 108 | self.init_std = init_std 109 | self.apply(self._init_weights) 110 | self._rescale_blocks() 111 | 112 | def _init_pos_embed(self, pos_embed): 113 | embed_dim = pos_embed.size(-1) 114 | grid_size = self.input_size // self.patch_size 115 | if self.is_video: 116 | grid_depth = self.num_frames // self.tubelet_size 117 | sincos = get_3d_sincos_pos_embed( 118 | embed_dim, 119 | grid_size, 120 | grid_depth, 121 | cls_token=False, 122 | uniform_power=self.uniform_power 123 | ) 124 | else: 125 | sincos = get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False) 126 | pos_embed.copy_(torch.from_numpy(sincos).float().unsqueeze(0)) 127 | 128 | def _init_weights(self, m): 129 | if isinstance(m, nn.Linear): 130 | trunc_normal_(m.weight, std=self.init_std) 131 | if isinstance(m, nn.Linear) and m.bias is not None: 132 | nn.init.constant_(m.bias, 0) 133 | elif isinstance(m, nn.LayerNorm): 134 | nn.init.constant_(m.bias, 0) 135 | nn.init.constant_(m.weight, 1.0) 136 | elif isinstance(m, nn.Conv2d): 137 | trunc_normal_(m.weight, std=self.init_std) 138 | if m.bias is not None: 139 | nn.init.constant_(m.bias, 0) 140 | elif isinstance(m, nn.Conv3d): 141 | trunc_normal_(m.weight, std=self.init_std) 142 | if m.bias is not None: 143 | nn.init.constant_(m.bias, 0) 144 | 145 | def _rescale_blocks(self): 146 | def rescale(param, layer_id): 147 | param.div_(math.sqrt(2.0 * layer_id)) 148 | 149 | for layer_id, layer in enumerate(self.blocks): 150 | rescale(layer.attn.proj.weight.data, layer_id + 1) 151 | rescale(layer.mlp.fc2.weight.data, layer_id + 1) 152 | 153 | def get_num_layers(self): 154 | return len(self.blocks) 155 | 156 | def no_weight_decay(self): 157 | return {} 158 | 159 | def forward(self, x, masks=None): 160 | """ 161 | :param x: input image/video 162 | :param masks: indices of patch tokens to mask (remove) 163 | """ 164 | 165 | if masks is not None and not isinstance(masks, list): 166 | masks = [masks] 167 | 168 | # Tokenize input 169 | pos_embed = self.pos_embed 170 | if pos_embed is not None: 171 | pos_embed = self.interpolate_pos_encoding(x, pos_embed) 172 | x = self.patch_embed(x) 173 | if pos_embed is not None: 174 | x += pos_embed 175 | B, N, D = x.shape 176 | 177 | # Mask away unwanted tokens (if masks provided) 178 | if masks is not None: 179 | x = apply_masks(x, masks) 180 | masks = torch.cat(masks, dim=0) 181 | 182 | # Fwd prop 183 | outs = [] 184 | for i, blk in enumerate(self.blocks): 185 | x = blk(x, mask=masks) 186 | if self.out_layers is not None and i in self.out_layers: 187 | outs.append(self.norm(x)) 188 | 189 | if self.out_layers is not None: 190 | return outs 191 | 192 | if self.norm is not None: 193 | x = self.norm(x) 194 | 195 | return x 196 | 197 | def interpolate_pos_encoding(self, x, pos_embed): 198 | 199 | _, N, dim = pos_embed.shape 200 | 201 | if self.is_video: 202 | 203 | # If pos_embed already corret size, just return 204 | _, _, T, H, W = x.shape 205 | if H == self.input_size and W == self.input_size and T == self.num_frames: 206 | return pos_embed 207 | 208 | # Convert depth, height, width of input to be measured in patches 209 | # instead of pixels/frames 210 | T = T // self.tubelet_size 211 | H = H // self.patch_size 212 | W = W // self.patch_size 213 | 214 | # Compute the initialized shape of the positional embedding measured 215 | # in patches 216 | N_t = self.num_frames // self.tubelet_size 217 | N_h = N_w = self.input_size // self.patch_size 218 | assert N_h * N_w * N_t == N, 'Positional embedding initialized incorrectly' 219 | 220 | # Compute scale factor for spatio-temporal interpolation 221 | scale_factor = (T/N_t, H/N_h, W/N_w) 222 | 223 | pos_embed = nn.functional.interpolate( 224 | pos_embed.reshape(1, N_t, N_h, N_w, dim).permute(0, 4, 1, 2, 3), 225 | scale_factor=scale_factor, 226 | mode='trilinear') 227 | pos_embed = pos_embed.permute(0, 2, 3, 4, 1).view(1, -1, dim) 228 | return pos_embed 229 | 230 | else: 231 | 232 | # If pos_embed already corret size, just return 233 | _, _, H, W = x.shape 234 | if H == self.input_size and W == self.input_size: 235 | return pos_embed 236 | 237 | # Compute scale factor for spatial interpolation 238 | npatch = (H // self.patch_size) * (W // self.patch_size) 239 | scale_factor = math.sqrt(npatch / N) 240 | 241 | pos_embed = nn.functional.interpolate( 242 | pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), 243 | scale_factor=scale_factor, 244 | mode='bicubic') 245 | pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) 246 | return pos_embed 247 | 248 | 249 | def vit_tiny(patch_size=16, **kwargs): 250 | model = VisionTransformer( 251 | patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, 252 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 253 | return model 254 | 255 | 256 | def vit_small(patch_size=16, **kwargs): 257 | model = VisionTransformer( 258 | patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, 259 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 260 | return model 261 | 262 | 263 | def vit_base(patch_size=16, **kwargs): 264 | model = VisionTransformer( 265 | patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, 266 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 267 | return model 268 | 269 | 270 | def vit_large(patch_size=16, **kwargs): 271 | model = VisionTransformer( 272 | patch_size=patch_size, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, 273 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 274 | return model 275 | 276 | 277 | def vit_huge(patch_size=16, **kwargs): 278 | model = VisionTransformer( 279 | patch_size=patch_size, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, 280 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 281 | return model 282 | 283 | 284 | def vit_giant(patch_size=16, **kwargs): 285 | model = VisionTransformer( 286 | patch_size=patch_size, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=48/11, 287 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 288 | return model 289 | 290 | 291 | def vit_gigantic(patch_size=14, **kwargs): 292 | model = VisionTransformer( 293 | patch_size=patch_size, embed_dim=1664, depth=48, num_heads=16, mpl_ratio=64/13, 294 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs 295 | ) 296 | return model 297 | 298 | 299 | VIT_EMBED_DIMS = { 300 | 'vit_tiny': 192, 301 | 'vit_small': 384, 302 | 'vit_base': 768, 303 | 'vit_large': 1024, 304 | 'vit_huge': 1280, 305 | 'vit_giant': 1408, 306 | 'vit_gigantic': 1664, 307 | } 308 | -------------------------------------------------------------------------------- /evals/video_classification_frozen/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import numpy as np 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torchvision.transforms as transforms 13 | 14 | import src.datasets.utils.video.transforms as video_transforms 15 | import src.datasets.utils.video.volume_transforms as volume_transforms 16 | 17 | from src.datasets.utils.video.randerase import RandomErasing 18 | 19 | from src.models.utils.pos_embs import get_1d_sincos_pos_embed 20 | from src.masks.utils import apply_masks 21 | 22 | 23 | class FrameAggregation(nn.Module): 24 | """ 25 | Process each frame independently and concatenate all tokens 26 | """ 27 | 28 | def __init__( 29 | self, 30 | model, 31 | max_frames=10000, 32 | use_pos_embed=False, 33 | attend_across_segments=False 34 | ): 35 | super().__init__() 36 | self.model = model 37 | self.embed_dim = embed_dim = model.embed_dim 38 | self.num_heads = model.num_heads 39 | self.attend_across_segments = attend_across_segments 40 | # 1D-temporal pos-embedding 41 | self.pos_embed = None 42 | if use_pos_embed: 43 | self.pos_embed = nn.Parameter( 44 | torch.zeros(1, max_frames, embed_dim), 45 | requires_grad=False) 46 | sincos = get_1d_sincos_pos_embed(embed_dim, max_frames) 47 | self.pos_embed.copy_(torch.from_numpy(sincos).float().unsqueeze(0)) 48 | 49 | def forward(self, x, clip_indices=None): 50 | 51 | # TODO: implement attend_across_segments=False 52 | # num_clips = len(x) 53 | num_views_per_clip = len(x[0]) 54 | 55 | # Concatenate views along batch dimension 56 | x = [torch.cat(xi, dim=0) for xi in x] 57 | # Concatenate clips along temporal dimension 58 | x = torch.cat(x, dim=2) 59 | B, C, T, H, W = x.size() 60 | 61 | # Put each frame along the batch dimension 62 | x = x.permute(0, 2, 1, 3, 4).reshape(B*T, C, H, W) 63 | 64 | outputs = self.model(x) 65 | _, N, D = outputs.size() 66 | outputs = outputs.reshape(B, T, N, D).flatten(1, 2) 67 | 68 | # Separate views into list 69 | B = B // num_views_per_clip 70 | all_outputs = [] 71 | for i in range(num_views_per_clip): 72 | o = outputs[i*B:(i+1)*B] 73 | # Compute positional embedding 74 | if (self.pos_embed is not None) and (clip_indices is not None): 75 | pos_embed = self.pos_embed.repeat(B, 1, 1) # [B, F, D] 76 | pos_embed = apply_masks(pos_embed, clip_indices, concat=False) # list(Tensor([B, T, D])) 77 | pos_embed = torch.cat(pos_embed, dim=1) # concatenate along temporal dimension 78 | pos_embed = pos_embed.unsqueeze(2).repeat(1, 1, N, 1) # [B, T*num_clips, N, D] 79 | pos_embed = pos_embed.flatten(1, 2) 80 | o += pos_embed 81 | all_outputs += [o] 82 | 83 | return all_outputs 84 | 85 | 86 | class ClipAggregation(nn.Module): 87 | """ 88 | Process each clip independently and concatenate all tokens 89 | """ 90 | 91 | def __init__( 92 | self, 93 | model, 94 | tubelet_size=2, 95 | max_frames=10000, 96 | use_pos_embed=False, 97 | attend_across_segments=False 98 | ): 99 | super().__init__() 100 | self.model = model 101 | self.tubelet_size = tubelet_size 102 | self.embed_dim = embed_dim = model.embed_dim 103 | self.num_heads = model.num_heads 104 | self.attend_across_segments = attend_across_segments 105 | # 1D-temporal pos-embedding 106 | self.pos_embed = None 107 | if use_pos_embed: 108 | max_T = max_frames // tubelet_size 109 | self.pos_embed = nn.Parameter( 110 | torch.zeros(1, max_T, embed_dim), 111 | requires_grad=False) 112 | sincos = get_1d_sincos_pos_embed(embed_dim, max_T) 113 | self.pos_embed.copy_(torch.from_numpy(sincos).float().unsqueeze(0)) 114 | 115 | def forward(self, x, clip_indices=None): 116 | 117 | num_clips = len(x) 118 | num_views_per_clip = len(x[0]) 119 | B, C, T, H, W = x[0][0].size() 120 | 121 | # Concatenate all spatial and temporal views along batch dimension 122 | x = [torch.cat(xi, dim=0) for xi in x] 123 | x = torch.cat(x, dim=0) 124 | outputs = self.model(x) 125 | _, N, D = outputs.size() 126 | 127 | T = T // self.tubelet_size # Num temporal tokens 128 | N = N // T # Num spatial tokens 129 | 130 | # Unroll outputs into a 2D array [spatial_views x temporal_views] 131 | eff_B = B * num_views_per_clip 132 | all_outputs = [[] for _ in range(num_views_per_clip)] 133 | for i in range(num_clips): 134 | o = outputs[i*eff_B:(i+1)*eff_B] 135 | for j in range(num_views_per_clip): 136 | all_outputs[j].append(o[j*B:(j+1)*B]) 137 | 138 | if not self.attend_across_segments: 139 | return all_outputs 140 | 141 | for i, outputs in enumerate(all_outputs): 142 | 143 | # Concatenate along temporal dimension 144 | outputs = [o.reshape(B, T, N, D) for o in outputs] 145 | outputs = torch.cat(outputs, dim=1).flatten(1, 2) 146 | 147 | # Compute positional embedding 148 | if (self.pos_embed is not None) and (clip_indices is not None): 149 | clip_indices = [c[:, ::self.tubelet_size] for c in clip_indices] 150 | pos_embed = self.pos_embed.repeat(B, 1, 1) # [B, F, D] 151 | pos_embed = apply_masks(pos_embed, clip_indices, concat=False) # list(Tensor([B, T, D])) 152 | pos_embed = torch.cat(pos_embed, dim=1) # concatenate along temporal dimension 153 | pos_embed = pos_embed.unsqueeze(2).repeat(1, 1, N, 1) # [B, T*num_clips, N, D] 154 | pos_embed = pos_embed.flatten(1, 2) 155 | outputs += pos_embed 156 | 157 | all_outputs[i] = outputs 158 | 159 | return all_outputs 160 | 161 | 162 | def make_transforms( 163 | training=True, 164 | random_horizontal_flip=True, 165 | random_resize_aspect_ratio=(3/4, 4/3), 166 | random_resize_scale=(0.3, 1.0), 167 | reprob=0.0, 168 | auto_augment=False, 169 | motion_shift=False, 170 | crop_size=224, 171 | num_views_per_clip=1, 172 | normalize=((0.485, 0.456, 0.406), 173 | (0.229, 0.224, 0.225)) 174 | ): 175 | 176 | if not training and num_views_per_clip > 1: 177 | print('Making EvalVideoTransform, multi-view') 178 | _frames_augmentation = EvalVideoTransform( 179 | num_views_per_clip=num_views_per_clip, 180 | short_side_size=crop_size, 181 | normalize=normalize, 182 | ) 183 | 184 | else: 185 | _frames_augmentation = VideoTransform( 186 | training=training, 187 | random_horizontal_flip=random_horizontal_flip, 188 | random_resize_aspect_ratio=random_resize_aspect_ratio, 189 | random_resize_scale=random_resize_scale, 190 | reprob=reprob, 191 | auto_augment=auto_augment, 192 | motion_shift=motion_shift, 193 | crop_size=crop_size, 194 | normalize=normalize, 195 | ) 196 | return _frames_augmentation 197 | 198 | 199 | class VideoTransform(object): 200 | 201 | def __init__( 202 | self, 203 | training=True, 204 | random_horizontal_flip=True, 205 | random_resize_aspect_ratio=(3/4, 4/3), 206 | random_resize_scale=(0.3, 1.0), 207 | reprob=0.0, 208 | auto_augment=False, 209 | motion_shift=False, 210 | crop_size=224, 211 | normalize=((0.485, 0.456, 0.406), 212 | (0.229, 0.224, 0.225)) 213 | ): 214 | 215 | self.training = training 216 | 217 | short_side_size = int(crop_size * 256 / 224) 218 | self.eval_transform = video_transforms.Compose([ 219 | video_transforms.Resize(short_side_size, interpolation='bilinear'), 220 | video_transforms.CenterCrop(size=(crop_size, crop_size)), 221 | volume_transforms.ClipToTensor(), 222 | video_transforms.Normalize(mean=normalize[0], std=normalize[1]) 223 | ]) 224 | 225 | self.random_horizontal_flip = random_horizontal_flip 226 | self.random_resize_aspect_ratio = random_resize_aspect_ratio 227 | self.random_resize_scale = random_resize_scale 228 | self.auto_augment = auto_augment 229 | self.motion_shift = motion_shift 230 | self.crop_size = crop_size 231 | self.normalize = torch.tensor(normalize) 232 | 233 | self.autoaug_transform = video_transforms.create_random_augment( 234 | input_size=(crop_size, crop_size), 235 | auto_augment='rand-m7-n4-mstd0.5-inc1', 236 | interpolation='bicubic', 237 | ) 238 | 239 | self.spatial_transform = video_transforms.random_resized_crop_with_shift \ 240 | if motion_shift else video_transforms.random_resized_crop 241 | 242 | self.reprob = reprob 243 | self.erase_transform = RandomErasing( 244 | reprob, 245 | mode='pixel', 246 | max_count=1, 247 | num_splits=1, 248 | device='cpu', 249 | ) 250 | 251 | def __call__(self, buffer): 252 | 253 | if not self.training: 254 | return [self.eval_transform(buffer)] 255 | 256 | buffer = [transforms.ToPILImage()(frame) for frame in buffer] 257 | 258 | if self.auto_augment: 259 | buffer = self.autoaug_transform(buffer) 260 | 261 | buffer = [transforms.ToTensor()(img) for img in buffer] 262 | buffer = torch.stack(buffer) # T C H W 263 | buffer = buffer.permute(0, 2, 3, 1) # T H W C 264 | 265 | buffer = tensor_normalize(buffer, self.normalize[0], self.normalize[1]) 266 | buffer = buffer.permute(3, 0, 1, 2) # T H W C -> C T H W 267 | 268 | buffer = self.spatial_transform( 269 | images=buffer, 270 | target_height=self.crop_size, 271 | target_width=self.crop_size, 272 | scale=self.random_resize_scale, 273 | ratio=self.random_resize_aspect_ratio, 274 | ) 275 | if self.random_horizontal_flip: 276 | buffer, _ = video_transforms.horizontal_flip(0.5, buffer) 277 | 278 | if self.reprob > 0: 279 | buffer = buffer.permute(1, 0, 2, 3) 280 | buffer = self.erase_transform(buffer) 281 | buffer = buffer.permute(1, 0, 2, 3) 282 | 283 | return [buffer] 284 | 285 | 286 | class EvalVideoTransform(object): 287 | 288 | def __init__( 289 | self, 290 | num_views_per_clip=1, 291 | short_side_size=224, 292 | normalize=((0.485, 0.456, 0.406), 293 | (0.229, 0.224, 0.225)) 294 | ): 295 | self.views_per_clip = num_views_per_clip 296 | self.short_side_size = short_side_size 297 | self.spatial_resize = video_transforms.Resize(short_side_size, interpolation='bilinear') 298 | self.to_tensor = video_transforms.Compose([ 299 | volume_transforms.ClipToTensor(), 300 | video_transforms.Normalize(mean=normalize[0], std=normalize[1]) 301 | ]) 302 | 303 | def __call__(self, buffer): 304 | 305 | # Sample several spatial views of each clip 306 | buffer = np.array(self.spatial_resize(buffer)) 307 | T, H, W, C = buffer.shape 308 | 309 | num_views = self.views_per_clip 310 | side_len = self.short_side_size 311 | spatial_step = (max(H, W) - side_len) // (num_views - 1) 312 | 313 | all_views = [] 314 | for i in range(num_views): 315 | start = i*spatial_step 316 | if H > W: 317 | view = buffer[:, start:start+side_len, :, :] 318 | else: 319 | view = buffer[:, :, start:start+side_len, :] 320 | view = self.to_tensor(view) 321 | all_views.append(view) 322 | 323 | return all_views 324 | 325 | 326 | def tensor_normalize(tensor, mean, std): 327 | """ 328 | Normalize a given tensor by subtracting the mean and dividing the std. 329 | Args: 330 | tensor (tensor): tensor to normalize. 331 | mean (tensor or list): mean value to subtract. 332 | std (tensor or list): std to divide. 333 | """ 334 | if tensor.dtype == torch.uint8: 335 | tensor = tensor.float() 336 | tensor = tensor / 255.0 337 | if type(mean) == list: 338 | mean = torch.tensor(mean) 339 | if type(std) == list: 340 | std = torch.tensor(std) 341 | tensor = tensor - mean 342 | tensor = tensor / std 343 | return tensor 344 | -------------------------------------------------------------------------------- /evals/image_classification_frozen/eval.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import os 9 | 10 | # -- FOR DISTRIBUTED TRAINING ENSURE ONLY 1 DEVICE VISIBLE PER PROCESS 11 | try: 12 | # -- WARNING: IF DOING DISTRIBUTED TRAINING ON A NON-SLURM CLUSTER, MAKE 13 | # -- SURE TO UPDATE THIS TO GET LOCAL-RANK ON NODE, OR ENSURE 14 | # -- THAT YOUR JOBS ARE LAUNCHED WITH ONLY 1 DEVICE VISIBLE 15 | # -- TO EACH PROCESS 16 | os.environ['CUDA_VISIBLE_DEVICES'] = os.environ['SLURM_LOCALID'] 17 | except Exception: 18 | pass 19 | 20 | import logging 21 | import pprint 22 | 23 | import numpy as np 24 | 25 | import torch 26 | import torch.multiprocessing as mp 27 | import torchvision.transforms as transforms 28 | 29 | from torch.nn.parallel import DistributedDataParallel 30 | 31 | from timm.data import create_transform as timm_make_transforms 32 | 33 | import src.models.vision_transformer as vit 34 | from src.models.attentive_pooler import AttentiveClassifier 35 | from src.datasets.data_manager import ( 36 | init_data, 37 | ) 38 | from src.utils.distributed import ( 39 | init_distributed, 40 | AllReduce 41 | ) 42 | from src.utils.schedulers import ( 43 | WarmupCosineSchedule, 44 | CosineWDSchedule, 45 | ) 46 | from src.utils.logging import ( 47 | AverageMeter, 48 | CSVLogger 49 | ) 50 | 51 | logging.basicConfig() 52 | logger = logging.getLogger() 53 | logger.setLevel(logging.INFO) 54 | 55 | _GLOBAL_SEED = 0 56 | np.random.seed(_GLOBAL_SEED) 57 | torch.manual_seed(_GLOBAL_SEED) 58 | torch.backends.cudnn.benchmark = True 59 | 60 | pp = pprint.PrettyPrinter(indent=4) 61 | 62 | 63 | def main(args_eval, resume_preempt=False): 64 | 65 | # ----------------------------------------------------------------------- # 66 | # PASSED IN PARAMS FROM CONFIG FILE 67 | # ----------------------------------------------------------------------- # 68 | 69 | # -- PRETRAIN 70 | args_pretrain = args_eval.get('pretrain') 71 | checkpoint_key = args_pretrain.get('checkpoint_key', 'target_encoder') 72 | model_name = args_pretrain.get('model_name', None) 73 | patch_size = args_pretrain.get('patch_size', None) 74 | pretrain_folder = args_pretrain.get('folder', None) 75 | ckp_fname = args_pretrain.get('checkpoint', None) 76 | tag = args_pretrain.get('write_tag', None) 77 | use_sdpa = args_pretrain.get('use_sdpa', True) 78 | use_SiLU = args_pretrain.get('use_silu', False) 79 | tight_SiLU = args_pretrain.get('tight_silu', True) 80 | uniform_power = args_pretrain.get('uniform_power', False) 81 | pretrained_path = os.path.join(pretrain_folder, ckp_fname) 82 | # Optional [for Video model]: 83 | tubelet_size = args_pretrain.get('tubelet_size', 2) 84 | frames_per_clip = args_pretrain.get('frames_per_clip', 1) 85 | 86 | # -- DATA 87 | args_data = args_eval.get('data') 88 | dataset_name = args_data.get('dataset_name') 89 | num_classes = args_data.get('num_classes') 90 | root_path = args_data.get('root_path', None) 91 | image_folder = args_data.get('image_folder', None) 92 | resolution = args_data.get('resolution', 224) 93 | 94 | # -- OPTIMIZATION 95 | args_opt = args_eval.get('optimization') 96 | batch_size = args_opt.get('batch_size') 97 | num_epochs = args_opt.get('num_epochs') 98 | wd = args_opt.get('weight_decay') 99 | start_lr = args_opt.get('start_lr') 100 | lr = args_opt.get('lr') 101 | final_lr = args_opt.get('final_lr') 102 | warmup = args_opt.get('warmup') 103 | use_bfloat16 = args_opt.get('use_bfloat16') 104 | 105 | # -- EXPERIMENT-ID/TAG (optional) 106 | resume_checkpoint = args_eval.get('resume_checkpoint', False) or resume_preempt 107 | eval_tag = args_eval.get('tag', None) 108 | 109 | # ----------------------------------------------------------------------- # 110 | 111 | try: 112 | mp.set_start_method('spawn') 113 | except Exception: 114 | pass 115 | 116 | if not torch.cuda.is_available(): 117 | device = torch.device('cpu') 118 | else: 119 | device = torch.device('cuda:0') 120 | torch.cuda.set_device(device) 121 | 122 | world_size, rank = init_distributed() 123 | logger.info(f'Initialized (rank/world-size) {rank}/{world_size}') 124 | 125 | # -- log/checkpointing paths 126 | folder = os.path.join(pretrain_folder, 'image_classification_frozen/') 127 | if eval_tag is not None: 128 | folder = os.path.join(folder, eval_tag) 129 | if not os.path.exists(folder): 130 | os.makedirs(folder, exist_ok=True) 131 | log_file = os.path.join(folder, f'{tag}_r{rank}.csv') 132 | latest_path = os.path.join(folder, f'{tag}-latest.pth.tar') 133 | 134 | # -- make csv_logger 135 | if rank == 0: 136 | csv_logger = CSVLogger(log_file, 137 | ('%d', 'epoch'), 138 | ('%.5f', 'loss'), 139 | ('%.5f', 'acc')) 140 | 141 | # Initialize model 142 | 143 | # -- pretrained encoder (frozen) 144 | encoder = init_model( 145 | crop_size=resolution, 146 | device=device, 147 | pretrained=pretrained_path, 148 | model_name=model_name, 149 | patch_size=patch_size, 150 | frames_per_clip=frames_per_clip, 151 | tubelet_size=tubelet_size, 152 | uniform_power=uniform_power, 153 | checkpoint_key=checkpoint_key, 154 | use_SiLU=use_SiLU, 155 | tight_SiLU=tight_SiLU, 156 | use_sdpa=use_sdpa) 157 | encoder.eval() 158 | for p in encoder.parameters(): 159 | p.requires_grad = False 160 | 161 | # -- init classifier 162 | classifier = AttentiveClassifier( 163 | embed_dim=encoder.embed_dim, 164 | num_heads=encoder.num_heads, 165 | depth=1, 166 | num_classes=num_classes 167 | ).to(device) 168 | 169 | train_loader = make_dataloader( 170 | dataset_name=dataset_name, 171 | root_path=root_path, 172 | resolution=resolution, 173 | image_folder=image_folder, 174 | batch_size=batch_size, 175 | world_size=world_size, 176 | rank=rank, 177 | training=True) 178 | val_loader = make_dataloader( 179 | dataset_name=dataset_name, 180 | root_path=root_path, 181 | resolution=resolution, 182 | image_folder=image_folder, 183 | batch_size=batch_size, 184 | world_size=world_size, 185 | rank=rank, 186 | training=False) 187 | ipe = len(train_loader) 188 | logger.info(f'Dataloader created... iterations per epoch: {ipe}') 189 | 190 | # -- optimizer and scheduler 191 | optimizer, scaler, scheduler, wd_scheduler = init_opt( 192 | classifier=classifier, 193 | wd=wd, 194 | start_lr=start_lr, 195 | ref_lr=lr, 196 | final_lr=final_lr, 197 | iterations_per_epoch=ipe, 198 | warmup=warmup, 199 | num_epochs=num_epochs, 200 | use_bfloat16=use_bfloat16) 201 | classifier = DistributedDataParallel(classifier, static_graph=True) 202 | 203 | # -- load training checkpoint 204 | start_epoch = 0 205 | if resume_checkpoint: 206 | classifier, optimizer, scaler, start_epoch = load_checkpoint( 207 | device=device, 208 | r_path=latest_path, 209 | classifier=classifier, 210 | opt=optimizer, 211 | scaler=scaler) 212 | for _ in range(start_epoch*ipe): 213 | scheduler.step() 214 | wd_scheduler.step() 215 | 216 | def save_checkpoint(epoch): 217 | save_dict = { 218 | 'classifier': classifier.state_dict(), 219 | 'opt': optimizer.state_dict(), 220 | 'scaler': None if scaler is None else scaler.state_dict(), 221 | 'epoch': epoch, 222 | 'batch_size': batch_size, 223 | 'world_size': world_size, 224 | 'lr': lr 225 | } 226 | if rank == 0: 227 | torch.save(save_dict, latest_path) 228 | 229 | # TRAIN LOOP 230 | for epoch in range(start_epoch, num_epochs): 231 | logger.info('Epoch %d' % (epoch + 1)) 232 | train_acc = run_one_epoch( 233 | device=device, 234 | training=True, 235 | encoder=encoder, 236 | classifier=classifier, 237 | scaler=scaler, 238 | optimizer=optimizer, 239 | scheduler=scheduler, 240 | wd_scheduler=wd_scheduler, 241 | data_loader=train_loader, 242 | use_bfloat16=use_bfloat16) 243 | 244 | val_acc = run_one_epoch( 245 | device=device, 246 | training=False, 247 | encoder=encoder, 248 | classifier=classifier, 249 | scaler=scaler, 250 | optimizer=optimizer, 251 | scheduler=scheduler, 252 | wd_scheduler=wd_scheduler, 253 | data_loader=val_loader, 254 | use_bfloat16=use_bfloat16) 255 | 256 | logger.info('[%5d] train: %.3f%% test: %.3f%%' % (epoch + 1, train_acc, val_acc)) 257 | if rank == 0: 258 | csv_logger.log(epoch + 1, train_acc, val_acc) 259 | save_checkpoint(epoch + 1) 260 | 261 | 262 | def run_one_epoch( 263 | device, 264 | training, 265 | encoder, 266 | classifier, 267 | scaler, 268 | optimizer, 269 | scheduler, 270 | wd_scheduler, 271 | data_loader, 272 | use_bfloat16, 273 | ): 274 | 275 | classifier.train(mode=training) 276 | criterion = torch.nn.CrossEntropyLoss() 277 | top1_meter = AverageMeter() 278 | for itr, data in enumerate(data_loader): 279 | 280 | if training: 281 | scheduler.step() 282 | wd_scheduler.step() 283 | 284 | with torch.cuda.amp.autocast(dtype=torch.float16, enabled=use_bfloat16): 285 | 286 | imgs, labels = data[0].to(device), data[1].to(device) 287 | with torch.no_grad(): 288 | outputs = encoder(imgs) 289 | if not training: 290 | outputs = classifier(outputs) 291 | if training: 292 | outputs = classifier(outputs) 293 | 294 | loss = criterion(outputs, labels) 295 | top1_acc = 100. * outputs.max(dim=1).indices.eq(labels).sum() / len(imgs) 296 | top1_acc = float(AllReduce.apply(top1_acc)) 297 | top1_meter.update(top1_acc) 298 | 299 | if training: 300 | if use_bfloat16: 301 | scaler.scale(loss).backward() 302 | scaler.unscale_(optimizer) 303 | torch.nn.utils.clip_grad_norm_(classifier.parameters(), 1.0) 304 | scaler.step(optimizer) 305 | scaler.update() 306 | else: 307 | loss.backward() 308 | torch.nn.utils.clip_grad_norm_(classifier.parameters(), 1.0) 309 | optimizer.step() 310 | optimizer.zero_grad() 311 | 312 | if itr % 20 == 0: 313 | logger.info('[%5d] %.3f%% (loss: %.3f) [mem: %.2e]' 314 | % (itr, top1_meter.avg, loss, 315 | torch.cuda.max_memory_allocated() / 1024.**2)) 316 | 317 | return top1_meter.avg 318 | 319 | 320 | def load_checkpoint( 321 | device, 322 | r_path, 323 | classifier, 324 | opt, 325 | scaler 326 | ): 327 | try: 328 | checkpoint = torch.load(r_path, map_location=torch.device('cpu')) 329 | epoch = checkpoint['epoch'] 330 | 331 | # -- loading encoder 332 | pretrained_dict = checkpoint['classifier'] 333 | msg = classifier.load_state_dict(pretrained_dict) 334 | logger.info(f'loaded pretrained classifier from epoch {epoch} with msg: {msg}') 335 | 336 | # -- loading optimizer 337 | opt.load_state_dict(checkpoint['opt']) 338 | if scaler is not None: 339 | scaler.load_state_dict(checkpoint['scaler']) 340 | logger.info(f'loaded optimizers from epoch {epoch}') 341 | logger.info(f'read-path: {r_path}') 342 | del checkpoint 343 | 344 | except Exception as e: 345 | logger.info(f'Encountered exception when loading checkpoint {e}') 346 | epoch = 0 347 | 348 | return classifier, opt, scaler, epoch 349 | 350 | 351 | def load_pretrained( 352 | encoder, 353 | pretrained, 354 | checkpoint_key='target_encoder' 355 | ): 356 | logger.info(f'Loading pretrained model from {pretrained}') 357 | checkpoint = torch.load(pretrained, map_location='cpu') 358 | try: 359 | pretrained_dict = checkpoint[checkpoint_key] 360 | except Exception: 361 | pretrained_dict = checkpoint['encoder'] 362 | 363 | pretrained_dict = {k.replace('module.', ''): v for k, v in pretrained_dict.items()} 364 | pretrained_dict = {k.replace('backbone.', ''): v for k, v in pretrained_dict.items()} 365 | for k, v in encoder.state_dict().items(): 366 | if k not in pretrained_dict: 367 | logger.info(f'key "{k}" could not be found in loaded state dict') 368 | elif pretrained_dict[k].shape != v.shape: 369 | logger.info(f'key "{k}" is of different shape in model and loaded state dict') 370 | pretrained_dict[k] = v 371 | msg = encoder.load_state_dict(pretrained_dict, strict=False) 372 | print(encoder) 373 | logger.info(f'loaded pretrained model with msg: {msg}') 374 | logger.info(f'loaded pretrained encoder from epoch: {checkpoint["epoch"]}\n path: {pretrained}') 375 | del checkpoint 376 | return encoder 377 | 378 | 379 | def make_dataloader( 380 | dataset_name, 381 | root_path, 382 | image_folder, 383 | batch_size, 384 | world_size, 385 | rank, 386 | resolution=224, 387 | training=False, 388 | subset_file=None 389 | ): 390 | normalization = ((0.485, 0.456, 0.406), 391 | (0.229, 0.224, 0.225)) 392 | if training: 393 | logger.info('implementing auto-agument strategy') 394 | transform = timm_make_transforms( 395 | input_size=resolution, 396 | is_training=training, 397 | auto_augment='original', 398 | interpolation='bicubic', 399 | re_prob=0.25, 400 | re_mode='pixel', 401 | re_count=1, 402 | mean=normalization[0], 403 | std=normalization[1]) 404 | else: 405 | transform = transforms.Compose([ 406 | transforms.Resize(size=int(resolution * 256/224)), 407 | transforms.CenterCrop(size=resolution), 408 | transforms.ToTensor(), 409 | transforms.Normalize(normalization[0], normalization[1])]) 410 | 411 | data_loader, _ = init_data( 412 | data=dataset_name, 413 | transform=transform, 414 | batch_size=batch_size, 415 | world_size=world_size, 416 | rank=rank, 417 | root_path=root_path, 418 | image_folder=image_folder, 419 | training=training, 420 | copy_data=False, 421 | drop_last=False, 422 | subset_file=subset_file) 423 | return data_loader 424 | 425 | 426 | def init_model( 427 | device, 428 | pretrained, 429 | model_name, 430 | patch_size=16, 431 | crop_size=224, 432 | # Video specific parameters 433 | frames_per_clip=16, 434 | tubelet_size=2, 435 | use_sdpa=False, 436 | use_SiLU=False, 437 | tight_SiLU=True, 438 | uniform_power=False, 439 | checkpoint_key='target_encoder' 440 | ): 441 | encoder = vit.__dict__[model_name]( 442 | img_size=crop_size, 443 | patch_size=patch_size, 444 | num_frames=frames_per_clip, 445 | tubelet_size=tubelet_size, 446 | uniform_power=uniform_power, 447 | use_sdpa=use_sdpa, 448 | use_SiLU=use_SiLU, 449 | tight_SiLU=tight_SiLU, 450 | ) 451 | if frames_per_clip > 1: 452 | def forward_prehook(module, input): 453 | input = input[0] # [B, C, H, W] 454 | input = input.unsqueeze(2).repeat(1, 1, frames_per_clip, 1, 1) 455 | return (input) 456 | 457 | encoder.register_forward_pre_hook(forward_prehook) 458 | 459 | encoder.to(device) 460 | encoder = load_pretrained(encoder=encoder, pretrained=pretrained, checkpoint_key=checkpoint_key) 461 | return encoder 462 | 463 | 464 | def init_opt( 465 | classifier, 466 | iterations_per_epoch, 467 | start_lr, 468 | ref_lr, 469 | warmup, 470 | num_epochs, 471 | wd=1e-6, 472 | final_wd=1e-6, 473 | final_lr=0.0, 474 | use_bfloat16=False 475 | ): 476 | param_groups = [ 477 | { 478 | 'params': (p for n, p in classifier.named_parameters() 479 | if ('bias' not in n) and (len(p.shape) != 1)) 480 | }, { 481 | 'params': (p for n, p in classifier.named_parameters() 482 | if ('bias' in n) or (len(p.shape) == 1)), 483 | 'WD_exclude': True, 484 | 'weight_decay': 0 485 | } 486 | ] 487 | 488 | logger.info('Using AdamW') 489 | optimizer = torch.optim.AdamW(param_groups) 490 | scheduler = WarmupCosineSchedule( 491 | optimizer, 492 | warmup_steps=int(warmup*iterations_per_epoch), 493 | start_lr=start_lr, 494 | ref_lr=ref_lr, 495 | final_lr=final_lr, 496 | T_max=int(num_epochs*iterations_per_epoch)) 497 | wd_scheduler = CosineWDSchedule( 498 | optimizer, 499 | ref_wd=wd, 500 | final_wd=final_wd, 501 | T_max=int(num_epochs*iterations_per_epoch)) 502 | scaler = torch.cuda.amp.GradScaler() if use_bfloat16 else None 503 | return optimizer, scaler, scheduler, wd_scheduler 504 | -------------------------------------------------------------------------------- /src/datasets/utils/video/randaugment.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | """ 9 | This implementation is based on 10 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/auto_augment.py 11 | pulished under an Apache License 2.0. 12 | """ 13 | 14 | import math 15 | import numpy as np 16 | import random 17 | import re 18 | import PIL 19 | from PIL import Image, ImageEnhance, ImageOps 20 | 21 | _PIL_VER = tuple([int(x) for x in PIL.__version__.split(".")[:2]]) 22 | 23 | _FILL = (128, 128, 128) 24 | 25 | # This signifies the max integer that the controller RNN could predict for the 26 | # augmentation scheme. 27 | _MAX_LEVEL = 10.0 28 | 29 | _HPARAMS_DEFAULT = { 30 | "translate_const": 250, 31 | "img_mean": _FILL, 32 | } 33 | 34 | _RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) 35 | 36 | 37 | def _interpolation(kwargs): 38 | interpolation = kwargs.pop("resample", Image.BILINEAR) 39 | if isinstance(interpolation, (list, tuple)): 40 | return random.choice(interpolation) 41 | else: 42 | return interpolation 43 | 44 | 45 | def _check_args_tf(kwargs): 46 | if "fillcolor" in kwargs and _PIL_VER < (5, 0): 47 | kwargs.pop("fillcolor") 48 | kwargs["resample"] = _interpolation(kwargs) 49 | 50 | 51 | def shear_x(img, factor, **kwargs): 52 | _check_args_tf(kwargs) 53 | return img.transform( 54 | img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs 55 | ) 56 | 57 | 58 | def shear_y(img, factor, **kwargs): 59 | _check_args_tf(kwargs) 60 | return img.transform( 61 | img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs 62 | ) 63 | 64 | 65 | def translate_x_rel(img, pct, **kwargs): 66 | pixels = pct * img.size[0] 67 | _check_args_tf(kwargs) 68 | return img.transform( 69 | img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs 70 | ) 71 | 72 | 73 | def translate_y_rel(img, pct, **kwargs): 74 | pixels = pct * img.size[1] 75 | _check_args_tf(kwargs) 76 | return img.transform( 77 | img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs 78 | ) 79 | 80 | 81 | def translate_x_abs(img, pixels, **kwargs): 82 | _check_args_tf(kwargs) 83 | return img.transform( 84 | img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs 85 | ) 86 | 87 | 88 | def translate_y_abs(img, pixels, **kwargs): 89 | _check_args_tf(kwargs) 90 | return img.transform( 91 | img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs 92 | ) 93 | 94 | 95 | def rotate(img, degrees, **kwargs): 96 | _check_args_tf(kwargs) 97 | if _PIL_VER >= (5, 2): 98 | return img.rotate(degrees, **kwargs) 99 | elif _PIL_VER >= (5, 0): 100 | w, h = img.size 101 | post_trans = (0, 0) 102 | rotn_center = (w / 2.0, h / 2.0) 103 | angle = -math.radians(degrees) 104 | matrix = [ 105 | round(math.cos(angle), 15), 106 | round(math.sin(angle), 15), 107 | 0.0, 108 | round(-math.sin(angle), 15), 109 | round(math.cos(angle), 15), 110 | 0.0, 111 | ] 112 | 113 | def transform(x, y, matrix): 114 | (a, b, c, d, e, f) = matrix 115 | return a * x + b * y + c, d * x + e * y + f 116 | 117 | matrix[2], matrix[5] = transform( 118 | -rotn_center[0] - post_trans[0], 119 | -rotn_center[1] - post_trans[1], 120 | matrix, 121 | ) 122 | matrix[2] += rotn_center[0] 123 | matrix[5] += rotn_center[1] 124 | return img.transform(img.size, Image.AFFINE, matrix, **kwargs) 125 | else: 126 | return img.rotate(degrees, resample=kwargs["resample"]) 127 | 128 | 129 | def auto_contrast(img, **__): 130 | return ImageOps.autocontrast(img) 131 | 132 | 133 | def invert(img, **__): 134 | return ImageOps.invert(img) 135 | 136 | 137 | def equalize(img, **__): 138 | return ImageOps.equalize(img) 139 | 140 | 141 | def solarize(img, thresh, **__): 142 | return ImageOps.solarize(img, thresh) 143 | 144 | 145 | def solarize_add(img, add, thresh=128, **__): 146 | lut = [] 147 | for i in range(256): 148 | if i < thresh: 149 | lut.append(min(255, i + add)) 150 | else: 151 | lut.append(i) 152 | if img.mode in ("L", "RGB"): 153 | if img.mode == "RGB" and len(lut) == 256: 154 | lut = lut + lut + lut 155 | return img.point(lut) 156 | else: 157 | return img 158 | 159 | 160 | def posterize(img, bits_to_keep, **__): 161 | if bits_to_keep >= 8: 162 | return img 163 | return ImageOps.posterize(img, bits_to_keep) 164 | 165 | 166 | def contrast(img, factor, **__): 167 | return ImageEnhance.Contrast(img).enhance(factor) 168 | 169 | 170 | def color(img, factor, **__): 171 | return ImageEnhance.Color(img).enhance(factor) 172 | 173 | 174 | def brightness(img, factor, **__): 175 | return ImageEnhance.Brightness(img).enhance(factor) 176 | 177 | 178 | def sharpness(img, factor, **__): 179 | return ImageEnhance.Sharpness(img).enhance(factor) 180 | 181 | 182 | def _randomly_negate(v): 183 | """With 50% prob, negate the value""" 184 | return -v if random.random() > 0.5 else v 185 | 186 | 187 | def _rotate_level_to_arg(level, _hparams): 188 | # range [-30, 30] 189 | level = (level / _MAX_LEVEL) * 30.0 190 | level = _randomly_negate(level) 191 | return (level,) 192 | 193 | 194 | def _enhance_level_to_arg(level, _hparams): 195 | # range [0.1, 1.9] 196 | return ((level / _MAX_LEVEL) * 1.8 + 0.1,) 197 | 198 | 199 | def _enhance_increasing_level_to_arg(level, _hparams): 200 | # the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend 201 | # range [0.1, 1.9] 202 | level = (level / _MAX_LEVEL) * 0.9 203 | level = 1.0 + _randomly_negate(level) 204 | return (level,) 205 | 206 | 207 | def _shear_level_to_arg(level, _hparams): 208 | # range [-0.3, 0.3] 209 | level = (level / _MAX_LEVEL) * 0.3 210 | level = _randomly_negate(level) 211 | return (level,) 212 | 213 | 214 | def _translate_abs_level_to_arg(level, hparams): 215 | translate_const = hparams["translate_const"] 216 | level = (level / _MAX_LEVEL) * float(translate_const) 217 | level = _randomly_negate(level) 218 | return (level,) 219 | 220 | 221 | def _translate_rel_level_to_arg(level, hparams): 222 | # default range [-0.45, 0.45] 223 | translate_pct = hparams.get("translate_pct", 0.45) 224 | level = (level / _MAX_LEVEL) * translate_pct 225 | level = _randomly_negate(level) 226 | return (level,) 227 | 228 | 229 | def _posterize_level_to_arg(level, _hparams): 230 | # As per Tensorflow TPU EfficientNet impl 231 | # range [0, 4], 'keep 0 up to 4 MSB of original image' 232 | # intensity/severity of augmentation decreases with level 233 | return (int((level / _MAX_LEVEL) * 4),) 234 | 235 | 236 | def _posterize_increasing_level_to_arg(level, hparams): 237 | # As per Tensorflow models research and UDA impl 238 | # range [4, 0], 'keep 4 down to 0 MSB of original image', 239 | # intensity/severity of augmentation increases with level 240 | return (4 - _posterize_level_to_arg(level, hparams)[0],) 241 | 242 | 243 | def _posterize_original_level_to_arg(level, _hparams): 244 | # As per original AutoAugment paper description 245 | # range [4, 8], 'keep 4 up to 8 MSB of image' 246 | # intensity/severity of augmentation decreases with level 247 | return (int((level / _MAX_LEVEL) * 4) + 4,) 248 | 249 | 250 | def _solarize_level_to_arg(level, _hparams): 251 | # range [0, 256] 252 | # intensity/severity of augmentation decreases with level 253 | return (int((level / _MAX_LEVEL) * 256),) 254 | 255 | 256 | def _solarize_increasing_level_to_arg(level, _hparams): 257 | # range [0, 256] 258 | # intensity/severity of augmentation increases with level 259 | return (256 - _solarize_level_to_arg(level, _hparams)[0],) 260 | 261 | 262 | def _solarize_add_level_to_arg(level, _hparams): 263 | # range [0, 110] 264 | return (int((level / _MAX_LEVEL) * 110),) 265 | 266 | 267 | LEVEL_TO_ARG = { 268 | "AutoContrast": None, 269 | "Equalize": None, 270 | "Invert": None, 271 | "Rotate": _rotate_level_to_arg, 272 | # There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers 273 | "Posterize": _posterize_level_to_arg, 274 | "PosterizeIncreasing": _posterize_increasing_level_to_arg, 275 | "PosterizeOriginal": _posterize_original_level_to_arg, 276 | "Solarize": _solarize_level_to_arg, 277 | "SolarizeIncreasing": _solarize_increasing_level_to_arg, 278 | "SolarizeAdd": _solarize_add_level_to_arg, 279 | "Color": _enhance_level_to_arg, 280 | "ColorIncreasing": _enhance_increasing_level_to_arg, 281 | "Contrast": _enhance_level_to_arg, 282 | "ContrastIncreasing": _enhance_increasing_level_to_arg, 283 | "Brightness": _enhance_level_to_arg, 284 | "BrightnessIncreasing": _enhance_increasing_level_to_arg, 285 | "Sharpness": _enhance_level_to_arg, 286 | "SharpnessIncreasing": _enhance_increasing_level_to_arg, 287 | "ShearX": _shear_level_to_arg, 288 | "ShearY": _shear_level_to_arg, 289 | "TranslateX": _translate_abs_level_to_arg, 290 | "TranslateY": _translate_abs_level_to_arg, 291 | "TranslateXRel": _translate_rel_level_to_arg, 292 | "TranslateYRel": _translate_rel_level_to_arg, 293 | } 294 | 295 | 296 | NAME_TO_OP = { 297 | "AutoContrast": auto_contrast, 298 | "Equalize": equalize, 299 | "Invert": invert, 300 | "Rotate": rotate, 301 | "Posterize": posterize, 302 | "PosterizeIncreasing": posterize, 303 | "PosterizeOriginal": posterize, 304 | "Solarize": solarize, 305 | "SolarizeIncreasing": solarize, 306 | "SolarizeAdd": solarize_add, 307 | "Color": color, 308 | "ColorIncreasing": color, 309 | "Contrast": contrast, 310 | "ContrastIncreasing": contrast, 311 | "Brightness": brightness, 312 | "BrightnessIncreasing": brightness, 313 | "Sharpness": sharpness, 314 | "SharpnessIncreasing": sharpness, 315 | "ShearX": shear_x, 316 | "ShearY": shear_y, 317 | "TranslateX": translate_x_abs, 318 | "TranslateY": translate_y_abs, 319 | "TranslateXRel": translate_x_rel, 320 | "TranslateYRel": translate_y_rel, 321 | } 322 | 323 | 324 | class AugmentOp: 325 | """ 326 | Apply for video. 327 | """ 328 | 329 | def __init__(self, name, prob=0.5, magnitude=10, hparams=None): 330 | hparams = hparams or _HPARAMS_DEFAULT 331 | self.aug_fn = NAME_TO_OP[name] 332 | self.level_fn = LEVEL_TO_ARG[name] 333 | self.prob = prob 334 | self.magnitude = magnitude 335 | self.hparams = hparams.copy() 336 | self.kwargs = { 337 | "fillcolor": hparams["img_mean"] 338 | if "img_mean" in hparams 339 | else _FILL, 340 | "resample": hparams["interpolation"] 341 | if "interpolation" in hparams 342 | else _RANDOM_INTERPOLATION, 343 | } 344 | 345 | # If magnitude_std is > 0, we introduce some randomness 346 | # in the usually fixed policy and sample magnitude from a normal distribution 347 | # with mean `magnitude` and std-dev of `magnitude_std`. 348 | # NOTE This is my own hack, being tested, not in papers or reference impls. 349 | self.magnitude_std = self.hparams.get("magnitude_std", 0) 350 | 351 | def __call__(self, img_list): 352 | if self.prob < 1.0 and random.random() > self.prob: 353 | return img_list 354 | magnitude = self.magnitude 355 | if self.magnitude_std and self.magnitude_std > 0: 356 | magnitude = random.gauss(magnitude, self.magnitude_std) 357 | magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range 358 | level_args = ( 359 | self.level_fn(magnitude, self.hparams) 360 | if self.level_fn is not None 361 | else () 362 | ) 363 | 364 | if isinstance(img_list, list): 365 | return [ 366 | self.aug_fn(img, *level_args, **self.kwargs) for img in img_list 367 | ] 368 | else: 369 | return self.aug_fn(img_list, *level_args, **self.kwargs) 370 | 371 | 372 | _RAND_TRANSFORMS = [ 373 | "AutoContrast", 374 | "Equalize", 375 | "Invert", 376 | "Rotate", 377 | "Posterize", 378 | "Solarize", 379 | "SolarizeAdd", 380 | "Color", 381 | "Contrast", 382 | "Brightness", 383 | "Sharpness", 384 | "ShearX", 385 | "ShearY", 386 | "TranslateXRel", 387 | "TranslateYRel", 388 | ] 389 | 390 | 391 | _RAND_INCREASING_TRANSFORMS = [ 392 | "AutoContrast", 393 | "Equalize", 394 | "Invert", 395 | "Rotate", 396 | "PosterizeIncreasing", 397 | "SolarizeIncreasing", 398 | "SolarizeAdd", 399 | "ColorIncreasing", 400 | "ContrastIncreasing", 401 | "BrightnessIncreasing", 402 | "SharpnessIncreasing", 403 | "ShearX", 404 | "ShearY", 405 | "TranslateXRel", 406 | "TranslateYRel", 407 | ] 408 | 409 | 410 | # These experimental weights are based loosely on the relative improvements mentioned in paper. 411 | # They may not result in increased performance, but could likely be tuned to so. 412 | _RAND_CHOICE_WEIGHTS_0 = { 413 | "Rotate": 0.3, 414 | "ShearX": 0.2, 415 | "ShearY": 0.2, 416 | "TranslateXRel": 0.1, 417 | "TranslateYRel": 0.1, 418 | "Color": 0.025, 419 | "Sharpness": 0.025, 420 | "AutoContrast": 0.025, 421 | "Solarize": 0.005, 422 | "SolarizeAdd": 0.005, 423 | "Contrast": 0.005, 424 | "Brightness": 0.005, 425 | "Equalize": 0.005, 426 | "Posterize": 0, 427 | "Invert": 0, 428 | } 429 | 430 | 431 | def _select_rand_weights(weight_idx=0, transforms=None): 432 | transforms = transforms or _RAND_TRANSFORMS 433 | assert weight_idx == 0 # only one set of weights currently 434 | rand_weights = _RAND_CHOICE_WEIGHTS_0 435 | probs = [rand_weights[k] for k in transforms] 436 | probs /= np.sum(probs) 437 | return probs 438 | 439 | 440 | def rand_augment_ops(magnitude=10, hparams=None, transforms=None): 441 | hparams = hparams or _HPARAMS_DEFAULT 442 | transforms = transforms or _RAND_TRANSFORMS 443 | return [ 444 | AugmentOp(name, prob=0.5, magnitude=magnitude, hparams=hparams) 445 | for name in transforms 446 | ] 447 | 448 | 449 | class RandAugment: 450 | def __init__(self, ops, num_layers=2, choice_weights=None): 451 | self.ops = ops 452 | self.num_layers = num_layers 453 | self.choice_weights = choice_weights 454 | 455 | def __call__(self, img): 456 | # no replacement when using weighted choice 457 | ops = np.random.choice( 458 | self.ops, 459 | self.num_layers, 460 | replace=self.choice_weights is None, 461 | p=self.choice_weights, 462 | ) 463 | for op in ops: 464 | img = op(img) 465 | return img 466 | 467 | 468 | def rand_augment_transform(config_str, hparams): 469 | """ 470 | RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719 471 | 472 | Create a RandAugment transform 473 | :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by 474 | dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining 475 | sections, not order sepecific determine 476 | 'm' - integer magnitude of rand augment 477 | 'n' - integer num layers (number of transform ops selected per image) 478 | 'w' - integer probabiliy weight index (index of a set of weights to influence choice of op) 479 | 'mstd' - float std deviation of magnitude noise applied 480 | 'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0) 481 | Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5 482 | 'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2 483 | :param hparams: Other hparams (kwargs) for the RandAugmentation scheme 484 | :return: A PyTorch compatible Transform 485 | """ 486 | magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10) 487 | num_layers = 2 # default to 2 ops per image 488 | weight_idx = None # default to no probability weights for op choice 489 | transforms = _RAND_TRANSFORMS 490 | config = config_str.split("-") 491 | assert config[0] == "rand" 492 | config = config[1:] 493 | for c in config: 494 | cs = re.split(r"(\d.*)", c) 495 | if len(cs) < 2: 496 | continue 497 | key, val = cs[:2] 498 | if key == "mstd": 499 | # noise param injected via hparams for now 500 | hparams.setdefault("magnitude_std", float(val)) 501 | elif key == "inc": 502 | if bool(val): 503 | transforms = _RAND_INCREASING_TRANSFORMS 504 | elif key == "m": 505 | magnitude = int(val) 506 | elif key == "n": 507 | num_layers = int(val) 508 | elif key == "w": 509 | weight_idx = int(val) 510 | else: 511 | assert NotImplementedError 512 | ra_ops = rand_augment_ops( 513 | magnitude=magnitude, hparams=hparams, transforms=transforms 514 | ) 515 | choice_weights = ( 516 | None if weight_idx is None else _select_rand_weights(weight_idx) 517 | ) 518 | return RandAugment(ra_ops, num_layers, choice_weights=choice_weights) 519 | --------------------------------------------------------------------------------