├── .flake8 ├── .github └── workflows │ ├── base_tests.yaml │ └── linters.yaml ├── .gitignore ├── APACHE-LICENSE ├── CHANGELOG.md ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── app ├── main.py ├── main_distributed.py ├── scaffold.py ├── vjepa │ ├── train.py │ ├── transforms.py │ └── utils.py └── vjepa_droid │ ├── droid.py │ ├── train.py │ ├── transforms.py │ └── utils.py ├── assets ├── flowchart.png ├── vjepa2-abstract-new.png └── vjepa2-ac-abstract-new.png ├── configs ├── eval │ ├── vitg-384 │ │ ├── coin.yaml │ │ ├── diving48.yaml │ │ ├── ek100.yaml │ │ ├── in1k.yaml │ │ ├── jester.yaml │ │ ├── k400.yaml │ │ └── ssv2.yaml │ └── vitl │ │ ├── coin.yaml │ │ ├── diving48.yaml │ │ ├── ek100.yaml │ │ ├── in1k.yaml │ │ ├── jester.yaml │ │ ├── k400.yaml │ │ └── ssv2.yaml ├── inference │ ├── vitg-384 │ │ ├── diving48.yaml │ │ ├── ek100.yaml │ │ └── ssv2.yaml │ └── vitl │ │ ├── diving48.yaml │ │ ├── ek100.yaml │ │ └── ssv2.yaml └── train │ ├── vitg16 │ ├── cooldown-256px-64f.yaml │ ├── cooldown-384px-64f.yaml │ ├── droid-256px-8f.yaml │ └── pretrain-256px-16f.yaml │ ├── vith16 │ ├── cooldown-256px-64f.yaml │ └── pretrain-256px-16f.yaml │ └── vitl16 │ ├── cooldown-256px-64f.yaml │ └── pretrain-256px-16f.yaml ├── evals ├── action_anticipation_frozen │ ├── dataloader.py │ ├── epickitchens.py │ ├── eval.py │ ├── losses.py │ ├── metrics.py │ ├── modelcustom │ │ └── vit_encoder_predictor_concat_ar.py │ ├── models.py │ └── utils.py ├── hub │ ├── __init__.py │ └── preprocessor.py ├── image_classification_frozen │ ├── eval.py │ ├── modelcustom │ │ └── vit_encoder.py │ └── models.py ├── main.py ├── main_distributed.py ├── scaffold.py └── video_classification_frozen │ ├── eval.py │ ├── modelcustom │ ├── vit_encoder_multiclip.py │ └── vit_encoder_multiclip_multilevel.py │ ├── models.py │ └── utils.py ├── hubconf.py ├── notebooks ├── energy_landscape_example.ipynb ├── franka_example_traj.npz ├── utils │ ├── mpc_utils.py │ └── world_model_wrapper.py ├── vjepa2_demo.ipynb └── vjepa2_demo.py ├── pyproject.toml ├── requirements-test.txt ├── requirements.txt ├── setup.py ├── src ├── datasets │ ├── data_manager.py │ ├── imagenet1k.py │ ├── utils │ │ ├── dataloader.py │ │ ├── utils.py │ │ ├── video │ │ │ ├── functional.py │ │ │ ├── randaugment.py │ │ │ ├── randerase.py │ │ │ ├── transforms.py │ │ │ ├── transforms_builder.py │ │ │ └── volume_transforms.py │ │ ├── weighted_sampler.py │ │ └── worker_init_fn.py │ └── video_dataset.py ├── hub │ ├── __init__.py │ └── backbones.py ├── masks │ ├── default.py │ ├── multiseq_multiblock3d.py │ └── utils.py ├── models │ ├── ac_predictor.py │ ├── attentive_pooler.py │ ├── predictor.py │ ├── utils │ │ ├── modules.py │ │ ├── patch_embed.py │ │ └── pos_embs.py │ └── vision_transformer.py └── utils │ ├── checkpoint_loader.py │ ├── distributed.py │ ├── logging.py │ ├── monitoring.py │ ├── schedulers.py │ ├── tensors.py │ └── wrappers.py └── tests ├── __init__.py ├── datasets ├── __init__.py ├── test_dataloader.py ├── test_memory_efficient_sampler.py └── test_vjepa_transforms.py └── models ├── __init__.py ├── test_models.py ├── test_predictor.py └── test_vision_transformer.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 119 3 | select = E,F,W 4 | ignore = E203,E701,W503 5 | per-file-ignores=__init__.py:F401 version.py:F401 6 | -------------------------------------------------------------------------------- /.github/workflows/base_tests.yaml: -------------------------------------------------------------------------------- 1 | name: UnitTests 2 | 3 | on: [push] 4 | 5 | jobs: 6 | unittests: 7 | runs-on: ubuntu-latest 8 | strategy: 9 | max-parallel: 4 10 | 11 | steps: 12 | - uses: actions/checkout@v4 13 | - name: Set up Python 3.12 14 | uses: actions/setup-python@v5 15 | with: 16 | python-version: '3.12' 17 | - name: Add conda to system path 18 | run: | 19 | # $CONDA is an environment variable pointing to the root of the miniconda directory 20 | echo $CONDA/bin >> $GITHUB_PATH 21 | - name: Install dependencies 22 | run: | 23 | conda create --name test-env python=3.12 24 | conda install pytest 25 | echo "Starting setup from $PWD" 26 | pip install -e . 27 | - name: Test with pytest 28 | run: | 29 | pytest tests 30 | -------------------------------------------------------------------------------- /.github/workflows/linters.yaml: -------------------------------------------------------------------------------- 1 | name: Lint (Common Code) 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | paths: 8 | - 'app/' 9 | - 'evals/*.py' 10 | - 'src/' 11 | - 'tests/' 12 | pull_request: 13 | branches: 14 | - master 15 | - 'gh/**' 16 | paths: 17 | - 'app/' 18 | - 'evals/*.py' 19 | - 'src/' 20 | - 'tests/' 21 | 22 | jobs: 23 | run-linters: 24 | name: Run linters 25 | runs-on: ubuntu-latest 26 | 27 | steps: 28 | - uses: actions/checkout@v4 29 | - name: Set up Python 3.12 30 | uses: actions/setup-python@v5 31 | with: 32 | python-version: '3.12' 33 | - name: Install Python lint dependencies 34 | run: | 35 | pip install -r requirements-test.txt 36 | - name: Set lint paths 37 | run: echo "lint_paths=app evals/*.py src tests" >> "$GITHUB_ENV" 38 | - name: Run isort 39 | run: | 40 | python -m isort $lint_paths --check 41 | - name: Run flake8 42 | if: always() 43 | run: | 44 | python -m flake8 --config .flake8 --show-source --statistics $lint_paths 45 | - name: Run black 46 | if: always() 47 | run: | 48 | python -m black --check $lint_paths 49 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .vscode/ 3 | .*.swp 4 | 5 | run_vjepa_aws.py 6 | run.py 7 | main_distributed_video.py 8 | main_video.py 9 | 10 | app/vjepa/configs/temp_aws 11 | app/main_dev.py 12 | app/main_distributed_dev.py 13 | evals/ava/alphaction/data 14 | 15 | run_evals.py 16 | run_evals_v2.py 17 | run_pretrain.py 18 | 19 | *.egg-info/ 20 | *.ipynb_checkpoints/ 21 | 22 | traces/ 23 | third_party/* 24 | 25 | evals/simu_env_planning/local/ 26 | evals/simu_env_planning/docker2/ 27 | evals/simu_env_planning/docker/ 28 | app/vjepa_droid/local/ 29 | app/vjepa_droid_v2/local/ 30 | app/vjepa_droid_v3/local/ 31 | app/vjepa_droid_v4/local/ 32 | configs/local -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | ## [0.0.1] - 2025-06-05 4 | 5 | Initial release of V-JEPA 2 codebase -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at <opensource-conduct@fb.com>. All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to V-JEPA 2 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code is consistent with style guidance (below) and lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 7. Add reviewer(s) for approval. 15 | 16 | ## Contributor License Agreement ("CLA") 17 | In order to accept your pull request, we need you to submit a CLA. You only need 18 | to do this once to work on any of Facebook's open source projects. 19 | 20 | Complete your CLA here: <https://code.facebook.com/cla> 21 | 22 | ## Issues 23 | We use GitHub issues to track public bugs. Please ensure your description is 24 | clear and has sufficient instructions to be able to reproduce the issue. 25 | 26 | Meta has a [bounty program](https://bugbounty.meta.com/) for the safe 27 | disclosure of security bugs. In those cases, please go through the process 28 | outlined on that page and do not file a public issue. 29 | 30 | ## Coding Style 31 | * 4 spaces for indentation rather than tabs 32 | * 119 character line length 33 | * PEP8 formatting 34 | 35 | We recommend using `black`, `isort`, and `flake8` to format your code changes. 36 | 37 | ## License 38 | By contributing to this repository, you agree that your contributions will be licensed 39 | under the LICENSE file in the root directory of this source tree. 40 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Meta Platforms, Inc. and affiliates. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /app/main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import argparse 7 | import multiprocessing as mp 8 | import pprint 9 | from pathlib import Path 10 | 11 | import yaml 12 | 13 | from app.scaffold import main as app_main 14 | from src.utils.distributed import init_distributed 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument("--fname", type=str, help="name of config file to load", default="configs.yaml") 18 | parser.add_argument( 19 | "--devices", 20 | type=str, 21 | nargs="+", 22 | default=["cuda:0", "cuda:1", "cuda:2", "cuda:3", "cuda:4", "cuda:5", "cuda:6", "cuda:7"], 23 | help="which devices to use on local machine", 24 | ) 25 | parser.add_argument( 26 | "--debugmode", 27 | type=bool, 28 | default=False, 29 | help="Setting this to true will not spin up new processes. " 30 | "The main code runs the main process, which makes it easier to \ 31 | debug with checkpointing.", 32 | ) 33 | 34 | 35 | def process_main(rank, fname, world_size, devices): 36 | import os 37 | 38 | os.environ["CUDA_VISIBLE_DEVICES"] = str(devices[rank].split(":")[-1]) 39 | 40 | import logging 41 | 42 | from src.utils.logging import get_logger 43 | 44 | logger = get_logger(force=True) 45 | if rank == 0: 46 | logger.setLevel(logging.INFO) 47 | else: 48 | logger.setLevel(logging.ERROR) 49 | 50 | logger.info(f"called-params {fname}") 51 | 52 | # Load config 53 | params = None 54 | with open(fname, "r") as y_file: 55 | params = yaml.load(y_file, Loader=yaml.FullLoader) 56 | logger.info("loaded params...") 57 | 58 | # Log config 59 | if rank == 0: 60 | pprint.PrettyPrinter(indent=4).pprint(params) 61 | folder = params["folder"] 62 | params_path = os.path.join(folder, "params-pretrain.yaml") 63 | folder = Path(folder) 64 | folder.mkdir(parents=True, exist_ok=True) 65 | with open(params_path, "w") as f: 66 | yaml.dump(params, f) 67 | 68 | # Init distributed (access to comm between GPUS on same machine) 69 | world_size, rank = init_distributed(rank_and_world_size=(rank, world_size)) 70 | logger.info(f"Running... (rank: {rank}/{world_size})") 71 | 72 | # Launch the app with loaded config 73 | app_main(params["app"], args=params) 74 | 75 | 76 | if __name__ == "__main__": 77 | args = parser.parse_args() 78 | if args.debugmode: 79 | process_main(rank=0, fname=args.fname, world_size=1, devices=["cuda:0"]) 80 | else: 81 | num_gpus = len(args.devices) 82 | mp.set_start_method("spawn") 83 | for rank in range(num_gpus): 84 | mp.Process(target=process_main, args=(rank, args.fname, num_gpus, args.devices)).start() 85 | -------------------------------------------------------------------------------- /app/scaffold.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import importlib 7 | import logging 8 | import sys 9 | 10 | logging.basicConfig(stream=sys.stdout, level=logging.INFO) 11 | logger = logging.getLogger() 12 | 13 | 14 | def main(app, args, resume_preempt=False): 15 | 16 | logger.info(f"Running pre-training of app: {app}") 17 | return importlib.import_module(f"app.{app}.train").main(args=args, resume_preempt=resume_preempt) 18 | -------------------------------------------------------------------------------- /app/vjepa/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import torchvision.transforms as transforms 8 | 9 | import src.datasets.utils.video.transforms as video_transforms 10 | from src.datasets.utils.video.randerase import RandomErasing 11 | 12 | 13 | def make_transforms( 14 | random_horizontal_flip=True, 15 | random_resize_aspect_ratio=(3 / 4, 4 / 3), 16 | random_resize_scale=(0.3, 1.0), 17 | reprob=0.0, 18 | auto_augment=False, 19 | motion_shift=False, 20 | crop_size=224, 21 | normalize=((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 22 | ): 23 | 24 | _frames_augmentation = VideoTransform( 25 | random_horizontal_flip=random_horizontal_flip, 26 | random_resize_aspect_ratio=random_resize_aspect_ratio, 27 | random_resize_scale=random_resize_scale, 28 | reprob=reprob, 29 | auto_augment=auto_augment, 30 | motion_shift=motion_shift, 31 | crop_size=crop_size, 32 | normalize=normalize, 33 | ) 34 | return _frames_augmentation 35 | 36 | 37 | class VideoTransform(object): 38 | 39 | def __init__( 40 | self, 41 | random_horizontal_flip=True, 42 | random_resize_aspect_ratio=(3 / 4, 4 / 3), 43 | random_resize_scale=(0.3, 1.0), 44 | reprob=0.0, 45 | auto_augment=False, 46 | motion_shift=False, 47 | crop_size=224, 48 | normalize=((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 49 | ): 50 | 51 | self.random_horizontal_flip = random_horizontal_flip 52 | self.random_resize_aspect_ratio = random_resize_aspect_ratio 53 | self.random_resize_scale = random_resize_scale 54 | self.auto_augment = auto_augment 55 | self.motion_shift = motion_shift 56 | self.crop_size = crop_size 57 | self.mean = torch.tensor(normalize[0], dtype=torch.float32) 58 | self.std = torch.tensor(normalize[1], dtype=torch.float32) 59 | if not self.auto_augment: 60 | # Without auto-augment, PIL and tensor conversions simply scale uint8 space by 255. 61 | self.mean *= 255.0 62 | self.std *= 255.0 63 | 64 | self.autoaug_transform = video_transforms.create_random_augment( 65 | input_size=(crop_size, crop_size), 66 | # auto_augment="rand-m4-n4-w1-mstd0.5-inc1", 67 | auto_augment="rand-m7-n4-mstd0.5-inc1", 68 | interpolation="bicubic", 69 | ) 70 | 71 | self.spatial_transform = ( 72 | video_transforms.random_resized_crop_with_shift if motion_shift else video_transforms.random_resized_crop 73 | ) 74 | 75 | self.reprob = reprob 76 | self.erase_transform = RandomErasing( 77 | reprob, 78 | mode="pixel", 79 | max_count=1, 80 | num_splits=1, 81 | device="cpu", 82 | ) 83 | 84 | def __call__(self, buffer): 85 | 86 | if self.auto_augment: 87 | buffer = [transforms.ToPILImage()(frame) for frame in buffer] 88 | buffer = self.autoaug_transform(buffer) 89 | buffer = [transforms.ToTensor()(img) for img in buffer] 90 | buffer = torch.stack(buffer) # T C H W 91 | buffer = buffer.permute(0, 2, 3, 1) # T H W C 92 | elif torch.is_tensor(buffer): 93 | # TODO: ensure input is always a tensor? 94 | buffer = buffer.to(torch.float32) 95 | else: 96 | buffer = torch.tensor(buffer, dtype=torch.float32) 97 | 98 | buffer = buffer.permute(3, 0, 1, 2) # T H W C -> C T H W 99 | 100 | buffer = self.spatial_transform( 101 | images=buffer, 102 | target_height=self.crop_size, 103 | target_width=self.crop_size, 104 | scale=self.random_resize_scale, 105 | ratio=self.random_resize_aspect_ratio, 106 | ) 107 | if self.random_horizontal_flip: 108 | buffer, _ = video_transforms.horizontal_flip(0.5, buffer) 109 | 110 | buffer = _tensor_normalize_inplace(buffer, self.mean, self.std) 111 | if self.reprob > 0: 112 | buffer = buffer.permute(1, 0, 2, 3) 113 | buffer = self.erase_transform(buffer) 114 | buffer = buffer.permute(1, 0, 2, 3) 115 | 116 | return buffer 117 | 118 | 119 | def tensor_normalize(tensor, mean, std): 120 | """ 121 | Normalize a given tensor by subtracting the mean and dividing the std. 122 | Args: 123 | tensor (tensor): tensor to normalize. 124 | mean (tensor or list): mean value to subtract. 125 | std (tensor or list): std to divide. 126 | """ 127 | if tensor.dtype == torch.uint8: 128 | tensor = tensor.float() 129 | tensor = tensor / 255.0 130 | if isinstance(mean, list): 131 | mean = torch.tensor(mean) 132 | if isinstance(std, list): 133 | std = torch.tensor(std) 134 | tensor = tensor - mean 135 | tensor = tensor / std 136 | return tensor 137 | 138 | 139 | def _tensor_normalize_inplace(tensor, mean, std): 140 | """ 141 | Normalize a given tensor by subtracting the mean and dividing the std. 142 | Args: 143 | tensor (tensor): tensor to normalize (with dimensions C, T, H, W). 144 | mean (tensor): mean value to subtract (in 0 to 255 floats). 145 | std (tensor): std to divide (in 0 to 255 floats). 146 | """ 147 | if tensor.dtype == torch.uint8: 148 | tensor = tensor.float() 149 | 150 | C, T, H, W = tensor.shape 151 | tensor = tensor.view(C, -1).permute(1, 0) # Make C the last dimension 152 | tensor.sub_(mean).div_(std) 153 | tensor = tensor.permute(1, 0).view(C, T, H, W) # Put C back in front 154 | return tensor 155 | -------------------------------------------------------------------------------- /app/vjepa_droid/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import torch 9 | import torchvision.transforms as transforms 10 | 11 | import src.datasets.utils.video.transforms as video_transforms 12 | from src.datasets.utils.video.randerase import RandomErasing 13 | 14 | 15 | def make_transforms( 16 | random_horizontal_flip=True, 17 | random_resize_aspect_ratio=(3 / 4, 4 / 3), 18 | random_resize_scale=(0.3, 1.0), 19 | reprob=0.0, 20 | auto_augment=False, 21 | motion_shift=False, 22 | crop_size=224, 23 | normalize=((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 24 | ): 25 | 26 | _frames_augmentation = VideoTransform( 27 | random_horizontal_flip=random_horizontal_flip, 28 | random_resize_aspect_ratio=random_resize_aspect_ratio, 29 | random_resize_scale=random_resize_scale, 30 | reprob=reprob, 31 | auto_augment=auto_augment, 32 | motion_shift=motion_shift, 33 | crop_size=crop_size, 34 | normalize=normalize, 35 | ) 36 | return _frames_augmentation 37 | 38 | 39 | class VideoTransform(object): 40 | 41 | def __init__( 42 | self, 43 | random_horizontal_flip=True, 44 | random_resize_aspect_ratio=(3 / 4, 4 / 3), 45 | random_resize_scale=(0.3, 1.0), 46 | reprob=0.0, 47 | auto_augment=False, 48 | motion_shift=False, 49 | crop_size=224, 50 | normalize=((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 51 | ): 52 | 53 | self.random_horizontal_flip = random_horizontal_flip 54 | self.random_resize_aspect_ratio = random_resize_aspect_ratio 55 | self.random_resize_scale = random_resize_scale 56 | self.auto_augment = auto_augment 57 | self.motion_shift = motion_shift 58 | self.crop_size = crop_size 59 | self.mean = torch.tensor(normalize[0], dtype=torch.float32) 60 | self.std = torch.tensor(normalize[1], dtype=torch.float32) 61 | if not self.auto_augment: 62 | # Without auto-augment, PIL and tensor conversions simply scale uint8 space by 255. 63 | self.mean *= 255.0 64 | self.std *= 255.0 65 | 66 | self.autoaug_transform = video_transforms.create_random_augment( 67 | input_size=(crop_size, crop_size), 68 | # auto_augment="rand-m4-n4-w1-mstd0.5-inc1", 69 | auto_augment="rand-m7-n4-mstd0.5-inc1", 70 | interpolation="bicubic", 71 | ) 72 | 73 | self.spatial_transform = ( 74 | video_transforms.random_resized_crop_with_shift if motion_shift else video_transforms.random_resized_crop 75 | ) 76 | 77 | self.reprob = reprob 78 | self.erase_transform = RandomErasing( 79 | reprob, 80 | mode="pixel", 81 | max_count=1, 82 | num_splits=1, 83 | device="cpu", 84 | ) 85 | 86 | def __call__(self, buffer): 87 | 88 | if self.auto_augment: 89 | buffer = [transforms.ToPILImage()(frame) for frame in buffer] 90 | buffer = self.autoaug_transform(buffer) 91 | buffer = [transforms.ToTensor()(img) for img in buffer] 92 | buffer = torch.stack(buffer) # T C H W 93 | buffer = buffer.permute(0, 2, 3, 1) # T H W C 94 | elif torch.is_tensor(buffer): 95 | # TODO: ensure input is always a tensor? 96 | buffer = buffer.to(torch.float32) 97 | else: 98 | buffer = torch.tensor(buffer, dtype=torch.float32) 99 | 100 | buffer = buffer.permute(3, 0, 1, 2) # T H W C -> C T H W 101 | 102 | buffer = self.spatial_transform( 103 | images=buffer, 104 | target_height=self.crop_size, 105 | target_width=self.crop_size, 106 | scale=self.random_resize_scale, 107 | ratio=self.random_resize_aspect_ratio, 108 | ) 109 | if self.random_horizontal_flip: 110 | buffer, _ = video_transforms.horizontal_flip(0.5, buffer) 111 | 112 | buffer = _tensor_normalize_inplace(buffer, self.mean, self.std) 113 | if self.reprob > 0: 114 | buffer = buffer.permute(1, 0, 2, 3) 115 | buffer = self.erase_transform(buffer) 116 | buffer = buffer.permute(1, 0, 2, 3) 117 | 118 | return buffer 119 | 120 | 121 | def tensor_normalize(tensor, mean, std): 122 | """ 123 | Normalize a given tensor by subtracting the mean and dividing the std. 124 | Args: 125 | tensor (tensor): tensor to normalize. 126 | mean (tensor or list): mean value to subtract. 127 | std (tensor or list): std to divide. 128 | """ 129 | if tensor.dtype == torch.uint8: 130 | tensor = tensor.float() 131 | tensor = tensor / 255.0 132 | if type(mean) == list: 133 | mean = torch.tensor(mean) 134 | if type(std) == list: 135 | std = torch.tensor(std) 136 | tensor = tensor - mean 137 | tensor = tensor / std 138 | return tensor 139 | 140 | 141 | def _tensor_normalize_inplace(tensor, mean, std): 142 | """ 143 | Normalize a given tensor by subtracting the mean and dividing the std. 144 | Args: 145 | tensor (tensor): tensor to normalize (with dimensions C, T, H, W). 146 | mean (tensor): mean value to subtract (in 0 to 255 floats). 147 | std (tensor): std to divide (in 0 to 255 floats). 148 | """ 149 | if tensor.dtype == torch.uint8: 150 | tensor = tensor.float() 151 | 152 | C, T, H, W = tensor.shape 153 | tensor = tensor.view(C, -1).permute(1, 0) # Make C the last dimension 154 | tensor.sub_(mean).div_(std) 155 | tensor = tensor.permute(1, 0).view(C, T, H, W) # Put C back in front 156 | return tensor 157 | -------------------------------------------------------------------------------- /assets/flowchart.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/vjepa2/2dbda75ac7ab8e4e16f1c987c2b76261b54a6e3d/assets/flowchart.png -------------------------------------------------------------------------------- /assets/vjepa2-abstract-new.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/vjepa2/2dbda75ac7ab8e4e16f1c987c2b76261b54a6e3d/assets/vjepa2-abstract-new.png -------------------------------------------------------------------------------- /assets/vjepa2-ac-abstract-new.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/vjepa2/2dbda75ac7ab8e4e16f1c987c2b76261b54a6e3d/assets/vjepa2-ac-abstract-new.png -------------------------------------------------------------------------------- /configs/eval/vitg-384/coin.yaml: -------------------------------------------------------------------------------- 1 | cpus_per_task: 16 2 | eval_name: video_classification_frozen 3 | folder: /your_folder/evals/vitg-384/coin 4 | mem_per_gpu: 220G 5 | nodes: 16 6 | resume_checkpoint: true 7 | tag: coin-vitg16-384-16x8x3 8 | tasks_per_node: 8 9 | experiment: 10 | classifier: 11 | num_heads: 16 12 | num_probe_blocks: 4 13 | data: 14 | dataset_type: VideoDataset 15 | dataset_train: /your_data_folder/COIN/train_paths.csv 16 | dataset_val: /your_data_folder/COIN/val_paths.csv 17 | frame_step: 4 18 | frames_per_clip: 16 19 | num_classes: 180 20 | num_segments: 8 21 | num_views_per_segment: 3 22 | resolution: 384 23 | optimization: 24 | batch_size: 1 25 | multihead_kwargs: 26 | - final_lr: 0.0 27 | final_weight_decay: 0.01 28 | lr: 0.005 29 | start_lr: 0.005 30 | warmup: 0.0 31 | weight_decay: 0.01 32 | - final_lr: 0.0 33 | final_weight_decay: 0.01 34 | lr: 0.003 35 | start_lr: 0.003 36 | warmup: 0.0 37 | weight_decay: 0.01 38 | - final_lr: 0.0 39 | final_weight_decay: 0.01 40 | lr: 0.001 41 | start_lr: 0.001 42 | warmup: 0.0 43 | weight_decay: 0.01 44 | - final_lr: 0.0 45 | final_weight_decay: 0.01 46 | lr: 0.0003 47 | start_lr: 0.0003 48 | warmup: 0.0 49 | weight_decay: 0.01 50 | - final_lr: 0.0 51 | final_weight_decay: 0.01 52 | lr: 0.0001 53 | start_lr: 0.0001 54 | warmup: 0.0 55 | weight_decay: 0.01 56 | - final_lr: 0.0 57 | final_weight_decay: 0.1 58 | lr: 0.005 59 | start_lr: 0.005 60 | warmup: 0.0 61 | weight_decay: 0.1 62 | - final_lr: 0.0 63 | final_weight_decay: 0.1 64 | lr: 0.003 65 | start_lr: 0.003 66 | warmup: 0.0 67 | weight_decay: 0.1 68 | - final_lr: 0.0 69 | final_weight_decay: 0.1 70 | lr: 0.001 71 | start_lr: 0.001 72 | warmup: 0.0 73 | weight_decay: 0.1 74 | - final_lr: 0.0 75 | final_weight_decay: 0.1 76 | lr: 0.0003 77 | start_lr: 0.0003 78 | warmup: 0.0 79 | weight_decay: 0.1 80 | - final_lr: 0.0 81 | final_weight_decay: 0.1 82 | lr: 0.0001 83 | start_lr: 0.0001 84 | warmup: 0.0 85 | weight_decay: 0.1 86 | - final_lr: 0.0 87 | final_weight_decay: 0.4 88 | lr: 0.005 89 | start_lr: 0.005 90 | warmup: 0.0 91 | weight_decay: 0.4 92 | - final_lr: 0.0 93 | final_weight_decay: 0.4 94 | lr: 0.003 95 | start_lr: 0.003 96 | warmup: 0.0 97 | weight_decay: 0.4 98 | - final_lr: 0.0 99 | final_weight_decay: 0.4 100 | lr: 0.001 101 | start_lr: 0.001 102 | warmup: 0.0 103 | weight_decay: 0.4 104 | - final_lr: 0.0 105 | final_weight_decay: 0.4 106 | lr: 0.0003 107 | start_lr: 0.0003 108 | warmup: 0.0 109 | weight_decay: 0.4 110 | - final_lr: 0.0 111 | final_weight_decay: 0.4 112 | lr: 0.0001 113 | start_lr: 0.0001 114 | warmup: 0.0 115 | weight_decay: 0.4 116 | - final_lr: 0.0 117 | final_weight_decay: 0.8 118 | lr: 0.005 119 | start_lr: 0.005 120 | warmup: 0.0 121 | weight_decay: 0.8 122 | - final_lr: 0.0 123 | final_weight_decay: 0.8 124 | lr: 0.003 125 | start_lr: 0.003 126 | warmup: 0.0 127 | weight_decay: 0.8 128 | - final_lr: 0.0 129 | final_weight_decay: 0.8 130 | lr: 0.001 131 | start_lr: 0.001 132 | warmup: 0.0 133 | weight_decay: 0.8 134 | - final_lr: 0.0 135 | final_weight_decay: 0.8 136 | lr: 0.0003 137 | start_lr: 0.0003 138 | warmup: 0.0 139 | weight_decay: 0.8 140 | - final_lr: 0.0 141 | final_weight_decay: 0.8 142 | lr: 0.0001 143 | start_lr: 0.0001 144 | warmup: 0.0 145 | weight_decay: 0.8 146 | num_epochs: 20 147 | use_bfloat16: true 148 | use_pos_embed: false 149 | model_kwargs: 150 | checkpoint: /your_vjepa2_checkpoints/vitg-384.pt 151 | module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip 152 | pretrain_kwargs: 153 | encoder: 154 | checkpoint_key: target_encoder 155 | img_temporal_dim_size: null 156 | model_name: vit_giant_xformers 157 | patch_size: 16 158 | tubelet_size: 2 159 | uniform_power: true 160 | use_rope: true 161 | wrapper_kwargs: 162 | max_frames: 128 163 | use_pos_embed: false 164 | -------------------------------------------------------------------------------- /configs/eval/vitg-384/diving48.yaml: -------------------------------------------------------------------------------- 1 | nodes: 8 2 | tasks_per_node: 8 3 | cpus_per_task: 12 4 | mem_per_gpu: 200G 5 | tag: diving48-vitg16-384-32x4x3 6 | eval_name: video_classification_frozen 7 | folder: /your_folder/evals/vitg-384/diving48 8 | resume_checkpoint: true 9 | experiment: 10 | classifier: 11 | num_probe_blocks: 4 12 | num_heads: 16 13 | data: 14 | dataset_type: VideoDataset 15 | dataset_train: /your_data_dir/diving48/annotations/Diving48_train_paths.csv 16 | dataset_val: /your_data_dir/diving48/annotations/Diving48_test_paths.csv 17 | num_classes: 48 18 | resolution: 384 19 | frames_per_clip: 32 20 | frame_step: 2 21 | num_segments: 4 22 | num_views_per_segment: 3 23 | optimization: 24 | use_pos_embed: false 25 | num_epochs: 100 26 | batch_size: 2 27 | use_bfloat16: true 28 | multihead_kwargs: 29 | - weight_decay: 0.8 30 | final_weight_decay: 0.8 31 | lr: 0.001 32 | start_lr: 0.001 33 | final_lr: 0.0 34 | warmup: 0. 35 | - weight_decay: 0.8 36 | final_weight_decay: 0.8 37 | lr: 0.0003 38 | start_lr: 0.0003 39 | final_lr: 0.0 40 | warmup: 0. 41 | - weight_decay: 0.8 42 | final_weight_decay: 0.8 43 | lr: 0.0001 44 | start_lr: 0.0001 45 | final_lr: 0.0 46 | warmup: 0. 47 | model_kwargs: 48 | checkpoint: /your_vjepa2_checkpoints/vitg-384.pt 49 | module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip_multilevel 50 | wrapper_kwargs: 51 | max_frames: 128 52 | use_pos_embed: false 53 | out_layers: [24, 29, 34, 39] 54 | pretrain_kwargs: 55 | encoder: 56 | model_name: vit_giant_xformers 57 | checkpoint_key: target_encoder 58 | tubelet_size: 2 59 | patch_size: 16 60 | uniform_power: true 61 | use_rope: true 62 | -------------------------------------------------------------------------------- /configs/eval/vitg-384/ek100.yaml: -------------------------------------------------------------------------------- 1 | nodes: 8 2 | tasks_per_node: 8 3 | cpus_per_task: 12 4 | tag: ek100-vitg16-384 5 | eval_name: action_anticipation_frozen 6 | folder: /your_folder/evals/vitg-384/ek100 7 | resume_checkpoint: true 8 | experiment: 9 | classifier: 10 | num_probe_blocks: 4 11 | num_heads: 16 12 | data: 13 | anticipation_time_sec: 14 | - 1.0 15 | - 1.0 16 | auto_augment: true 17 | file_format: 0 18 | dataset: EK100 19 | # e.g. /home/username/EPIC-KITCHENS 20 | base_path: /your_ek100_root_dir/ 21 | dataset_train: /your_data/EPIC_100_train.csv 22 | dataset_val: /your_data/EPIC_100_validation.csv 23 | frames_per_clip: 32 24 | frames_per_second: 8 25 | motion_shift: false 26 | num_workers: 2 27 | pin_memory: true 28 | random_resize_scale: 29 | - 0.08 30 | - 1.0 31 | reprob: 0.25 32 | resolution: 384 33 | train_anticipation_point: 34 | - 0.0 35 | - 0.25 36 | train_anticipation_time_sec: 37 | - 0.25 38 | - 1.75 39 | optimization: 40 | num_epochs: 20 41 | batch_size: 2 42 | use_bfloat16: true 43 | use_focal_loss: true 44 | multihead_kwargs: 45 | - weight_decay: 0.0001 46 | final_weight_decay: 0.0001 47 | lr: 0.005 48 | start_lr: 0.005 49 | final_lr: 0.0 50 | warmup: 0. 51 | - weight_decay: 0.0001 52 | final_weight_decay: 0.0001 53 | lr: 0.003 54 | start_lr: 0.003 55 | final_lr: 0.0 56 | warmup: 0. 57 | - weight_decay: 0.0001 58 | final_weight_decay: 0.0001 59 | lr: 0.001 60 | start_lr: 0.001 61 | final_lr: 0.0 62 | warmup: 0. 63 | - weight_decay: 0.0001 64 | final_weight_decay: 0.0001 65 | lr: 0.0003 66 | start_lr: 0.0003 67 | final_lr: 0.0 68 | warmup: 0. 69 | - weight_decay: 0.0001 70 | final_weight_decay: 0.0001 71 | lr: 0.0001 72 | start_lr: 0.0001 73 | final_lr: 0.0 74 | warmup: 0. 75 | 76 | - weight_decay: 0.001 77 | final_weight_decay: 0.001 78 | lr: 0.005 79 | start_lr: 0.005 80 | final_lr: 0.0 81 | warmup: 0. 82 | - weight_decay: 0.001 83 | final_weight_decay: 0.001 84 | lr: 0.003 85 | start_lr: 0.003 86 | final_lr: 0.0 87 | warmup: 0. 88 | - weight_decay: 0.001 89 | final_weight_decay: 0.001 90 | lr: 0.001 91 | start_lr: 0.001 92 | final_lr: 0.0 93 | warmup: 0. 94 | - weight_decay: 0.001 95 | final_weight_decay: 0.001 96 | lr: 0.0003 97 | start_lr: 0.0003 98 | final_lr: 0.0 99 | warmup: 0. 100 | - weight_decay: 0.001 101 | final_weight_decay: 0.001 102 | lr: 0.0001 103 | start_lr: 0.0001 104 | final_lr: 0.0 105 | warmup: 0. 106 | 107 | - weight_decay: 0.01 108 | final_weight_decay: 0.01 109 | lr: 0.005 110 | start_lr: 0.005 111 | final_lr: 0.0 112 | warmup: 0. 113 | - weight_decay: 0.01 114 | final_weight_decay: 0.01 115 | lr: 0.003 116 | start_lr: 0.003 117 | final_lr: 0.0 118 | warmup: 0. 119 | - weight_decay: 0.01 120 | final_weight_decay: 0.01 121 | lr: 0.001 122 | start_lr: 0.001 123 | final_lr: 0.0 124 | warmup: 0. 125 | - weight_decay: 0.01 126 | final_weight_decay: 0.01 127 | lr: 0.0003 128 | start_lr: 0.0003 129 | final_lr: 0.0 130 | warmup: 0. 131 | - weight_decay: 0.01 132 | final_weight_decay: 0.01 133 | lr: 0.0001 134 | start_lr: 0.0001 135 | final_lr: 0.0 136 | warmup: 0. 137 | 138 | - weight_decay: 0.1 139 | final_weight_decay: 0.1 140 | lr: 0.005 141 | start_lr: 0.005 142 | final_lr: 0.0 143 | warmup: 0. 144 | - weight_decay: 0.1 145 | final_weight_decay: 0.1 146 | lr: 0.003 147 | start_lr: 0.003 148 | final_lr: 0.0 149 | warmup: 0. 150 | - weight_decay: 0.1 151 | final_weight_decay: 0.1 152 | lr: 0.001 153 | start_lr: 0.001 154 | final_lr: 0.0 155 | warmup: 0. 156 | - weight_decay: 0.1 157 | final_weight_decay: 0.1 158 | lr: 0.0003 159 | start_lr: 0.0003 160 | final_lr: 0.0 161 | warmup: 0. 162 | - weight_decay: 0.1 163 | final_weight_decay: 0.1 164 | lr: 0.0001 165 | start_lr: 0.0001 166 | final_lr: 0.0 167 | warmup: 0. 168 | model_kwargs: 169 | checkpoint: /your_vjepa2_checkpoints/vitg-384.pt 170 | module_name: evals.action_anticipation_frozen.modelcustom.vit_encoder_predictor_concat_ar 171 | wrapper_kwargs: 172 | no_predictor: false 173 | num_output_frames: 2 174 | num_steps: 1 175 | pretrain_kwargs: 176 | encoder: 177 | model_name: vit_giant_xformers 178 | checkpoint_key: target_encoder 179 | tubelet_size: 2 180 | patch_size: 16 181 | uniform_power: true 182 | use_rope: true 183 | predictor: 184 | model_name: vit_predictor 185 | checkpoint_key: predictor 186 | num_frames: 64 187 | depth: 12 188 | num_heads: 12 189 | predictor_embed_dim: 384 190 | num_mask_tokens: 10 191 | uniform_power: true 192 | use_mask_tokens: true 193 | use_sdpa: true 194 | use_silu: false 195 | wide_silu: false 196 | use_rope: true 197 | -------------------------------------------------------------------------------- /configs/eval/vitg-384/in1k.yaml: -------------------------------------------------------------------------------- 1 | cpus_per_task: 16 2 | eval_name: image_classification_frozen 3 | folder: /your_folder/evals/vitg-384/in1k 4 | mem_per_gpu: 220G 5 | nodes: 16 6 | resume_checkpoint: true 7 | tag: in1k-vitg16-384-18f 8 | tasks_per_node: 8 9 | experiment: 10 | classifier: 11 | num_heads: 16 12 | num_probe_blocks: 4 13 | data: 14 | dataset_name: ImageNet 15 | num_classes: 1000 16 | root_path: /datasets/ 17 | image_folder: ImageNet_FullSize/240712/061417/ 18 | resolution: 384 19 | optimization: 20 | batch_size: 8 21 | multihead_kwargs: 22 | - final_lr: 0.0 23 | final_weight_decay: 0.0 24 | lr: 0.001 25 | start_lr: 0.001 26 | warmup: 0.0 27 | weight_decay: 0.001 28 | - final_lr: 0.0 29 | final_weight_decay: 0.008 30 | lr: 0.0005 31 | start_lr: 0.0002 32 | warmup: 5 33 | weight_decay: 0.008 34 | - final_lr: 0.0 35 | final_weight_decay: 0.004 36 | lr: 0.0005 37 | start_lr: 0.0002 38 | warmup: 5 39 | weight_decay: 0.004 40 | - final_lr: 0.0 41 | final_weight_decay: 0.002 42 | lr: 0.0005 43 | start_lr: 0.0002 44 | warmup: 5 45 | weight_decay: 0.002 46 | - final_lr: 0.0 47 | final_weight_decay: 0.001 48 | lr: 0.0005 49 | start_lr: 0.0002 50 | warmup: 5 51 | weight_decay: 0.001 52 | - final_lr: 0.0 53 | final_weight_decay: 0.0005 54 | lr: 0.0005 55 | start_lr: 0.0002 56 | warmup: 5 57 | weight_decay: 0.0005 58 | - final_lr: 0.0 59 | final_weight_decay: 0.008 60 | lr: 0.001 61 | start_lr: 0.0002 62 | warmup: 5 63 | weight_decay: 0.008 64 | - final_lr: 0.0 65 | final_weight_decay: 0.004 66 | lr: 0.001 67 | start_lr: 0.0002 68 | warmup: 5 69 | weight_decay: 0.004 70 | - final_lr: 0.0 71 | final_weight_decay: 0.002 72 | lr: 0.001 73 | start_lr: 0.0002 74 | warmup: 5 75 | weight_decay: 0.002 76 | - final_lr: 0.0 77 | final_weight_decay: 0.001 78 | lr: 0.001 79 | start_lr: 0.0002 80 | warmup: 5 81 | weight_decay: 0.001 82 | - final_lr: 0.0 83 | final_weight_decay: 0.0005 84 | lr: 0.001 85 | start_lr: 0.0002 86 | warmup: 5 87 | weight_decay: 0.0005 88 | - final_lr: 0.0 89 | final_weight_decay: 0.008 90 | lr: 0.0015 91 | start_lr: 0.0002 92 | warmup: 5 93 | weight_decay: 0.008 94 | - final_lr: 0.0 95 | final_weight_decay: 0.004 96 | lr: 0.0015 97 | start_lr: 0.0002 98 | warmup: 5 99 | weight_decay: 0.004 100 | - final_lr: 0.0 101 | final_weight_decay: 0.002 102 | lr: 0.0015 103 | start_lr: 0.0002 104 | warmup: 5 105 | weight_decay: 0.002 106 | - final_lr: 0.0 107 | final_weight_decay: 0.001 108 | lr: 0.0015 109 | start_lr: 0.0002 110 | warmup: 5 111 | weight_decay: 0.001 112 | - final_lr: 0.0 113 | final_weight_decay: 0.0005 114 | lr: 0.0015 115 | start_lr: 0.0002 116 | warmup: 5 117 | weight_decay: 0.0005 118 | - final_lr: 0.0 119 | final_weight_decay: 0.008 120 | lr: 0.002 121 | start_lr: 0.0002 122 | warmup: 5 123 | weight_decay: 0.008 124 | - final_lr: 0.0 125 | final_weight_decay: 0.004 126 | lr: 0.002 127 | start_lr: 0.0002 128 | warmup: 5 129 | weight_decay: 0.004 130 | - final_lr: 0.0 131 | final_weight_decay: 0.002 132 | lr: 0.002 133 | start_lr: 0.0002 134 | warmup: 5 135 | weight_decay: 0.002 136 | - final_lr: 0.0 137 | final_weight_decay: 0.001 138 | lr: 0.002 139 | start_lr: 0.0002 140 | warmup: 5 141 | weight_decay: 0.001 142 | - final_lr: 0.0 143 | final_weight_decay: 0.0005 144 | lr: 0.002 145 | start_lr: 0.0002 146 | warmup: 5 147 | weight_decay: 0.0005 148 | num_epochs: 20 149 | use_bfloat16: true 150 | model_kwargs: 151 | checkpoint: /your_vjepa2_checkpoints/vitg-384.pt 152 | module_name: evals.image_classification_frozen.modelcustom.vit_encoder 153 | pretrain_kwargs: 154 | encoder: 155 | checkpoint_key: target_encoder 156 | img_temporal_dim_size: null 157 | model_name: vit_giant_xformers 158 | patch_size: 16 159 | tubelet_size: 2 160 | uniform_power: true 161 | use_rope: true 162 | wrapper_kwargs: 163 | img_as_video_nframes: 18 164 | -------------------------------------------------------------------------------- /configs/eval/vitg-384/jester.yaml: -------------------------------------------------------------------------------- 1 | nodes: 8 2 | tasks_per_node: 8 3 | cpus_per_task: 12 4 | mem_per_gpu: 200G 5 | tag: jester-vitg16-384-32x4x3 6 | eval_name: video_classification_frozen 7 | folder: /your_folder/evals/vitg-384/jester 8 | resume_checkpoint: true 9 | experiment: 10 | classifier: 11 | num_probe_blocks: 4 12 | num_heads: 16 13 | data: 14 | dataset_type: VideoDataset 15 | dataset_train: /your_data_dir/Jester/annotations/jester_train_paths.csv 16 | dataset_val: /your_data_dir/Jester/annotations/jester_validation_paths.csv 17 | num_classes: 27 18 | resolution: 384 19 | frames_per_clip: 32 20 | frame_step: 2 21 | num_segments: 4 22 | num_views_per_segment: 3 23 | optimization: 24 | use_pos_embed: false 25 | num_epochs: 100 26 | batch_size: 2 27 | use_bfloat16: true 28 | multihead_kwargs: 29 | - weight_decay: 0.8 30 | final_weight_decay: 0.8 31 | lr: 0.001 32 | start_lr: 0.001 33 | final_lr: 0.0 34 | warmup: 0. 35 | - weight_decay: 0.8 36 | final_weight_decay: 0.8 37 | lr: 0.0003 38 | start_lr: 0.0003 39 | final_lr: 0.0 40 | warmup: 0. 41 | - weight_decay: 0.8 42 | final_weight_decay: 0.8 43 | lr: 0.0001 44 | start_lr: 0.0001 45 | final_lr: 0.0 46 | warmup: 0. 47 | model_kwargs: 48 | checkpoint: /your_vjepa2_checkpoints/vitg-384.pt 49 | module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip_multilevel 50 | wrapper_kwargs: 51 | max_frames: 128 52 | use_pos_embed: false 53 | out_layers: [24, 29, 34, 39] 54 | pretrain_kwargs: 55 | encoder: 56 | model_name: vit_giant_xformers 57 | checkpoint_key: target_encoder 58 | tubelet_size: 2 59 | patch_size: 16 60 | uniform_power: true 61 | use_rope: true 62 | -------------------------------------------------------------------------------- /configs/eval/vitg-384/k400.yaml: -------------------------------------------------------------------------------- 1 | cpus_per_task: 16 2 | eval_name: video_classification_frozen 3 | folder: /your_folder/evals/vitg-384/k400 4 | mem_per_gpu: 220G 5 | nodes: 32 6 | num_workers: 8 7 | resume_checkpoint: true 8 | tag: k400-vitg16-384-16x8x3-16f 9 | tasks_per_node: 8 10 | experiment: 11 | classifier: 12 | num_heads: 16 13 | num_probe_blocks: 4 14 | data: 15 | dataset_type: VideoDataset 16 | dataset_train: /your_data_path/k400_train_paths.csv 17 | dataset_val: /your_data_path/k400_val_paths.csv 18 | frame_step: 4 19 | frames_per_clip: 16 20 | num_classes: 400 21 | num_segments: 8 22 | num_views_per_segment: 3 23 | resolution: 384 24 | optimization: 25 | batch_size: 1 26 | multihead_kwargs: 27 | - final_lr: 0.0 28 | final_weight_decay: 0.01 29 | lr: 0.005 30 | start_lr: 0.005 31 | warmup: 0.0 32 | weight_decay: 0.01 33 | - final_lr: 0.0 34 | final_weight_decay: 0.01 35 | lr: 0.003 36 | start_lr: 0.003 37 | warmup: 0.0 38 | weight_decay: 0.01 39 | - final_lr: 0.0 40 | final_weight_decay: 0.01 41 | lr: 0.001 42 | start_lr: 0.001 43 | warmup: 0.0 44 | weight_decay: 0.01 45 | - final_lr: 0.0 46 | final_weight_decay: 0.01 47 | lr: 0.0003 48 | start_lr: 0.0003 49 | warmup: 0.0 50 | weight_decay: 0.01 51 | - final_lr: 0.0 52 | final_weight_decay: 0.01 53 | lr: 0.0001 54 | start_lr: 0.0001 55 | warmup: 0.0 56 | weight_decay: 0.01 57 | - final_lr: 0.0 58 | final_weight_decay: 0.1 59 | lr: 0.005 60 | start_lr: 0.005 61 | warmup: 0.0 62 | weight_decay: 0.1 63 | - final_lr: 0.0 64 | final_weight_decay: 0.1 65 | lr: 0.003 66 | start_lr: 0.003 67 | warmup: 0.0 68 | weight_decay: 0.1 69 | - final_lr: 0.0 70 | final_weight_decay: 0.1 71 | lr: 0.001 72 | start_lr: 0.001 73 | warmup: 0.0 74 | weight_decay: 0.1 75 | - final_lr: 0.0 76 | final_weight_decay: 0.1 77 | lr: 0.0003 78 | start_lr: 0.0003 79 | warmup: 0.0 80 | weight_decay: 0.1 81 | - final_lr: 0.0 82 | final_weight_decay: 0.1 83 | lr: 0.0001 84 | start_lr: 0.0001 85 | warmup: 0.0 86 | weight_decay: 0.1 87 | - final_lr: 0.0 88 | final_weight_decay: 0.4 89 | lr: 0.005 90 | start_lr: 0.005 91 | warmup: 0.0 92 | weight_decay: 0.4 93 | - final_lr: 0.0 94 | final_weight_decay: 0.4 95 | lr: 0.003 96 | start_lr: 0.003 97 | warmup: 0.0 98 | weight_decay: 0.4 99 | - final_lr: 0.0 100 | final_weight_decay: 0.4 101 | lr: 0.001 102 | start_lr: 0.001 103 | warmup: 0.0 104 | weight_decay: 0.4 105 | - final_lr: 0.0 106 | final_weight_decay: 0.4 107 | lr: 0.0003 108 | start_lr: 0.0003 109 | warmup: 0.0 110 | weight_decay: 0.4 111 | - final_lr: 0.0 112 | final_weight_decay: 0.4 113 | lr: 0.0001 114 | start_lr: 0.0001 115 | warmup: 0.0 116 | weight_decay: 0.4 117 | - final_lr: 0.0 118 | final_weight_decay: 0.8 119 | lr: 0.005 120 | start_lr: 0.005 121 | warmup: 0.0 122 | weight_decay: 0.8 123 | - final_lr: 0.0 124 | final_weight_decay: 0.8 125 | lr: 0.003 126 | start_lr: 0.003 127 | warmup: 0.0 128 | weight_decay: 0.8 129 | - final_lr: 0.0 130 | final_weight_decay: 0.8 131 | lr: 0.001 132 | start_lr: 0.001 133 | warmup: 0.0 134 | weight_decay: 0.8 135 | - final_lr: 0.0 136 | final_weight_decay: 0.8 137 | lr: 0.0003 138 | start_lr: 0.0003 139 | warmup: 0.0 140 | weight_decay: 0.8 141 | - final_lr: 0.0 142 | final_weight_decay: 0.8 143 | lr: 0.0001 144 | start_lr: 0.0001 145 | warmup: 0.0 146 | weight_decay: 0.8 147 | num_epochs: 20 148 | use_bfloat16: true 149 | use_pos_embed: false 150 | model_kwargs: 151 | checkpoint: /your_vjepa2_checkpoints/vitg-384.pt 152 | module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip 153 | pretrain_kwargs: 154 | encoder: 155 | checkpoint_key: target_encoder 156 | img_temporal_dim_size: null 157 | model_name: vit_giant_xformers 158 | patch_size: 16 159 | tubelet_size: 2 160 | uniform_power: true 161 | use_rope: true 162 | wrapper_kwargs: 163 | max_frames: 128 164 | use_pos_embed: false 165 | -------------------------------------------------------------------------------- /configs/eval/vitg-384/ssv2.yaml: -------------------------------------------------------------------------------- 1 | cpus_per_task: 16 2 | eval_name: video_classification_frozen 3 | folder: /your_folder/evals/vitg-384/ssv2 4 | mem_per_gpu: 220G 5 | nodes: 16 6 | num_workers: 8 7 | resume_checkpoint: true 8 | tag: ssv2-vitg16-384-64x2x3 9 | tasks_per_node: 8 10 | experiment: 11 | classifier: 12 | num_heads: 16 13 | num_probe_blocks: 4 14 | data: 15 | dataset_type: VideoDataset 16 | dataset_train: /your_data_path/ssv2_train_paths.csv 17 | dataset_val: /your_data_path/ssv2_val_paths.csv 18 | frame_step: 2 19 | frames_per_clip: 64 20 | num_classes: 174 21 | num_segments: 2 22 | num_views_per_segment: 3 23 | resolution: 384 24 | optimization: 25 | batch_size: 2 26 | multihead_kwargs: 27 | - final_lr: 0.0 28 | final_weight_decay: 0.01 29 | lr: 0.005 30 | start_lr: 0.005 31 | warmup: 0.0 32 | weight_decay: 0.01 33 | - final_lr: 0.0 34 | final_weight_decay: 0.01 35 | lr: 0.003 36 | start_lr: 0.003 37 | warmup: 0.0 38 | weight_decay: 0.01 39 | - final_lr: 0.0 40 | final_weight_decay: 0.01 41 | lr: 0.001 42 | start_lr: 0.001 43 | warmup: 0.0 44 | weight_decay: 0.01 45 | - final_lr: 0.0 46 | final_weight_decay: 0.01 47 | lr: 0.0003 48 | start_lr: 0.0003 49 | warmup: 0.0 50 | weight_decay: 0.01 51 | - final_lr: 0.0 52 | final_weight_decay: 0.01 53 | lr: 0.0001 54 | start_lr: 0.0001 55 | warmup: 0.0 56 | weight_decay: 0.01 57 | - final_lr: 0.0 58 | final_weight_decay: 0.1 59 | lr: 0.005 60 | start_lr: 0.005 61 | warmup: 0.0 62 | weight_decay: 0.1 63 | - final_lr: 0.0 64 | final_weight_decay: 0.1 65 | lr: 0.003 66 | start_lr: 0.003 67 | warmup: 0.0 68 | weight_decay: 0.1 69 | - final_lr: 0.0 70 | final_weight_decay: 0.1 71 | lr: 0.001 72 | start_lr: 0.001 73 | warmup: 0.0 74 | weight_decay: 0.1 75 | - final_lr: 0.0 76 | final_weight_decay: 0.1 77 | lr: 0.0003 78 | start_lr: 0.0003 79 | warmup: 0.0 80 | weight_decay: 0.1 81 | - final_lr: 0.0 82 | final_weight_decay: 0.1 83 | lr: 0.0001 84 | start_lr: 0.0001 85 | warmup: 0.0 86 | weight_decay: 0.1 87 | - final_lr: 0.0 88 | final_weight_decay: 0.4 89 | lr: 0.005 90 | start_lr: 0.005 91 | warmup: 0.0 92 | weight_decay: 0.4 93 | - final_lr: 0.0 94 | final_weight_decay: 0.4 95 | lr: 0.003 96 | start_lr: 0.003 97 | warmup: 0.0 98 | weight_decay: 0.4 99 | - final_lr: 0.0 100 | final_weight_decay: 0.4 101 | lr: 0.001 102 | start_lr: 0.001 103 | warmup: 0.0 104 | weight_decay: 0.4 105 | - final_lr: 0.0 106 | final_weight_decay: 0.4 107 | lr: 0.0003 108 | start_lr: 0.0003 109 | warmup: 0.0 110 | weight_decay: 0.4 111 | - final_lr: 0.0 112 | final_weight_decay: 0.4 113 | lr: 0.0001 114 | start_lr: 0.0001 115 | warmup: 0.0 116 | weight_decay: 0.4 117 | - final_lr: 0.0 118 | final_weight_decay: 0.8 119 | lr: 0.005 120 | start_lr: 0.005 121 | warmup: 0.0 122 | weight_decay: 0.8 123 | - final_lr: 0.0 124 | final_weight_decay: 0.8 125 | lr: 0.003 126 | start_lr: 0.003 127 | warmup: 0.0 128 | weight_decay: 0.8 129 | - final_lr: 0.0 130 | final_weight_decay: 0.8 131 | lr: 0.001 132 | start_lr: 0.001 133 | warmup: 0.0 134 | weight_decay: 0.8 135 | - final_lr: 0.0 136 | final_weight_decay: 0.8 137 | lr: 0.0003 138 | start_lr: 0.0003 139 | warmup: 0.0 140 | weight_decay: 0.8 141 | - final_lr: 0.0 142 | final_weight_decay: 0.8 143 | lr: 0.0001 144 | start_lr: 0.0001 145 | warmup: 0.0 146 | weight_decay: 0.8 147 | num_epochs: 20 148 | use_bfloat16: true 149 | use_pos_embed: false 150 | model_kwargs: 151 | checkpoint: /your_vjepa2_checkpoints/vitg-384.pt 152 | module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip 153 | pretrain_kwargs: 154 | encoder: 155 | checkpoint_key: target_encoder 156 | img_temporal_dim_size: null 157 | model_name: vit_giant_xformers 158 | patch_size: 16 159 | tubelet_size: 2 160 | uniform_power: true 161 | use_rope: true 162 | wrapper_kwargs: 163 | max_frames: 128 164 | use_pos_embed: false 165 | -------------------------------------------------------------------------------- /configs/eval/vitl/coin.yaml: -------------------------------------------------------------------------------- 1 | cpus_per_task: 16 2 | eval_name: video_classification_frozen 3 | folder: /your_folder/evals/vitl/coin 4 | mem_per_gpu: 220G 5 | nodes: 8 6 | resume_checkpoint: true 7 | tag: coin-vitl16-384-16x8x3 8 | tasks_per_node: 8 9 | experiment: 10 | classifier: 11 | num_heads: 16 12 | num_probe_blocks: 4 13 | data: 14 | dataset_type: VideoDataset 15 | dataset_train: /your_data_folder/COIN/train_paths.csv 16 | dataset_val: /your_data_folder/COIN/val_paths.csv 17 | frame_step: 4 18 | frames_per_clip: 16 19 | num_classes: 180 20 | num_segments: 8 21 | num_views_per_segment: 3 22 | resolution: 256 23 | optimization: 24 | batch_size: 2 25 | multihead_kwargs: 26 | - final_lr: 0.0 27 | final_weight_decay: 0.01 28 | lr: 0.005 29 | start_lr: 0.005 30 | warmup: 0.0 31 | weight_decay: 0.01 32 | - final_lr: 0.0 33 | final_weight_decay: 0.01 34 | lr: 0.003 35 | start_lr: 0.003 36 | warmup: 0.0 37 | weight_decay: 0.01 38 | - final_lr: 0.0 39 | final_weight_decay: 0.01 40 | lr: 0.001 41 | start_lr: 0.001 42 | warmup: 0.0 43 | weight_decay: 0.01 44 | - final_lr: 0.0 45 | final_weight_decay: 0.01 46 | lr: 0.0003 47 | start_lr: 0.0003 48 | warmup: 0.0 49 | weight_decay: 0.01 50 | - final_lr: 0.0 51 | final_weight_decay: 0.01 52 | lr: 0.0001 53 | start_lr: 0.0001 54 | warmup: 0.0 55 | weight_decay: 0.01 56 | - final_lr: 0.0 57 | final_weight_decay: 0.1 58 | lr: 0.005 59 | start_lr: 0.005 60 | warmup: 0.0 61 | weight_decay: 0.1 62 | - final_lr: 0.0 63 | final_weight_decay: 0.1 64 | lr: 0.003 65 | start_lr: 0.003 66 | warmup: 0.0 67 | weight_decay: 0.1 68 | - final_lr: 0.0 69 | final_weight_decay: 0.1 70 | lr: 0.001 71 | start_lr: 0.001 72 | warmup: 0.0 73 | weight_decay: 0.1 74 | - final_lr: 0.0 75 | final_weight_decay: 0.1 76 | lr: 0.0003 77 | start_lr: 0.0003 78 | warmup: 0.0 79 | weight_decay: 0.1 80 | - final_lr: 0.0 81 | final_weight_decay: 0.1 82 | lr: 0.0001 83 | start_lr: 0.0001 84 | warmup: 0.0 85 | weight_decay: 0.1 86 | - final_lr: 0.0 87 | final_weight_decay: 0.4 88 | lr: 0.005 89 | start_lr: 0.005 90 | warmup: 0.0 91 | weight_decay: 0.4 92 | - final_lr: 0.0 93 | final_weight_decay: 0.4 94 | lr: 0.003 95 | start_lr: 0.003 96 | warmup: 0.0 97 | weight_decay: 0.4 98 | - final_lr: 0.0 99 | final_weight_decay: 0.4 100 | lr: 0.001 101 | start_lr: 0.001 102 | warmup: 0.0 103 | weight_decay: 0.4 104 | - final_lr: 0.0 105 | final_weight_decay: 0.4 106 | lr: 0.0003 107 | start_lr: 0.0003 108 | warmup: 0.0 109 | weight_decay: 0.4 110 | - final_lr: 0.0 111 | final_weight_decay: 0.4 112 | lr: 0.0001 113 | start_lr: 0.0001 114 | warmup: 0.0 115 | weight_decay: 0.4 116 | - final_lr: 0.0 117 | final_weight_decay: 0.8 118 | lr: 0.005 119 | start_lr: 0.005 120 | warmup: 0.0 121 | weight_decay: 0.8 122 | - final_lr: 0.0 123 | final_weight_decay: 0.8 124 | lr: 0.003 125 | start_lr: 0.003 126 | warmup: 0.0 127 | weight_decay: 0.8 128 | - final_lr: 0.0 129 | final_weight_decay: 0.8 130 | lr: 0.001 131 | start_lr: 0.001 132 | warmup: 0.0 133 | weight_decay: 0.8 134 | - final_lr: 0.0 135 | final_weight_decay: 0.8 136 | lr: 0.0003 137 | start_lr: 0.0003 138 | warmup: 0.0 139 | weight_decay: 0.8 140 | - final_lr: 0.0 141 | final_weight_decay: 0.8 142 | lr: 0.0001 143 | start_lr: 0.0001 144 | warmup: 0.0 145 | weight_decay: 0.8 146 | num_epochs: 20 147 | use_bfloat16: true 148 | use_pos_embed: false 149 | model_kwargs: 150 | checkpoint: /your_vjepa2_checkpoints/vitl.pt 151 | module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip 152 | pretrain_kwargs: 153 | encoder: 154 | checkpoint_key: target_encoder 155 | img_temporal_dim_size: null 156 | model_name: vit_large 157 | patch_size: 16 158 | tubelet_size: 2 159 | uniform_power: true 160 | use_rope: true 161 | wrapper_kwargs: 162 | max_frames: 128 163 | use_pos_embed: false 164 | -------------------------------------------------------------------------------- /configs/eval/vitl/diving48.yaml: -------------------------------------------------------------------------------- 1 | nodes: 8 2 | tasks_per_node: 8 3 | cpus_per_task: 12 4 | mem_per_gpu: 200G 5 | tag: diving48-vitl16-32x4x3 6 | eval_name: video_classification_frozen 7 | folder: /your_folder/evals/vitl/diving48 8 | resume_checkpoint: true 9 | experiment: 10 | classifier: 11 | num_probe_blocks: 4 12 | num_heads: 16 13 | data: 14 | dataset_type: VideoDataset 15 | dataset_train: /your_data_dir/diving48/annotations/Diving48_train_paths.csv 16 | dataset_val: /your_data_dir/diving48/annotations/Diving48_test_paths.csv 17 | num_classes: 48 18 | resolution: 256 19 | frames_per_clip: 32 20 | frame_step: 2 21 | num_segments: 4 22 | num_views_per_segment: 3 23 | optimization: 24 | use_pos_embed: false 25 | num_epochs: 100 26 | batch_size: 2 27 | use_bfloat16: true 28 | multihead_kwargs: 29 | - weight_decay: 0.8 30 | final_weight_decay: 0.8 31 | lr: 0.001 32 | start_lr: 0.001 33 | final_lr: 0.0 34 | warmup: 0. 35 | - weight_decay: 0.8 36 | final_weight_decay: 0.8 37 | lr: 0.0003 38 | start_lr: 0.0003 39 | final_lr: 0.0 40 | warmup: 0. 41 | - weight_decay: 0.8 42 | final_weight_decay: 0.8 43 | lr: 0.0001 44 | start_lr: 0.0001 45 | final_lr: 0.0 46 | warmup: 0. 47 | model_kwargs: 48 | checkpoint: /your_vjepa2_checkpoints/vitl.pt 49 | module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip_multilevel 50 | wrapper_kwargs: 51 | max_frames: 128 52 | use_pos_embed: false 53 | out_layers: [17, 19, 21, 23] 54 | pretrain_kwargs: 55 | encoder: 56 | model_name: vit_large 57 | checkpoint_key: target_encoder 58 | tubelet_size: 2 59 | patch_size: 16 60 | uniform_power: true 61 | use_rope: true 62 | -------------------------------------------------------------------------------- /configs/eval/vitl/ek100.yaml: -------------------------------------------------------------------------------- 1 | nodes: 8 2 | tasks_per_node: 8 3 | cpus_per_task: 12 4 | tag: ek100-vitl16 5 | eval_name: action_anticipation_frozen 6 | folder: /your_folder/evals/vitl/ek100 7 | resume_checkpoint: true 8 | experiment: 9 | classifier: 10 | num_probe_blocks: 4 11 | num_heads: 16 12 | data: 13 | anticipation_time_sec: 14 | - 1.0 15 | - 1.0 16 | auto_augment: true 17 | file_format: 0 18 | dataset: EK100 19 | # e.g. /home/username/EPIC-KITCHENS 20 | base_path: /your_ek100_root_dir/ 21 | dataset_train: /your_data/EPIC_100_train.csv 22 | dataset_val: /your_data/EPIC_100_validation.csv 23 | frames_per_clip: 32 24 | frames_per_second: 8 25 | motion_shift: false 26 | num_workers: 2 27 | pin_memory: true 28 | random_resize_scale: 29 | - 0.08 30 | - 1.0 31 | reprob: 0.25 32 | resolution: 256 33 | train_anticipation_point: 34 | - 0.0 35 | - 0.25 36 | train_anticipation_time_sec: 37 | - 0.25 38 | - 1.75 39 | optimization: 40 | num_epochs: 20 41 | batch_size: 2 42 | use_bfloat16: true 43 | use_focal_loss: true 44 | multihead_kwargs: 45 | - weight_decay: 0.0001 46 | final_weight_decay: 0.0001 47 | lr: 0.005 48 | start_lr: 0.005 49 | final_lr: 0.0 50 | warmup: 0. 51 | - weight_decay: 0.0001 52 | final_weight_decay: 0.0001 53 | lr: 0.003 54 | start_lr: 0.003 55 | final_lr: 0.0 56 | warmup: 0. 57 | - weight_decay: 0.0001 58 | final_weight_decay: 0.0001 59 | lr: 0.001 60 | start_lr: 0.001 61 | final_lr: 0.0 62 | warmup: 0. 63 | - weight_decay: 0.0001 64 | final_weight_decay: 0.0001 65 | lr: 0.0003 66 | start_lr: 0.0003 67 | final_lr: 0.0 68 | warmup: 0. 69 | - weight_decay: 0.0001 70 | final_weight_decay: 0.0001 71 | lr: 0.0001 72 | start_lr: 0.0001 73 | final_lr: 0.0 74 | warmup: 0. 75 | 76 | - weight_decay: 0.001 77 | final_weight_decay: 0.001 78 | lr: 0.005 79 | start_lr: 0.005 80 | final_lr: 0.0 81 | warmup: 0. 82 | - weight_decay: 0.001 83 | final_weight_decay: 0.001 84 | lr: 0.003 85 | start_lr: 0.003 86 | final_lr: 0.0 87 | warmup: 0. 88 | - weight_decay: 0.001 89 | final_weight_decay: 0.001 90 | lr: 0.001 91 | start_lr: 0.001 92 | final_lr: 0.0 93 | warmup: 0. 94 | - weight_decay: 0.001 95 | final_weight_decay: 0.001 96 | lr: 0.0003 97 | start_lr: 0.0003 98 | final_lr: 0.0 99 | warmup: 0. 100 | - weight_decay: 0.001 101 | final_weight_decay: 0.001 102 | lr: 0.0001 103 | start_lr: 0.0001 104 | final_lr: 0.0 105 | warmup: 0. 106 | 107 | - weight_decay: 0.01 108 | final_weight_decay: 0.01 109 | lr: 0.005 110 | start_lr: 0.005 111 | final_lr: 0.0 112 | warmup: 0. 113 | - weight_decay: 0.01 114 | final_weight_decay: 0.01 115 | lr: 0.003 116 | start_lr: 0.003 117 | final_lr: 0.0 118 | warmup: 0. 119 | - weight_decay: 0.01 120 | final_weight_decay: 0.01 121 | lr: 0.001 122 | start_lr: 0.001 123 | final_lr: 0.0 124 | warmup: 0. 125 | - weight_decay: 0.01 126 | final_weight_decay: 0.01 127 | lr: 0.0003 128 | start_lr: 0.0003 129 | final_lr: 0.0 130 | warmup: 0. 131 | - weight_decay: 0.01 132 | final_weight_decay: 0.01 133 | lr: 0.0001 134 | start_lr: 0.0001 135 | final_lr: 0.0 136 | warmup: 0. 137 | 138 | - weight_decay: 0.1 139 | final_weight_decay: 0.1 140 | lr: 0.005 141 | start_lr: 0.005 142 | final_lr: 0.0 143 | warmup: 0. 144 | - weight_decay: 0.1 145 | final_weight_decay: 0.1 146 | lr: 0.003 147 | start_lr: 0.003 148 | final_lr: 0.0 149 | warmup: 0. 150 | - weight_decay: 0.1 151 | final_weight_decay: 0.1 152 | lr: 0.001 153 | start_lr: 0.001 154 | final_lr: 0.0 155 | warmup: 0. 156 | - weight_decay: 0.1 157 | final_weight_decay: 0.1 158 | lr: 0.0003 159 | start_lr: 0.0003 160 | final_lr: 0.0 161 | warmup: 0. 162 | - weight_decay: 0.1 163 | final_weight_decay: 0.1 164 | lr: 0.0001 165 | start_lr: 0.0001 166 | final_lr: 0.0 167 | warmup: 0. 168 | model_kwargs: 169 | checkpoint: /your_vjepa2_checkpoints/vitl.pt 170 | module_name: evals.action_anticipation_frozen.modelcustom.vit_encoder_predictor_concat_ar 171 | wrapper_kwargs: 172 | no_predictor: false 173 | num_output_frames: 2 174 | num_steps: 1 175 | pretrain_kwargs: 176 | encoder: 177 | model_name: vit_large 178 | checkpoint_key: target_encoder 179 | tubelet_size: 2 180 | patch_size: 16 181 | uniform_power: true 182 | use_rope: true 183 | predictor: 184 | model_name: vit_predictor 185 | checkpoint_key: predictor 186 | num_frames: 64 187 | depth: 12 188 | num_heads: 12 189 | predictor_embed_dim: 384 190 | num_mask_tokens: 10 191 | uniform_power: true 192 | use_mask_tokens: true 193 | use_sdpa: true 194 | use_silu: false 195 | wide_silu: false 196 | use_rope: true 197 | -------------------------------------------------------------------------------- /configs/eval/vitl/in1k.yaml: -------------------------------------------------------------------------------- 1 | cpus_per_task: 16 2 | eval_name: image_classification_frozen 3 | folder: /your_folder/evals/vitl/in1k 4 | mem_per_gpu: 220G 5 | nodes: 8 6 | resume_checkpoint: true 7 | tag: in1k-vitl16-384-18f 8 | tasks_per_node: 8 9 | experiment: 10 | classifier: 11 | num_heads: 16 12 | num_probe_blocks: 4 13 | data: 14 | dataset_name: ImageNet 15 | num_classes: 1000 16 | root_path: /datasets/ 17 | image_folder: ImageNet_FullSize/240712/061417/ 18 | resolution: 256 19 | optimization: 20 | batch_size: 16 21 | multihead_kwargs: 22 | - final_lr: 0.0 23 | final_weight_decay: 0.0 24 | lr: 0.001 25 | start_lr: 0.001 26 | warmup: 0.0 27 | weight_decay: 0.001 28 | - final_lr: 0.0 29 | final_weight_decay: 0.008 30 | lr: 0.0005 31 | start_lr: 0.0002 32 | warmup: 5 33 | weight_decay: 0.008 34 | - final_lr: 0.0 35 | final_weight_decay: 0.004 36 | lr: 0.0005 37 | start_lr: 0.0002 38 | warmup: 5 39 | weight_decay: 0.004 40 | - final_lr: 0.0 41 | final_weight_decay: 0.002 42 | lr: 0.0005 43 | start_lr: 0.0002 44 | warmup: 5 45 | weight_decay: 0.002 46 | - final_lr: 0.0 47 | final_weight_decay: 0.001 48 | lr: 0.0005 49 | start_lr: 0.0002 50 | warmup: 5 51 | weight_decay: 0.001 52 | - final_lr: 0.0 53 | final_weight_decay: 0.0005 54 | lr: 0.0005 55 | start_lr: 0.0002 56 | warmup: 5 57 | weight_decay: 0.0005 58 | - final_lr: 0.0 59 | final_weight_decay: 0.008 60 | lr: 0.001 61 | start_lr: 0.0002 62 | warmup: 5 63 | weight_decay: 0.008 64 | - final_lr: 0.0 65 | final_weight_decay: 0.004 66 | lr: 0.001 67 | start_lr: 0.0002 68 | warmup: 5 69 | weight_decay: 0.004 70 | - final_lr: 0.0 71 | final_weight_decay: 0.002 72 | lr: 0.001 73 | start_lr: 0.0002 74 | warmup: 5 75 | weight_decay: 0.002 76 | - final_lr: 0.0 77 | final_weight_decay: 0.001 78 | lr: 0.001 79 | start_lr: 0.0002 80 | warmup: 5 81 | weight_decay: 0.001 82 | - final_lr: 0.0 83 | final_weight_decay: 0.0005 84 | lr: 0.001 85 | start_lr: 0.0002 86 | warmup: 5 87 | weight_decay: 0.0005 88 | - final_lr: 0.0 89 | final_weight_decay: 0.008 90 | lr: 0.0015 91 | start_lr: 0.0002 92 | warmup: 5 93 | weight_decay: 0.008 94 | - final_lr: 0.0 95 | final_weight_decay: 0.004 96 | lr: 0.0015 97 | start_lr: 0.0002 98 | warmup: 5 99 | weight_decay: 0.004 100 | - final_lr: 0.0 101 | final_weight_decay: 0.002 102 | lr: 0.0015 103 | start_lr: 0.0002 104 | warmup: 5 105 | weight_decay: 0.002 106 | - final_lr: 0.0 107 | final_weight_decay: 0.001 108 | lr: 0.0015 109 | start_lr: 0.0002 110 | warmup: 5 111 | weight_decay: 0.001 112 | - final_lr: 0.0 113 | final_weight_decay: 0.0005 114 | lr: 0.0015 115 | start_lr: 0.0002 116 | warmup: 5 117 | weight_decay: 0.0005 118 | - final_lr: 0.0 119 | final_weight_decay: 0.008 120 | lr: 0.002 121 | start_lr: 0.0002 122 | warmup: 5 123 | weight_decay: 0.008 124 | - final_lr: 0.0 125 | final_weight_decay: 0.004 126 | lr: 0.002 127 | start_lr: 0.0002 128 | warmup: 5 129 | weight_decay: 0.004 130 | - final_lr: 0.0 131 | final_weight_decay: 0.002 132 | lr: 0.002 133 | start_lr: 0.0002 134 | warmup: 5 135 | weight_decay: 0.002 136 | - final_lr: 0.0 137 | final_weight_decay: 0.001 138 | lr: 0.002 139 | start_lr: 0.0002 140 | warmup: 5 141 | weight_decay: 0.001 142 | - final_lr: 0.0 143 | final_weight_decay: 0.0005 144 | lr: 0.002 145 | start_lr: 0.0002 146 | warmup: 5 147 | weight_decay: 0.0005 148 | num_epochs: 20 149 | use_bfloat16: true 150 | model_kwargs: 151 | checkpoint: /your_vjepa2_checkpoints/vitl.pt 152 | module_name: evals.image_classification_frozen.modelcustom.vit_encoder 153 | pretrain_kwargs: 154 | encoder: 155 | checkpoint_key: target_encoder 156 | img_temporal_dim_size: null 157 | model_name: vit_large 158 | patch_size: 16 159 | tubelet_size: 2 160 | uniform_power: true 161 | use_rope: true 162 | wrapper_kwargs: 163 | img_as_video_nframes: 16 164 | -------------------------------------------------------------------------------- /configs/eval/vitl/jester.yaml: -------------------------------------------------------------------------------- 1 | nodes: 8 2 | tasks_per_node: 8 3 | cpus_per_task: 12 4 | mem_per_gpu: 200G 5 | tag: jester-vitl16-32x4x3 6 | eval_name: video_classification_frozen 7 | folder: /your_folder/evals/vitl/jester 8 | resume_checkpoint: true 9 | experiment: 10 | classifier: 11 | num_probe_blocks: 4 12 | num_heads: 16 13 | data: 14 | dataset_type: VideoDataset 15 | dataset_train: /your_data_dir/Jester/annotations/jester_train_paths.csv 16 | dataset_val: /your_data_dir/Jester/annotations/jester_validation_paths.csv 17 | num_classes: 27 18 | resolution: 256 19 | frames_per_clip: 32 20 | frame_step: 2 21 | num_segments: 4 22 | num_views_per_segment: 3 23 | optimization: 24 | use_pos_embed: false 25 | num_epochs: 100 26 | batch_size: 2 27 | use_bfloat16: true 28 | multihead_kwargs: 29 | - weight_decay: 0.8 30 | final_weight_decay: 0.8 31 | lr: 0.001 32 | start_lr: 0.001 33 | final_lr: 0.0 34 | warmup: 0. 35 | - weight_decay: 0.8 36 | final_weight_decay: 0.8 37 | lr: 0.0003 38 | start_lr: 0.0003 39 | final_lr: 0.0 40 | warmup: 0. 41 | - weight_decay: 0.8 42 | final_weight_decay: 0.8 43 | lr: 0.0001 44 | start_lr: 0.0001 45 | final_lr: 0.0 46 | warmup: 0. 47 | model_kwargs: 48 | checkpoint: /your_vjepa2_checkpoints/vitl.pt 49 | module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip_multilevel 50 | wrapper_kwargs: 51 | max_frames: 128 52 | use_pos_embed: false 53 | out_layers: [17, 19, 21, 23] 54 | pretrain_kwargs: 55 | encoder: 56 | model_name: vit_large 57 | checkpoint_key: target_encoder 58 | tubelet_size: 2 59 | patch_size: 16 60 | uniform_power: true 61 | use_rope: true 62 | -------------------------------------------------------------------------------- /configs/eval/vitl/k400.yaml: -------------------------------------------------------------------------------- 1 | cpus_per_task: 16 2 | eval_name: video_classification_frozen 3 | folder: /your_folder/evals/vitl/k400 4 | mem_per_gpu: 220G 5 | nodes: 8 6 | num_workers: 8 7 | resume_checkpoint: true 8 | tag: k400-vitl16-16x8x3-16f 9 | tasks_per_node: 8 10 | experiment: 11 | classifier: 12 | num_heads: 16 13 | num_probe_blocks: 4 14 | data: 15 | dataset_type: VideoDataset 16 | dataset_train: /your_data_path/k400_train_paths.csv 17 | dataset_val: /your_data_path/k400_val_paths.csv 18 | frame_step: 4 19 | frames_per_clip: 16 20 | num_classes: 400 21 | num_segments: 8 22 | num_views_per_segment: 3 23 | resolution: 256 24 | optimization: 25 | batch_size: 4 26 | multihead_kwargs: 27 | - final_lr: 0.0 28 | final_weight_decay: 0.01 29 | lr: 0.005 30 | start_lr: 0.005 31 | warmup: 0.0 32 | weight_decay: 0.01 33 | - final_lr: 0.0 34 | final_weight_decay: 0.01 35 | lr: 0.003 36 | start_lr: 0.003 37 | warmup: 0.0 38 | weight_decay: 0.01 39 | - final_lr: 0.0 40 | final_weight_decay: 0.01 41 | lr: 0.001 42 | start_lr: 0.001 43 | warmup: 0.0 44 | weight_decay: 0.01 45 | - final_lr: 0.0 46 | final_weight_decay: 0.01 47 | lr: 0.0003 48 | start_lr: 0.0003 49 | warmup: 0.0 50 | weight_decay: 0.01 51 | - final_lr: 0.0 52 | final_weight_decay: 0.01 53 | lr: 0.0001 54 | start_lr: 0.0001 55 | warmup: 0.0 56 | weight_decay: 0.01 57 | - final_lr: 0.0 58 | final_weight_decay: 0.1 59 | lr: 0.005 60 | start_lr: 0.005 61 | warmup: 0.0 62 | weight_decay: 0.1 63 | - final_lr: 0.0 64 | final_weight_decay: 0.1 65 | lr: 0.003 66 | start_lr: 0.003 67 | warmup: 0.0 68 | weight_decay: 0.1 69 | - final_lr: 0.0 70 | final_weight_decay: 0.1 71 | lr: 0.001 72 | start_lr: 0.001 73 | warmup: 0.0 74 | weight_decay: 0.1 75 | - final_lr: 0.0 76 | final_weight_decay: 0.1 77 | lr: 0.0003 78 | start_lr: 0.0003 79 | warmup: 0.0 80 | weight_decay: 0.1 81 | - final_lr: 0.0 82 | final_weight_decay: 0.1 83 | lr: 0.0001 84 | start_lr: 0.0001 85 | warmup: 0.0 86 | weight_decay: 0.1 87 | - final_lr: 0.0 88 | final_weight_decay: 0.4 89 | lr: 0.005 90 | start_lr: 0.005 91 | warmup: 0.0 92 | weight_decay: 0.4 93 | - final_lr: 0.0 94 | final_weight_decay: 0.4 95 | lr: 0.003 96 | start_lr: 0.003 97 | warmup: 0.0 98 | weight_decay: 0.4 99 | - final_lr: 0.0 100 | final_weight_decay: 0.4 101 | lr: 0.001 102 | start_lr: 0.001 103 | warmup: 0.0 104 | weight_decay: 0.4 105 | - final_lr: 0.0 106 | final_weight_decay: 0.4 107 | lr: 0.0003 108 | start_lr: 0.0003 109 | warmup: 0.0 110 | weight_decay: 0.4 111 | - final_lr: 0.0 112 | final_weight_decay: 0.4 113 | lr: 0.0001 114 | start_lr: 0.0001 115 | warmup: 0.0 116 | weight_decay: 0.4 117 | - final_lr: 0.0 118 | final_weight_decay: 0.8 119 | lr: 0.005 120 | start_lr: 0.005 121 | warmup: 0.0 122 | weight_decay: 0.8 123 | - final_lr: 0.0 124 | final_weight_decay: 0.8 125 | lr: 0.003 126 | start_lr: 0.003 127 | warmup: 0.0 128 | weight_decay: 0.8 129 | - final_lr: 0.0 130 | final_weight_decay: 0.8 131 | lr: 0.001 132 | start_lr: 0.001 133 | warmup: 0.0 134 | weight_decay: 0.8 135 | - final_lr: 0.0 136 | final_weight_decay: 0.8 137 | lr: 0.0003 138 | start_lr: 0.0003 139 | warmup: 0.0 140 | weight_decay: 0.8 141 | - final_lr: 0.0 142 | final_weight_decay: 0.8 143 | lr: 0.0001 144 | start_lr: 0.0001 145 | warmup: 0.0 146 | weight_decay: 0.8 147 | num_epochs: 20 148 | use_bfloat16: true 149 | use_pos_embed: false 150 | model_kwargs: 151 | checkpoint: /your_vjepa2_checkpoints/vitl.pt 152 | module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip 153 | pretrain_kwargs: 154 | encoder: 155 | checkpoint_key: target_encoder 156 | img_temporal_dim_size: null 157 | model_name: vit_large 158 | patch_size: 16 159 | tubelet_size: 2 160 | uniform_power: true 161 | use_rope: true 162 | wrapper_kwargs: 163 | max_frames: 128 164 | use_pos_embed: false 165 | -------------------------------------------------------------------------------- /configs/eval/vitl/ssv2.yaml: -------------------------------------------------------------------------------- 1 | cpus_per_task: 16 2 | eval_name: video_classification_frozen 3 | folder: /your_folder/evals/vitl/ssv2 4 | mem_per_gpu: 220G 5 | nodes: 8 6 | max_workers: 8 7 | resume_checkpoint: true 8 | tag: ssv2-vitl16-16x2x3-16f 9 | tasks_per_node: 8 10 | experiment: 11 | classifier: 12 | num_heads: 16 13 | num_probe_blocks: 4 14 | data: 15 | dataset_type: VideoDataset 16 | dataset_train: /your_data_path/ssv2_train_paths.csv 17 | dataset_val: /your_data_path/ssv2_val_paths.csv 18 | frame_step: 4 19 | frames_per_clip: 16 20 | num_classes: 174 21 | num_segments: 2 22 | num_views_per_segment: 3 23 | resolution: 256 24 | optimization: 25 | batch_size: 4 26 | multihead_kwargs: 27 | - final_lr: 0.0 28 | final_weight_decay: 0.01 29 | lr: 0.005 30 | start_lr: 0.005 31 | warmup: 0.0 32 | weight_decay: 0.01 33 | - final_lr: 0.0 34 | final_weight_decay: 0.01 35 | lr: 0.003 36 | start_lr: 0.003 37 | warmup: 0.0 38 | weight_decay: 0.01 39 | - final_lr: 0.0 40 | final_weight_decay: 0.01 41 | lr: 0.001 42 | start_lr: 0.001 43 | warmup: 0.0 44 | weight_decay: 0.01 45 | - final_lr: 0.0 46 | final_weight_decay: 0.01 47 | lr: 0.0003 48 | start_lr: 0.0003 49 | warmup: 0.0 50 | weight_decay: 0.01 51 | - final_lr: 0.0 52 | final_weight_decay: 0.01 53 | lr: 0.0001 54 | start_lr: 0.0001 55 | warmup: 0.0 56 | weight_decay: 0.01 57 | - final_lr: 0.0 58 | final_weight_decay: 0.1 59 | lr: 0.005 60 | start_lr: 0.005 61 | warmup: 0.0 62 | weight_decay: 0.1 63 | - final_lr: 0.0 64 | final_weight_decay: 0.1 65 | lr: 0.003 66 | start_lr: 0.003 67 | warmup: 0.0 68 | weight_decay: 0.1 69 | - final_lr: 0.0 70 | final_weight_decay: 0.1 71 | lr: 0.001 72 | start_lr: 0.001 73 | warmup: 0.0 74 | weight_decay: 0.1 75 | - final_lr: 0.0 76 | final_weight_decay: 0.1 77 | lr: 0.0003 78 | start_lr: 0.0003 79 | warmup: 0.0 80 | weight_decay: 0.1 81 | - final_lr: 0.0 82 | final_weight_decay: 0.1 83 | lr: 0.0001 84 | start_lr: 0.0001 85 | warmup: 0.0 86 | weight_decay: 0.1 87 | - final_lr: 0.0 88 | final_weight_decay: 0.4 89 | lr: 0.005 90 | start_lr: 0.005 91 | warmup: 0.0 92 | weight_decay: 0.4 93 | - final_lr: 0.0 94 | final_weight_decay: 0.4 95 | lr: 0.003 96 | start_lr: 0.003 97 | warmup: 0.0 98 | weight_decay: 0.4 99 | - final_lr: 0.0 100 | final_weight_decay: 0.4 101 | lr: 0.001 102 | start_lr: 0.001 103 | warmup: 0.0 104 | weight_decay: 0.4 105 | - final_lr: 0.0 106 | final_weight_decay: 0.4 107 | lr: 0.0003 108 | start_lr: 0.0003 109 | warmup: 0.0 110 | weight_decay: 0.4 111 | - final_lr: 0.0 112 | final_weight_decay: 0.4 113 | lr: 0.0001 114 | start_lr: 0.0001 115 | warmup: 0.0 116 | weight_decay: 0.4 117 | - final_lr: 0.0 118 | final_weight_decay: 0.8 119 | lr: 0.005 120 | start_lr: 0.005 121 | warmup: 0.0 122 | weight_decay: 0.8 123 | - final_lr: 0.0 124 | final_weight_decay: 0.8 125 | lr: 0.003 126 | start_lr: 0.003 127 | warmup: 0.0 128 | weight_decay: 0.8 129 | - final_lr: 0.0 130 | final_weight_decay: 0.8 131 | lr: 0.001 132 | start_lr: 0.001 133 | warmup: 0.0 134 | weight_decay: 0.8 135 | - final_lr: 0.0 136 | final_weight_decay: 0.8 137 | lr: 0.0003 138 | start_lr: 0.0003 139 | warmup: 0.0 140 | weight_decay: 0.8 141 | - final_lr: 0.0 142 | final_weight_decay: 0.8 143 | lr: 0.0001 144 | start_lr: 0.0001 145 | warmup: 0.0 146 | weight_decay: 0.8 147 | num_epochs: 20 148 | use_bfloat16: true 149 | use_pos_embed: false 150 | model_kwargs: 151 | checkpoint: /your_vjepa2_checkpoints/vitl.pt 152 | module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip 153 | pretrain_kwargs: 154 | encoder: 155 | checkpoint_key: target_encoder 156 | img_temporal_dim_size: null 157 | model_name: vit_large 158 | patch_size: 16 159 | tubelet_size: 2 160 | uniform_power: true 161 | use_rope: true 162 | wrapper_kwargs: 163 | max_frames: 128 164 | use_pos_embed: false 165 | -------------------------------------------------------------------------------- /configs/inference/vitg-384/diving48.yaml: -------------------------------------------------------------------------------- 1 | nodes: 8 2 | tasks_per_node: 8 3 | cpus_per_task: 12 4 | mem_per_gpu: 200G 5 | tag: diving48-vitg16-384-32x4x3 6 | eval_name: video_classification_frozen 7 | folder: /your_folder/evals/vitg-384/diving48 8 | resume_checkpoint: true 9 | val_only: true 10 | experiment: 11 | classifier: 12 | num_probe_blocks: 4 13 | num_heads: 16 14 | data: 15 | dataset_type: VideoDataset 16 | dataset_train: /your_data_dir/diving48/annotations/Diving48_train_paths.csv 17 | dataset_val: /your_data_dir/diving48/annotations/Diving48_test_paths.csv 18 | num_classes: 48 19 | resolution: 384 20 | frames_per_clip: 32 21 | frame_step: 2 22 | num_segments: 4 23 | num_views_per_segment: 3 24 | optimization: 25 | use_pos_embed: false 26 | num_epochs: 100 27 | batch_size: 2 28 | use_bfloat16: true 29 | multihead_kwargs: 30 | - weight_decay: 0.0 31 | final_weight_decay: 0.0 32 | lr: 0.0 33 | start_lr: 0.0 34 | final_lr: 0.0 35 | warmup: 0.0 36 | model_kwargs: 37 | checkpoint: /your_vjepa2_checkpoints/vitg-384.pt 38 | module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip_multilevel 39 | wrapper_kwargs: 40 | max_frames: 128 41 | use_pos_embed: false 42 | out_layers: [24, 29, 34, 39] 43 | pretrain_kwargs: 44 | encoder: 45 | model_name: vit_giant_xformers 46 | checkpoint_key: target_encoder 47 | tubelet_size: 2 48 | patch_size: 16 49 | uniform_power: true 50 | use_rope: true 51 | -------------------------------------------------------------------------------- /configs/inference/vitg-384/ek100.yaml: -------------------------------------------------------------------------------- 1 | nodes: 8 2 | tasks_per_node: 8 3 | cpus_per_task: 12 4 | tag: ek100-vitg16-384 5 | eval_name: action_anticipation_frozen 6 | folder: /your_folder/evals/vitg-384/ek100 7 | resume_checkpoint: true 8 | val_only: true 9 | experiment: 10 | classifier: 11 | num_probe_blocks: 4 12 | num_heads: 16 13 | data: 14 | anticipation_time_sec: 15 | - 1.0 16 | - 1.0 17 | auto_augment: true 18 | file_format: 0 19 | dataset: EK100 20 | base_path: /your_ek100_root_dir/ 21 | dataset_train: /your_data/EPIC_100_train.csv 22 | dataset_val: /your_data/EPIC_100_validation.csv 23 | frames_per_clip: 32 24 | frames_per_second: 8 25 | motion_shift: false 26 | num_workers: 2 27 | pin_memory: true 28 | random_resize_scale: 29 | - 0.08 30 | - 1.0 31 | reprob: 0.25 32 | resolution: 384 33 | train_anticipation_point: 34 | - 0.0 35 | - 0.25 36 | train_anticipation_time_sec: 37 | - 0.25 38 | - 1.75 39 | optimization: 40 | num_epochs: 20 41 | batch_size: 2 42 | use_bfloat16: true 43 | use_focal_loss: true 44 | multihead_kwargs: 45 | - weight_decay: 0.0 46 | final_weight_decay: 0.0 47 | lr: 0.0 48 | start_lr: 0.0 49 | final_lr: 0.0 50 | warmup: 0.0 51 | model_kwargs: 52 | checkpoint: /your_vjepa2_checkpoints/vitg-384.pt 53 | module_name: evals.action_anticipation_frozen.modelcustom.vit_encoder_predictor_concat_ar 54 | wrapper_kwargs: 55 | no_predictor: false 56 | num_output_frames: 2 57 | num_steps: 1 58 | pretrain_kwargs: 59 | encoder: 60 | model_name: vit_giant_xformers 61 | checkpoint_key: target_encoder 62 | tubelet_size: 2 63 | patch_size: 16 64 | uniform_power: true 65 | use_rope: true 66 | predictor: 67 | model_name: vit_predictor 68 | checkpoint_key: predictor 69 | num_frames: 64 70 | depth: 12 71 | num_heads: 12 72 | predictor_embed_dim: 384 73 | num_mask_tokens: 10 74 | uniform_power: true 75 | use_mask_tokens: true 76 | use_sdpa: true 77 | use_silu: false 78 | wide_silu: false 79 | use_rope: true 80 | -------------------------------------------------------------------------------- /configs/inference/vitg-384/ssv2.yaml: -------------------------------------------------------------------------------- 1 | cpus_per_task: 16 2 | eval_name: video_classification_frozen 3 | folder: /your_folder/evals/vitg-384/ssv2 4 | mem_per_gpu: 220G 5 | nodes: 16 6 | num_workers: 8 7 | resume_checkpoint: true 8 | val_only: true 9 | tag: ssv2-vitg16-384-64x2x3 10 | tasks_per_node: 8 11 | experiment: 12 | classifier: 13 | num_heads: 16 14 | num_probe_blocks: 4 15 | data: 16 | dataset_type: VideoDataset 17 | dataset_train: /your_data_path/ssv2_train_paths.csv 18 | dataset_val: /your_data_path/ssv2_val_paths.csv 19 | frame_step: 2 20 | frames_per_clip: 64 21 | num_classes: 174 22 | num_segments: 2 23 | num_views_per_segment: 3 24 | resolution: 384 25 | optimization: 26 | batch_size: 2 27 | multihead_kwargs: 28 | - final_lr: 0.0 29 | final_weight_decay: 0.0 30 | lr: 0.0 31 | start_lr: 0.0 32 | warmup: 0.0 33 | weight_decay: 0.0 34 | num_epochs: 20 35 | use_bfloat16: true 36 | use_pos_embed: false 37 | model_kwargs: 38 | checkpoint: /your_vjepa2_checkpoints/vitg-384.pt 39 | module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip 40 | pretrain_kwargs: 41 | encoder: 42 | checkpoint_key: target_encoder 43 | img_temporal_dim_size: null 44 | model_name: vit_giant_xformers 45 | patch_size: 16 46 | tubelet_size: 2 47 | uniform_power: true 48 | use_rope: true 49 | wrapper_kwargs: 50 | max_frames: 128 51 | use_pos_embed: false 52 | -------------------------------------------------------------------------------- /configs/inference/vitl/diving48.yaml: -------------------------------------------------------------------------------- 1 | nodes: 8 2 | tasks_per_node: 8 3 | cpus_per_task: 12 4 | mem_per_gpu: 200G 5 | tag: diving48-vitl16-32x4x3 6 | eval_name: video_classification_frozen 7 | folder: /your_folder/evals/vitl/diving48 8 | resume_checkpoint: true 9 | val_only: true 10 | experiment: 11 | classifier: 12 | num_probe_blocks: 4 13 | num_heads: 16 14 | data: 15 | dataset_type: VideoDataset 16 | dataset_train: /your_data_dir/diving48/annotations/Diving48_train_paths.csv 17 | dataset_val: /your_data_dir/diving48/annotations/Diving48_test_paths.csv 18 | num_classes: 48 19 | resolution: 256 20 | frames_per_clip: 32 21 | frame_step: 2 22 | num_segments: 4 23 | num_views_per_segment: 3 24 | optimization: 25 | use_pos_embed: false 26 | num_epochs: 100 27 | batch_size: 2 28 | use_bfloat16: true 29 | multihead_kwargs: 30 | - weight_decay: 0.0 31 | final_weight_decay: 0.0 32 | lr: 0.0 33 | start_lr: 0.0 34 | final_lr: 0.0 35 | warmup: 0.0 36 | model_kwargs: 37 | checkpoint: /your_vjepa2_checkpoints/vitl.pt 38 | module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip_multilevel 39 | wrapper_kwargs: 40 | max_frames: 128 41 | use_pos_embed: false 42 | out_layers: [17, 19, 21, 23] 43 | pretrain_kwargs: 44 | encoder: 45 | model_name: vit_large 46 | checkpoint_key: target_encoder 47 | tubelet_size: 2 48 | patch_size: 16 49 | uniform_power: true 50 | use_rope: true 51 | -------------------------------------------------------------------------------- /configs/inference/vitl/ek100.yaml: -------------------------------------------------------------------------------- 1 | nodes: 8 2 | tasks_per_node: 8 3 | cpus_per_task: 12 4 | tag: ek100-vitl16 5 | eval_name: action_anticipation_frozen 6 | folder: /your_folder/evals/vitl/ek100 7 | resume_checkpoint: true 8 | val_only: true 9 | experiment: 10 | classifier: 11 | num_probe_blocks: 4 12 | num_heads: 16 13 | data: 14 | anticipation_time_sec: 15 | - 1.0 16 | - 1.0 17 | auto_augment: true 18 | file_format: 0 19 | dataset: EK100 20 | base_path: /your_ek100_root_dir/ 21 | dataset_train: /your_data/EPIC_100_train.csv 22 | dataset_val: /your_data/EPIC_100_validation.csv 23 | frames_per_clip: 32 24 | frames_per_second: 8 25 | motion_shift: false 26 | num_workers: 2 27 | pin_memory: true 28 | random_resize_scale: 29 | - 0.08 30 | - 1.0 31 | reprob: 0.25 32 | resolution: 256 33 | train_anticipation_point: 34 | - 0.0 35 | - 0.25 36 | train_anticipation_time_sec: 37 | - 0.25 38 | - 1.75 39 | optimization: 40 | num_epochs: 20 41 | batch_size: 2 42 | use_bfloat16: true 43 | use_focal_loss: true 44 | multihead_kwargs: 45 | - weight_decay: 0.0 46 | final_weight_decay: 0.0 47 | lr: 0.0 48 | start_lr: 0.0 49 | final_lr: 0.0 50 | warmup: 0.0 51 | model_kwargs: 52 | checkpoint: /your_vjepa2_checkpoints/vitl.pt 53 | module_name: evals.action_anticipation_frozen.modelcustom.vit_encoder_predictor_concat_ar 54 | wrapper_kwargs: 55 | no_predictor: false 56 | num_output_frames: 2 57 | num_steps: 1 58 | pretrain_kwargs: 59 | encoder: 60 | model_name: vit_large 61 | checkpoint_key: target_encoder 62 | tubelet_size: 2 63 | patch_size: 16 64 | uniform_power: true 65 | use_rope: true 66 | predictor: 67 | model_name: vit_predictor 68 | checkpoint_key: predictor 69 | num_frames: 64 70 | depth: 12 71 | num_heads: 12 72 | predictor_embed_dim: 384 73 | num_mask_tokens: 10 74 | uniform_power: true 75 | use_mask_tokens: true 76 | use_sdpa: true 77 | use_silu: false 78 | wide_silu: false 79 | use_rope: true 80 | -------------------------------------------------------------------------------- /configs/inference/vitl/ssv2.yaml: -------------------------------------------------------------------------------- 1 | cpus_per_task: 16 2 | eval_name: video_classification_frozen 3 | folder: /your_folder/evals/vitl/ssv2 4 | mem_per_gpu: 220G 5 | nodes: 8 6 | max_workers: 8 7 | resume_checkpoint: true 8 | val_only: true 9 | tag: ssv2-vitl16-16x2x3-16f 10 | tasks_per_node: 8 11 | experiment: 12 | classifier: 13 | num_heads: 16 14 | num_probe_blocks: 4 15 | data: 16 | dataset_type: VideoDataset 17 | dataset_train: /your_data_path/ssv2_train_paths.csv 18 | dataset_val: /your_data_path/ssv2_val_paths.csv 19 | frame_step: 4 20 | frames_per_clip: 16 21 | num_classes: 174 22 | num_segments: 2 23 | num_views_per_segment: 3 24 | resolution: 256 25 | optimization: 26 | batch_size: 4 27 | multihead_kwargs: 28 | - final_lr: 0.0 29 | final_weight_decay: 0.0 30 | lr: 0.0 31 | start_lr: 0.0 32 | warmup: 0.0 33 | weight_decay: 0.0 34 | num_epochs: 20 35 | use_bfloat16: true 36 | use_pos_embed: false 37 | model_kwargs: 38 | checkpoint: /your_vjepa2_checkpoints/vitl.pt 39 | module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip 40 | pretrain_kwargs: 41 | encoder: 42 | checkpoint_key: target_encoder 43 | img_temporal_dim_size: null 44 | model_name: vit_large 45 | patch_size: 16 46 | tubelet_size: 2 47 | uniform_power: true 48 | use_rope: true 49 | wrapper_kwargs: 50 | max_frames: 128 51 | use_pos_embed: false 52 | -------------------------------------------------------------------------------- /configs/train/vitg16/cooldown-256px-64f.yaml: -------------------------------------------------------------------------------- 1 | app: vjepa 2 | cpus_per_task: 32 3 | folder: /your_folder/anneal/64.8.vitg16-256px-64f 4 | mem_per_gpu: 220G 5 | nodes: 64 6 | tasks_per_node: 8 7 | data: 8 | dataset_type: VideoDataset 9 | datasets: 10 | - /your_k710_root_dir/k710_train_paths.csv 11 | - /your_data_path/ssv2_train_paths.csv 12 | - /your_data/howto_320p.csv 13 | datasets_weights: 14 | - 0.335 15 | - 0.100 16 | - 0.565 17 | batch_size: 6 18 | crop_size: 256 19 | dataset_fpcs: 20 | - 64 21 | - 64 22 | - 64 23 | fps: 4 24 | num_workers: 12 25 | patch_size: 16 26 | persistent_workers: true 27 | pin_mem: false 28 | tubelet_size: 2 29 | data_aug: 30 | auto_augment: false 31 | motion_shift: false 32 | random_resize_aspect_ratio: 33 | - 0.75 34 | - 1.35 35 | random_resize_scale: 36 | - 0.3 37 | - 1.0 38 | reprob: 0.0 39 | loss: 40 | loss_exp: 1.0 41 | mask: 42 | - aspect_ratio: 43 | - 0.75 44 | - 1.5 45 | full_complement: false 46 | max_keep: null 47 | max_temporal_keep: 1.0 48 | num_blocks: 8 49 | spatial_scale: 50 | - 0.15 51 | - 0.15 52 | temporal_scale: 53 | - 1.0 54 | - 1.0 55 | - aspect_ratio: 56 | - 0.75 57 | - 1.5 58 | full_complement: false 59 | max_keep: null 60 | max_temporal_keep: 1.0 61 | num_blocks: 2 62 | spatial_scale: 63 | - 0.7 64 | - 0.7 65 | temporal_scale: 66 | - 1.0 67 | - 1.0 68 | meta: 69 | dtype: bfloat16 70 | eval_freq: 100 71 | load_checkpoint: true 72 | read_checkpoint: null 73 | save_every_freq: 50 74 | seed: 239 75 | use_sdpa: true 76 | model: 77 | model_name: vit_giant_xformers 78 | pred_depth: 12 79 | pred_embed_dim: 384 80 | pred_num_heads: 12 81 | uniform_power: true 82 | use_activation_checkpointing: true 83 | use_mask_tokens: true 84 | use_rope: true 85 | zero_init_mask_tokens: true 86 | optimization: 87 | anneal_ckpt: /your_folder/pretrain/16.8.vitg.256px.16f/e0.pt 88 | ema: 89 | - 0.99925 90 | - 0.99925 91 | epochs: 4 92 | final_lr: 1.0e-06 93 | final_weight_decay: 0.04 94 | ipe: 30 95 | ipe_scale: 1.25 96 | is_anneal: true 97 | lr: 0.000525 98 | resume_anneal: true 99 | start_lr: 0.0001 100 | warmup: 0 101 | weight_decay: 0.04 102 | -------------------------------------------------------------------------------- /configs/train/vitg16/cooldown-384px-64f.yaml: -------------------------------------------------------------------------------- 1 | app: vjepa 2 | cpus_per_task: 32 3 | folder: /your_folder/anneal/64.8.vitg16-384px-64f 4 | mem_per_gpu: 220G 5 | nodes: 64 6 | tasks_per_node: 8 7 | data: 8 | dataset_type: VideoDataset 9 | datasets: 10 | - /your_k710_root_dir/k710_train_paths.csv 11 | - /your_data_path/ssv2_train_paths.csv 12 | - /your_data/howto_320p.csv 13 | datasets_weights: 14 | - 0.335 15 | - 0.100 16 | - 0.565 17 | batch_size: 6 18 | crop_size: 384 19 | dataset_fpcs: 20 | - 64 21 | - 64 22 | - 64 23 | fps: 4 24 | num_workers: 12 25 | patch_size: 16 26 | persistent_workers: true 27 | pin_mem: false 28 | tubelet_size: 2 29 | data_aug: 30 | auto_augment: false 31 | motion_shift: false 32 | random_resize_aspect_ratio: 33 | - 0.75 34 | - 1.35 35 | random_resize_scale: 36 | - 0.3 37 | - 1.0 38 | reprob: 0.0 39 | loss: 40 | loss_exp: 1.0 41 | mask: 42 | - aspect_ratio: 43 | - 0.75 44 | - 1.5 45 | full_complement: false 46 | max_keep: null 47 | max_temporal_keep: 1.0 48 | num_blocks: 8 49 | spatial_scale: 50 | - 0.15 51 | - 0.15 52 | temporal_scale: 53 | - 1.0 54 | - 1.0 55 | - aspect_ratio: 56 | - 0.75 57 | - 1.5 58 | full_complement: false 59 | max_keep: null 60 | max_temporal_keep: 1.0 61 | num_blocks: 2 62 | spatial_scale: 63 | - 0.7 64 | - 0.7 65 | temporal_scale: 66 | - 1.0 67 | - 1.0 68 | meta: 69 | dtype: bfloat16 70 | eval_freq: 100 71 | load_checkpoint: true 72 | read_checkpoint: null 73 | save_every_freq: 50 74 | seed: 239 75 | use_sdpa: true 76 | model: 77 | model_name: vit_giant_xformers 78 | pred_depth: 12 79 | pred_embed_dim: 384 80 | pred_num_heads: 12 81 | uniform_power: true 82 | use_activation_checkpointing: true 83 | use_mask_tokens: true 84 | use_rope: true 85 | zero_init_mask_tokens: true 86 | optimization: 87 | anneal_ckpt: /your_folder/pretrain/16.8.vitg.256px.16f/e0.pt 88 | ema: 89 | - 0.99925 90 | - 0.99925 91 | epochs: 40 92 | final_lr: 1.0e-06 93 | final_weight_decay: 0.04 94 | ipe: 300 95 | ipe_scale: 1.25 96 | is_anneal: true 97 | lr: 0.000525 98 | resume_anneal: true 99 | start_lr: 0.0001 100 | warmup: 0 101 | weight_decay: 0.04 102 | -------------------------------------------------------------------------------- /configs/train/vitg16/droid-256px-8f.yaml: -------------------------------------------------------------------------------- 1 | app: vjepa_droid 2 | cpus_per_task: 12 3 | folder: /your_folder/droid/4.8.vitg16-256px-8f 4 | mem_per_gpu: 220G 5 | nodes: 4 6 | tasks_per_node: 8 7 | data: 8 | batch_size: 8 9 | camera_views: 10 | - left_mp4_path 11 | crop_size: 256 12 | datasets: 13 | - /your_file_path/droid_train_paths_cw.csv 14 | dataset_fpcs: 15 | - 8 16 | fps: 4 17 | num_workers: 12 18 | patch_size: 16 19 | pin_mem: true 20 | stereo_view: false 21 | tubelet_size: 2 22 | data_aug: 23 | auto_augment: false 24 | horizontal_flip: false 25 | motion_shift: false 26 | random_resize_aspect_ratio: 27 | - 0.75 28 | - 1.35 29 | random_resize_scale: 30 | - 1.777 31 | - 1.777 32 | reprob: 0.0 33 | loss: 34 | auto_steps: 2 35 | loss_exp: 1.0 36 | normalize_reps: true 37 | reg_coeff: 0.0 38 | meta: 39 | dtype: bfloat16 40 | eval_freq: 100 41 | resume_checkpoint: null 42 | load_predictor: false 43 | pretrain_checkpoint: /your_vjepa2_checkpoints/vitg.pt 44 | context_encoder_key: target_encoder 45 | target_encoder_key: target_encoder 46 | save_every_freq: 25 47 | seed: 239 48 | use_sdpa: true 49 | model: 50 | model_name: vit_giant_xformers 51 | pred_depth: 24 52 | pred_embed_dim: 1024 53 | pred_is_frame_causal: true 54 | pred_num_heads: 16 55 | uniform_power: true 56 | use_activation_checkpointing: true 57 | use_extrinsics: false 58 | use_rope: true 59 | optimization: 60 | anneal: 15 61 | epochs: 315 62 | final_lr: 0.0 63 | final_weight_decay: 0.04 64 | ipe: 300 65 | lr: 0.000425 66 | start_lr: 0.000075 67 | warmup: 15 68 | weight_decay: 0.04 69 | -------------------------------------------------------------------------------- /configs/train/vitg16/pretrain-256px-16f.yaml: -------------------------------------------------------------------------------- 1 | app: vjepa 2 | nodes: 16 3 | tasks_per_node: 8 4 | cpus_per_task: 16 5 | mem_per_gpu: 220G 6 | folder: /your_folder/pretrain/16.8.vitg.256px.16f 7 | data: 8 | dataset_type: VideoDataset 9 | datasets: 10 | - /your_k710_root_dir/k710_train_paths.csv 11 | - /your_data_path/ssv2_train_paths.csv 12 | - /your_data/howto_320p.csv 13 | datasets_weights: 14 | - 0.335 15 | - 0.100 16 | - 0.565 17 | batch_size: 24 18 | crop_size: 256 19 | patch_size: 16 20 | dataset_fpcs: 21 | - 16 22 | - 16 23 | - 16 24 | tubelet_size: 2 25 | fps: 4 26 | num_workers: 8 27 | persistent_workers: true 28 | pin_mem: true 29 | data_aug: 30 | auto_augment: false 31 | motion_shift: false 32 | random_resize_aspect_ratio: 33 | - 0.75 34 | - 1.35 35 | random_resize_scale: 36 | - 0.3 37 | - 1.0 38 | reprob: 0.0 39 | loss: 40 | loss_exp: 1.0 41 | mask: 42 | - aspect_ratio: 43 | - 0.75 44 | - 1.5 45 | full_complement: false 46 | max_keep: null 47 | max_temporal_keep: 1.0 48 | num_blocks: 8 49 | spatial_scale: 50 | - 0.15 51 | - 0.15 52 | temporal_scale: 53 | - 1.0 54 | - 1.0 55 | - aspect_ratio: 56 | - 0.75 57 | - 1.5 58 | full_complement: false 59 | max_keep: null 60 | max_temporal_keep: 1.0 61 | num_blocks: 2 62 | spatial_scale: 63 | - 0.7 64 | - 0.7 65 | temporal_scale: 66 | - 1.0 67 | - 1.0 68 | meta: 69 | dtype: bfloat16 70 | eval_freq: 100 71 | load_checkpoint: true 72 | read_checkpoint: null 73 | save_every_freq: 50 74 | seed: 239 75 | use_sdpa: true 76 | model: 77 | model_name: vit_giant_xformers 78 | pred_depth: 12 79 | pred_embed_dim: 384 80 | pred_num_heads: 12 81 | uniform_power: true 82 | use_activation_checkpointing: true 83 | use_mask_tokens: true 84 | use_rope: true 85 | zero_init_mask_tokens: true 86 | optimization: 87 | ema: 88 | - 0.99925 89 | - 0.99925 90 | epochs: 800 91 | final_lr: 0.000525 92 | final_weight_decay: 0.04 93 | ipe: 300 94 | ipe_scale: 1.25 95 | lr: 0.000525 96 | start_lr: 0.0001 97 | warmup: 40 98 | weight_decay: 0.04 99 | -------------------------------------------------------------------------------- /configs/train/vith16/cooldown-256px-64f.yaml: -------------------------------------------------------------------------------- 1 | app: vjepa 2 | cpus_per_task: 32 3 | folder: /your_folder/anneal/32.8.vith16-256px-64f 4 | mem_per_gpu: 220G 5 | nodes: 32 6 | tasks_per_node: 8 7 | data: 8 | dataset_type: VideoDataset 9 | datasets: 10 | - /your_k710_root_dir/k710_train_paths.csv 11 | - /your_data_path/ssv2_train_paths.csv 12 | - /your_data/howto_320p.csv 13 | datasets_weights: 14 | - 0.335 15 | - 0.100 16 | - 0.565 17 | batch_size: 12 18 | crop_size: 256 19 | dataset_fpcs: 20 | - 64 21 | - 64 22 | - 64 23 | fps: 4 24 | num_workers: 12 25 | patch_size: 16 26 | persistent_workers: true 27 | pin_mem: false 28 | tubelet_size: 2 29 | data_aug: 30 | auto_augment: false 31 | motion_shift: false 32 | random_resize_aspect_ratio: 33 | - 0.75 34 | - 1.35 35 | random_resize_scale: 36 | - 0.3 37 | - 1.0 38 | reprob: 0.0 39 | loss: 40 | loss_exp: 1.0 41 | mask: 42 | - aspect_ratio: 43 | - 0.75 44 | - 1.5 45 | full_complement: false 46 | max_keep: null 47 | max_temporal_keep: 1.0 48 | num_blocks: 8 49 | spatial_scale: 50 | - 0.15 51 | - 0.15 52 | temporal_scale: 53 | - 1.0 54 | - 1.0 55 | - aspect_ratio: 56 | - 0.75 57 | - 1.5 58 | full_complement: false 59 | max_keep: null 60 | max_temporal_keep: 1.0 61 | num_blocks: 2 62 | spatial_scale: 63 | - 0.7 64 | - 0.7 65 | temporal_scale: 66 | - 1.0 67 | - 1.0 68 | meta: 69 | dtype: bfloat16 70 | eval_freq: 100 71 | load_checkpoint: true 72 | read_checkpoint: null 73 | save_every_freq: 50 74 | seed: 239 75 | use_sdpa: true 76 | model: 77 | model_name: vit_huge 78 | pred_depth: 12 79 | pred_embed_dim: 384 80 | pred_num_heads: 12 81 | uniform_power: true 82 | use_activation_checkpointing: true 83 | use_mask_tokens: true 84 | use_rope: true 85 | zero_init_mask_tokens: true 86 | optimization: 87 | anneal_ckpt: /your_folder/pretrain/16.8.vith.256px.16f/e0.pt 88 | ema: 89 | - 0.99925 90 | - 0.99925 91 | epochs: 40 92 | final_lr: 1.0e-06 93 | final_weight_decay: 0.04 94 | ipe: 300 95 | ipe_scale: 1.25 96 | is_anneal: true 97 | lr: 0.000525 98 | resume_anneal: true 99 | start_lr: 0.0001 100 | warmup: 0 101 | weight_decay: 0.04 102 | -------------------------------------------------------------------------------- /configs/train/vith16/pretrain-256px-16f.yaml: -------------------------------------------------------------------------------- 1 | app: vjepa 2 | nodes: 16 3 | tasks_per_node: 8 4 | cpus_per_task: 16 5 | mem_per_gpu: 220G 6 | folder: /your_folder/pretrain/16.8.vith.256px.16f 7 | data: 8 | dataset_type: VideoDataset 9 | datasets: 10 | - /your_k710_root_dir/k710_train_paths.csv 11 | - /your_data_path/ssv2_train_paths.csv 12 | - /your_data/howto_320p.csv 13 | datasets_weights: 14 | - 0.335 15 | - 0.100 16 | - 0.565 17 | batch_size: 24 18 | crop_size: 256 19 | patch_size: 16 20 | dataset_fpcs: 21 | - 16 22 | - 16 23 | - 16 24 | tubelet_size: 2 25 | fps: 4 26 | num_workers: 8 27 | persistent_workers: true 28 | pin_mem: true 29 | data_aug: 30 | auto_augment: false 31 | motion_shift: false 32 | random_resize_aspect_ratio: 33 | - 0.75 34 | - 1.35 35 | random_resize_scale: 36 | - 0.3 37 | - 1.0 38 | reprob: 0.0 39 | loss: 40 | loss_exp: 1.0 41 | mask: 42 | - aspect_ratio: 43 | - 0.75 44 | - 1.5 45 | full_complement: false 46 | max_keep: null 47 | max_temporal_keep: 1.0 48 | num_blocks: 8 49 | spatial_scale: 50 | - 0.15 51 | - 0.15 52 | temporal_scale: 53 | - 1.0 54 | - 1.0 55 | - aspect_ratio: 56 | - 0.75 57 | - 1.5 58 | full_complement: false 59 | max_keep: null 60 | max_temporal_keep: 1.0 61 | num_blocks: 2 62 | spatial_scale: 63 | - 0.7 64 | - 0.7 65 | temporal_scale: 66 | - 1.0 67 | - 1.0 68 | meta: 69 | dtype: bfloat16 70 | eval_freq: 100 71 | load_checkpoint: true 72 | read_checkpoint: null 73 | save_every_freq: 50 74 | seed: 239 75 | use_sdpa: true 76 | model: 77 | model_name: vit_huge 78 | pred_depth: 12 79 | pred_embed_dim: 384 80 | pred_num_heads: 12 81 | uniform_power: true 82 | use_activation_checkpointing: true 83 | use_mask_tokens: true 84 | use_rope: true 85 | zero_init_mask_tokens: true 86 | optimization: 87 | ema: 88 | - 0.99925 89 | - 0.99925 90 | epochs: 10 91 | final_lr: 0.000525 92 | final_weight_decay: 0.04 93 | ipe: 300 94 | ipe_scale: 1.25 95 | lr: 0.000525 96 | start_lr: 0.0001 97 | warmup: 40 98 | weight_decay: 0.04 99 | -------------------------------------------------------------------------------- /configs/train/vitl16/cooldown-256px-64f.yaml: -------------------------------------------------------------------------------- 1 | app: vjepa 2 | cpus_per_task: 32 3 | folder: /your_folder/anneal/32.8.vitl16-256px-64f 4 | mem_per_gpu: 220G 5 | nodes: 32 6 | tasks_per_node: 8 7 | data: 8 | dataset_type: VideoDataset 9 | datasets: 10 | - /your_k710_root_dir/k710_train_paths.csv 11 | - /your_data_path/ssv2_train_paths.csv 12 | - /your_data/howto_320p.csv 13 | datasets_weights: 14 | - 0.335 15 | - 0.100 16 | - 0.565 17 | batch_size: 12 18 | crop_size: 256 19 | dataset_fpcs: 20 | - 64 21 | - 64 22 | - 64 23 | fps: 4 24 | num_workers: 12 25 | patch_size: 16 26 | persistent_workers: true 27 | pin_mem: false 28 | tubelet_size: 2 29 | data_aug: 30 | auto_augment: false 31 | motion_shift: false 32 | random_resize_aspect_ratio: 33 | - 0.75 34 | - 1.35 35 | random_resize_scale: 36 | - 0.3 37 | - 1.0 38 | reprob: 0.0 39 | loss: 40 | loss_exp: 1.0 41 | mask: 42 | - aspect_ratio: 43 | - 0.75 44 | - 1.5 45 | full_complement: false 46 | max_keep: null 47 | max_temporal_keep: 1.0 48 | num_blocks: 8 49 | spatial_scale: 50 | - 0.15 51 | - 0.15 52 | temporal_scale: 53 | - 1.0 54 | - 1.0 55 | - aspect_ratio: 56 | - 0.75 57 | - 1.5 58 | full_complement: false 59 | max_keep: null 60 | max_temporal_keep: 1.0 61 | num_blocks: 2 62 | spatial_scale: 63 | - 0.7 64 | - 0.7 65 | temporal_scale: 66 | - 1.0 67 | - 1.0 68 | meta: 69 | dtype: bfloat16 70 | eval_freq: 100 71 | load_checkpoint: true 72 | read_checkpoint: null 73 | save_every_freq: 50 74 | seed: 239 75 | use_sdpa: true 76 | model: 77 | model_name: vit_large 78 | pred_depth: 12 79 | pred_embed_dim: 384 80 | pred_num_heads: 12 81 | uniform_power: true 82 | use_activation_checkpointing: true 83 | use_mask_tokens: true 84 | use_rope: true 85 | zero_init_mask_tokens: true 86 | optimization: 87 | anneal_ckpt: /your_folder/pretrain/16.8.vitl.256px.16f/e0.pt 88 | ema: 89 | - 0.99925 90 | - 0.99925 91 | epochs: 40 92 | final_lr: 1.0e-06 93 | final_weight_decay: 0.04 94 | ipe: 300 95 | ipe_scale: 1.25 96 | is_anneal: true 97 | lr: 0.000525 98 | resume_anneal: true 99 | start_lr: 0.0001 100 | warmup: 0 101 | weight_decay: 0.04 102 | -------------------------------------------------------------------------------- /configs/train/vitl16/pretrain-256px-16f.yaml: -------------------------------------------------------------------------------- 1 | app: vjepa 2 | nodes: 16 3 | tasks_per_node: 8 4 | cpus_per_task: 16 5 | mem_per_gpu: 220G 6 | folder: /your_folder/pretrain/16.8.vitl.256px.16f 7 | data: 8 | dataset_type: VideoDataset 9 | datasets: 10 | - /your_k710_root_dir/k710_train_paths.csv 11 | - /your_data_path/ssv2_train_paths.csv 12 | - /your_data/howto_320p.csv 13 | datasets_weights: 14 | - 0.335 15 | - 0.100 16 | - 0.565 17 | batch_size: 24 18 | crop_size: 256 19 | patch_size: 16 20 | dataset_fpcs: 21 | - 16 22 | - 16 23 | - 16 24 | tubelet_size: 2 25 | fps: 4 26 | num_workers: 8 27 | persistent_workers: true 28 | pin_mem: true 29 | data_aug: 30 | auto_augment: false 31 | motion_shift: false 32 | random_resize_aspect_ratio: 33 | - 0.75 34 | - 1.35 35 | random_resize_scale: 36 | - 0.3 37 | - 1.0 38 | reprob: 0.0 39 | loss: 40 | loss_exp: 1.0 41 | mask: 42 | - aspect_ratio: 43 | - 0.75 44 | - 1.5 45 | full_complement: false 46 | max_keep: null 47 | max_temporal_keep: 1.0 48 | num_blocks: 8 49 | spatial_scale: 50 | - 0.15 51 | - 0.15 52 | temporal_scale: 53 | - 1.0 54 | - 1.0 55 | - aspect_ratio: 56 | - 0.75 57 | - 1.5 58 | full_complement: false 59 | max_keep: null 60 | max_temporal_keep: 1.0 61 | num_blocks: 2 62 | spatial_scale: 63 | - 0.7 64 | - 0.7 65 | temporal_scale: 66 | - 1.0 67 | - 1.0 68 | meta: 69 | dtype: bfloat16 70 | eval_freq: 100 71 | load_checkpoint: true 72 | read_checkpoint: null 73 | save_every_freq: 50 74 | seed: 239 75 | use_sdpa: true 76 | model: 77 | model_name: vit_large 78 | pred_depth: 12 79 | pred_embed_dim: 384 80 | pred_num_heads: 12 81 | uniform_power: true 82 | use_activation_checkpointing: true 83 | use_mask_tokens: true 84 | use_rope: true 85 | zero_init_mask_tokens: true 86 | optimization: 87 | ema: 88 | - 0.99925 89 | - 0.99925 90 | epochs: 10 91 | final_lr: 0.000525 92 | final_weight_decay: 0.04 93 | ipe: 300 94 | ipe_scale: 1.25 95 | lr: 0.000525 96 | start_lr: 0.0001 97 | warmup: 40 98 | weight_decay: 0.04 99 | -------------------------------------------------------------------------------- /evals/action_anticipation_frozen/losses.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch.nn.functional as F 7 | 8 | 9 | def sigmoid_focal_loss( 10 | inputs, 11 | targets, 12 | alpha=0.25, 13 | gamma=2.0, 14 | reduction="sum", 15 | detach=False, 16 | ): 17 | """ 18 | Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. 19 | 20 | :param Tensor inputs: Prediction logits for each sample [B x K] 21 | :param Tensor targets: Class label for each sample [B] (long tensor) 22 | :param float alpha: Weight in range (0,1) to balance pos vs neg samples. 23 | :param float gamma: Exponent of modulating factor (1-p_t) to balance easy vs hard samples. 24 | :param str reduction: 'mean' | 'sum' 25 | """ 26 | B, K = inputs.size() # [batch_size, class logits] 27 | 28 | # convert to one-hot targets 29 | targets = F.one_hot(targets, K).float() # [B, K] 30 | 31 | p = F.sigmoid(inputs) 32 | 33 | ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") 34 | p_t = p * targets + (1 - p) * (1 - targets) 35 | loss = ce_loss * ((1 - p_t) ** gamma) 36 | 37 | if alpha >= 0: 38 | alpha_t = alpha * targets + (1 - alpha) * (1 - targets) 39 | loss = alpha_t * loss 40 | 41 | if reduction == "mean": 42 | loss = loss.mean() 43 | elif reduction == "sum": 44 | loss = loss.sum() 45 | 46 | if detach: 47 | loss = loss.detach() 48 | 49 | return loss 50 | -------------------------------------------------------------------------------- /evals/action_anticipation_frozen/metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import torch.distributed as dist 8 | import torch.nn.functional as F 9 | 10 | 11 | class ClassMeanRecall: 12 | 13 | def __init__(self, num_classes: int, device: torch.device, k=5): 14 | self.num_classes = num_classes 15 | self.TP = torch.zeros(num_classes).to(device) 16 | self.FN = torch.zeros(num_classes).to(device) 17 | self.k = k 18 | 19 | def __call__(self, logits, labels, valid_classes=None, eps=1e-8): 20 | """ 21 | :param logits: Tensors of shape [B, num_classes] 22 | :param labels: Tensors of shape [B] 23 | :param valid_classes: set 24 | """ 25 | k, tp_tensor, fn_tensor = self.k, self.TP, self.FN 26 | logits = F.sigmoid(logits) 27 | 28 | if valid_classes is not None: 29 | _logits = torch.zeros(logits.shape).to(logits.device) 30 | for c in valid_classes: 31 | _logits[:, c] = logits[:, c] 32 | logits = _logits 33 | 34 | preds = logits.topk(k, dim=1).indices 35 | 36 | # Loop over batch and check whether all targets are within top-k logit 37 | # predictions for their respective class, if so TP else FN 38 | for p, gt in zip(preds, labels): 39 | if gt in p: 40 | tp_tensor[gt] += 1 41 | else: 42 | fn_tensor[gt] += 1 43 | 44 | # Aggregate TP/FN across all workers, but need to detach so that we 45 | # don't accidentally update tp_tensor and fn_tensor, which 46 | # only track local quantities. 47 | TP, FN = tp_tensor.clone(), fn_tensor.clone() 48 | dist.all_reduce(TP) 49 | dist.all_reduce(FN) 50 | 51 | nch = torch.sum((TP + FN) > 0) # num classes hit; may not have TP/FP data for all classes yet 52 | recall = 100.0 * torch.sum(TP / (TP + FN + eps)) / nch # mean class recall 53 | topk = 100.0 * sum(TP) / int(sum(TP + FN)) # accuracy 54 | 55 | return dict( 56 | recall=recall, 57 | accuracy=topk, 58 | ) 59 | -------------------------------------------------------------------------------- /evals/action_anticipation_frozen/models.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import importlib 7 | import logging 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | from src.models.attentive_pooler import AttentivePooler 13 | 14 | logging.basicConfig() 15 | logger = logging.getLogger() 16 | logger.setLevel(logging.INFO) 17 | 18 | 19 | class AttentiveClassifier(nn.Module): 20 | 21 | def __init__( 22 | self, 23 | verb_classes: dict, 24 | noun_classes: dict, 25 | action_classes: dict, 26 | embed_dim: int, 27 | num_heads: int, 28 | depth: int, 29 | use_activation_checkpointing: bool, 30 | ): 31 | super().__init__() 32 | self.num_verb_classes = len(verb_classes) 33 | num_noun_classes = len(noun_classes) 34 | num_action_classes = len(action_classes) 35 | self.action_only = self.num_verb_classes == 0 36 | 37 | self.pooler = AttentivePooler( 38 | num_queries=1 if self.action_only else 3, 39 | embed_dim=embed_dim, 40 | num_heads=num_heads, 41 | depth=depth, 42 | use_activation_checkpointing=use_activation_checkpointing, 43 | ) 44 | if not self.action_only: 45 | self.verb_classifier = nn.Linear(embed_dim, self.num_verb_classes, bias=True) 46 | self.noun_classifier = nn.Linear(embed_dim, num_noun_classes, bias=True) 47 | self.action_classifier = nn.Linear(embed_dim, num_action_classes, bias=True) 48 | 49 | def forward(self, x): 50 | if torch.isnan(x).any(): 51 | print("Nan detected at output of encoder") 52 | exit(1) 53 | 54 | x = self.pooler(x) # [B, 2, D] 55 | if not self.action_only: 56 | x_verb, x_noun, x_action = x[:, 0, :], x[:, 1, :], x[:, 2, :] 57 | x_verb = self.verb_classifier(x_verb) 58 | x_noun = self.noun_classifier(x_noun) 59 | x_action = self.action_classifier(x_action) 60 | return dict( 61 | verb=x_verb, 62 | noun=x_noun, 63 | action=x_action, 64 | ) 65 | else: 66 | x_action = x[:, 0, :] 67 | x_action = self.action_classifier(x_action) 68 | return dict(action=x_action) 69 | 70 | 71 | def init_module( 72 | module_name, 73 | device, 74 | frames_per_clip, 75 | frames_per_second, 76 | resolution, 77 | checkpoint, 78 | model_kwargs, 79 | wrapper_kwargs, 80 | ): 81 | """ 82 | Build (frozen) model and initialize from pretrained checkpoint 83 | 84 | API requirements for "model" module: 85 | 1) Needs to be a pytorch module with 'forward()' function protocol: 86 | :param x: (Tensor) Video clip (shape=[batch_size x num_channels x num_frames x height x width]) 87 | :param anticipation_time: (Tensor) Seconds into the future to predict for each sample in batch 88 | (shape=[batch_size]) 89 | :returns: (Tensor) Representations of future frames (shape=[batch_size x num_output_tokens x feature_dim]) 90 | 91 | 2) Needs to have a public attribute called 'embed_dim' (int) describing its 92 | output feature dimension. 93 | """ 94 | model = ( 95 | importlib.import_module(f"{module_name}") 96 | .init_module( 97 | frames_per_clip=frames_per_clip, 98 | frames_per_second=frames_per_second, 99 | resolution=resolution, 100 | checkpoint=checkpoint, 101 | model_kwargs=model_kwargs, 102 | wrapper_kwargs=wrapper_kwargs, 103 | ) 104 | .to(device) 105 | ) 106 | model.eval() 107 | for p in model.parameters(): 108 | p.requires_grad = False 109 | print(model) 110 | return model 111 | 112 | 113 | def init_classifier( 114 | embed_dim: int, 115 | num_heads: int, 116 | num_blocks: int, 117 | device: torch.device, 118 | num_classifiers: int, 119 | action_classes: dict, 120 | verb_classes: dict, 121 | noun_classes: dict, 122 | ): 123 | classifiers = [ 124 | AttentiveClassifier( 125 | verb_classes=verb_classes, 126 | noun_classes=noun_classes, 127 | action_classes=action_classes, 128 | embed_dim=embed_dim, 129 | num_heads=num_heads, 130 | depth=num_blocks, 131 | use_activation_checkpointing=True, 132 | ).to(device) 133 | for _ in range(num_classifiers) 134 | ] 135 | print(classifiers[0]) 136 | return classifiers 137 | -------------------------------------------------------------------------------- /evals/action_anticipation_frozen/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import logging 7 | import math 8 | 9 | import torch 10 | 11 | logging.basicConfig() 12 | logger = logging.getLogger() 13 | logger.setLevel(logging.INFO) 14 | 15 | 16 | def init_opt(classifiers, iterations_per_epoch, opt_kwargs, num_epochs, use_bfloat16=False): 17 | optimizers, schedulers, wd_schedulers, scalers = [], [], [], [] 18 | for c, kwargs in zip(classifiers, opt_kwargs): 19 | param_groups = [ 20 | { 21 | "params": (p for n, p in c.named_parameters()), 22 | "mc_warmup_steps": int(kwargs.get("warmup") * iterations_per_epoch), 23 | "mc_start_lr": kwargs.get("start_lr"), 24 | "mc_ref_lr": kwargs.get("ref_lr"), 25 | "mc_final_lr": kwargs.get("final_lr"), 26 | "mc_ref_wd": kwargs.get("ref_wd"), 27 | "mc_final_wd": kwargs.get("final_wd"), 28 | } 29 | ] 30 | logger.info("Using AdamW") 31 | optimizers += [torch.optim.AdamW(param_groups)] 32 | schedulers += [WarmupCosineLRSchedule(optimizers[-1], T_max=int(num_epochs * iterations_per_epoch))] 33 | wd_schedulers += [CosineWDSchedule(optimizers[-1], T_max=int(num_epochs * iterations_per_epoch))] 34 | scalers += [torch.cuda.amp.GradScaler() if use_bfloat16 else None] 35 | return optimizers, scalers, schedulers, wd_schedulers 36 | 37 | 38 | class WarmupCosineLRSchedule(object): 39 | 40 | def __init__(self, optimizer, T_max, last_epoch=-1): 41 | self.optimizer = optimizer 42 | self.T_max = T_max 43 | self._step = 0.0 44 | 45 | def step(self): 46 | self._step += 1 47 | for group in self.optimizer.param_groups: 48 | ref_lr = group.get("mc_ref_lr") 49 | final_lr = group.get("mc_final_lr") 50 | start_lr = group.get("mc_start_lr") 51 | warmup_steps = group.get("mc_warmup_steps") 52 | T_max = self.T_max - warmup_steps 53 | if self._step < warmup_steps: 54 | progress = float(self._step) / float(max(1, warmup_steps)) 55 | new_lr = start_lr + progress * (ref_lr - start_lr) 56 | else: 57 | # -- progress after warmup 58 | progress = float(self._step - warmup_steps) / float(max(1, T_max)) 59 | new_lr = max( 60 | final_lr, 61 | final_lr + (ref_lr - final_lr) * 0.5 * (1.0 + math.cos(math.pi * progress)), 62 | ) 63 | group["lr"] = new_lr 64 | 65 | 66 | class CosineWDSchedule(object): 67 | 68 | def __init__(self, optimizer, T_max): 69 | self.optimizer = optimizer 70 | self.T_max = T_max 71 | self._step = 0.0 72 | 73 | def step(self): 74 | self._step += 1 75 | progress = self._step / self.T_max 76 | 77 | for group in self.optimizer.param_groups: 78 | ref_wd = group.get("mc_ref_wd") 79 | final_wd = group.get("mc_final_wd") 80 | new_wd = final_wd + (ref_wd - final_wd) * 0.5 * (1.0 + math.cos(math.pi * progress)) 81 | if final_wd <= ref_wd: 82 | new_wd = max(final_wd, new_wd) 83 | else: 84 | new_wd = min(final_wd, new_wd) 85 | group["weight_decay"] = new_wd 86 | -------------------------------------------------------------------------------- /evals/hub/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/vjepa2/2dbda75ac7ab8e4e16f1c987c2b76261b54a6e3d/evals/hub/__init__.py -------------------------------------------------------------------------------- /evals/hub/preprocessor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | def _make_transforms(crop_size=256): 8 | from ..video_classification_frozen.utils import make_transforms 9 | 10 | return make_transforms(crop_size=crop_size, training=False) 11 | 12 | 13 | def vjepa2_preprocessor(*, pretrained: bool = True, **kwargs): 14 | crop_size = kwargs.get("crop_size", 256) 15 | return _make_transforms(crop_size=crop_size) 16 | -------------------------------------------------------------------------------- /evals/image_classification_frozen/modelcustom/vit_encoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | ------------------------------------------------------------------------------ 7 | 8 | modelcustom API requirements: 9 | 10 | API requirements for Encoder module: 11 | 1) Needs to be a pytorch module with 'forward()' function protocol: 12 | :param x: (Tensor) Video clip (shape=[batch_size x num_channels x num_frames x height x width]) 13 | :returns: (Tensor) Representations of video clip (shape=[batch_size x num_encoder_tokens x feature_dim]) 14 | 2) Needs to have a public attribute called 'embed_dim' (int) describing its 15 | output feature dimension. 16 | 17 | API requirements for Predictor module: 18 | 1) Needs to be a pytorch module with 'forward()' function protocol: 19 | :param x: (Tensor) Video clip tokens (shape=[batch_size x num_encoder_tokens x feature_dim]) 20 | :param anticipation_time: (Tensor) Seconds into the future to predict for each sample in batch 21 | (shape=[batch_size]) 22 | :returns: (Tensor) Representations of future frames (shape=[batch_size x num_output_tokens x feature_dim]) 23 | 2) Needs to have a public attribute called 'embed_dim' (int) describing its 24 | output feature dimension. 25 | """ 26 | 27 | import logging 28 | 29 | import torch 30 | 31 | import src.models.vision_transformer as vit 32 | 33 | logging.basicConfig() 34 | logger = logging.getLogger() 35 | logger.setLevel(logging.INFO) 36 | 37 | 38 | def init_module( 39 | resolution: int, 40 | checkpoint: str, 41 | # -- 42 | model_kwargs: dict, 43 | wrapper_kwargs: dict, 44 | **kwargs, 45 | ): 46 | logger.info(f"Loading pretrained model from {checkpoint=}") 47 | checkpoint = torch.load(checkpoint, map_location="cpu") 48 | 49 | img_as_video_nframes = wrapper_kwargs.get("img_as_video_nframes") 50 | # -- 51 | enc_kwargs = model_kwargs["encoder"] 52 | enc_ckp_key = enc_kwargs.get("checkpoint_key") 53 | enc_model_name = enc_kwargs.get("model_name") 54 | 55 | model = vit.__dict__[enc_model_name]( 56 | input_size=resolution, 57 | num_frames=img_as_video_nframes, 58 | **enc_kwargs, 59 | ) 60 | 61 | def forward_prehook(module, input): 62 | input = input[0] # [B, C, H, W] 63 | input = input.unsqueeze(2).repeat(1, 1, img_as_video_nframes, 1, 1) 64 | return input 65 | 66 | model.register_forward_pre_hook(forward_prehook) 67 | 68 | pretrained_dict = checkpoint[enc_ckp_key] 69 | # -- 70 | pretrained_dict = {k.replace("module.", ""): v for k, v in pretrained_dict.items()} 71 | pretrained_dict = {k.replace("backbone.", ""): v for k, v in pretrained_dict.items()} 72 | for k, v in model.state_dict().items(): 73 | if k not in pretrained_dict: 74 | logger.info(f'key "{k}" could not be found in loaded state dict') 75 | elif pretrained_dict[k].shape != v.shape: 76 | logger.info(f'key "{k}" is of different shape in model and loaded state dict') 77 | pretrained_dict[k] = v 78 | msg = model.load_state_dict(pretrained_dict, strict=False) 79 | logger.info(f"loaded pretrained model with msg: {msg}") 80 | print(model) 81 | 82 | del checkpoint 83 | return model 84 | -------------------------------------------------------------------------------- /evals/image_classification_frozen/models.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import importlib 7 | import logging 8 | 9 | logging.basicConfig() 10 | logger = logging.getLogger() 11 | logger.setLevel(logging.INFO) 12 | 13 | 14 | def init_module( 15 | module_name, 16 | device, 17 | resolution, 18 | checkpoint, 19 | model_kwargs, 20 | wrapper_kwargs, 21 | ): 22 | """ 23 | Build (frozen) model and initialize from pretrained checkpoint 24 | 25 | API requirements for Encoder module: 26 | 1) Needs to be a pytorch module with 'forward()' function protocol: 27 | :param x: (Tensor) Video clip (shape=[batch_size x num_channels x num_frames x height x width]) 28 | :returns: (Tensor) Representations of video clip (shape=[batch_size x num_encoder_tokens x feature_dim]) 29 | """ 30 | model = ( 31 | importlib.import_module(f"{module_name}") 32 | .init_module( 33 | resolution=resolution, 34 | checkpoint=checkpoint, 35 | model_kwargs=model_kwargs, 36 | wrapper_kwargs=wrapper_kwargs, 37 | ) 38 | .to(device) 39 | ) 40 | model.eval() 41 | for p in model.parameters(): 42 | p.requires_grad = False 43 | print(model) 44 | return model 45 | -------------------------------------------------------------------------------- /evals/main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import argparse 7 | import multiprocessing as mp 8 | import os 9 | import pprint 10 | 11 | import yaml 12 | 13 | from evals.scaffold import main as eval_main 14 | from src.utils.distributed import init_distributed 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument("--val_only", action="store_true", help="only run eval", default=False) 18 | parser.add_argument("--fname", type=str, help="name of config file to load", default="configs.yaml") 19 | parser.add_argument( 20 | "--devices", 21 | type=str, 22 | nargs="+", 23 | default=["cuda:0", "cuda:1", "cuda:2", "cuda:3", "cuda:4", "cuda:5", "cuda:6", "cuda:7"], 24 | help="which devices to use on local machine", 25 | ) 26 | parser.add_argument( 27 | "--debugmode", 28 | type=bool, 29 | default=False, 30 | help="Setting this to true will not spin up new processes. " 31 | "The main code runs the main process, which makes it easier to debug with checkpointing.", 32 | ) 33 | parser.add_argument( 34 | "--folder", 35 | type=str, 36 | help="location to save logs", 37 | default="", 38 | ) 39 | parser.add_argument("--override_config_folder", action="store_true") 40 | parser.add_argument("--checkpoint", type=str, help="location of pretrained ckpt") 41 | parser.add_argument("--model_name", type=str, help="Model name") 42 | parser.add_argument("--batch_size", type=int) 43 | parser.add_argument("--use_fsdp", action="store_true") 44 | 45 | 46 | def process_main(args, rank, fname, world_size, devices): 47 | import logging 48 | import os 49 | 50 | os.environ["CUDA_VISIBLE_DEVICES"] = str(devices[rank].split(":")[-1]) 51 | 52 | logging.basicConfig() 53 | logger = logging.getLogger() 54 | if rank == 0: 55 | logger.setLevel(logging.INFO) 56 | else: 57 | logger.setLevel(logging.ERROR) 58 | 59 | logger.info(f"called-params {fname}") 60 | 61 | # Load config 62 | params = None 63 | with open(fname, "r") as y_file: 64 | params = yaml.load(y_file, Loader=yaml.FullLoader) 65 | if args.val_only: 66 | params["val_only"] = True 67 | 68 | if args.checkpoint: 69 | params["model_kwargs"]["checkpoint"] = args.checkpoint 70 | 71 | if args.model_name: 72 | params["model_kwargs"]["pretrain_kwargs"]["encoder"]["model_name"] = args.model_name 73 | 74 | if args.batch_size: 75 | params["experiment"]["optimization"]["batch_size"] = args.batch_size 76 | 77 | if args.override_config_folder: 78 | params["folder"] = args.folder 79 | params["use_fsdp"] = args.use_fsdp 80 | logger.info("loaded params...") 81 | 82 | if rank == 0: 83 | pprint.PrettyPrinter(indent=4).pprint(params) 84 | 85 | # Init distributed (access to comm between GPUS on same machine) 86 | world_size, rank = init_distributed(rank_and_world_size=(rank, world_size)) 87 | logger.info(f"Running... (rank: {rank}/{world_size})") 88 | 89 | # Launch the eval with loaded config 90 | eval_main(params["eval_name"], args_eval=params) 91 | 92 | 93 | if __name__ == "__main__": 94 | args = parser.parse_args() 95 | if args.debugmode: 96 | # FSDP debugging (use torchrun) 97 | if args.use_fsdp: 98 | process_main( 99 | args=args, 100 | rank=int(os.environ["RANK"]), 101 | fname=args.fname, 102 | world_size=int(os.environ["WORLD_SIZE"]), 103 | devices=args.devices, 104 | ) 105 | # Single-GPU debugging 106 | else: 107 | process_main(args=args, rank=0, fname=args.fname, world_size=1, devices=["cuda:0"]) 108 | else: 109 | num_gpus = len(args.devices) 110 | mp.set_start_method("spawn") 111 | for rank in range(num_gpus): 112 | mp.Process(target=process_main, args=(args, rank, args.fname, num_gpus, args.devices)).start() 113 | -------------------------------------------------------------------------------- /evals/scaffold.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import importlib 7 | 8 | from src.utils.logging import get_logger 9 | 10 | logger = get_logger("Eval runner scaffold") 11 | 12 | 13 | def main(eval_name, args_eval, resume_preempt=False): 14 | logger.info(f"Running evaluation: {eval_name}") 15 | if eval_name.startswith("app."): 16 | import_path = f"{eval_name}.eval" 17 | else: 18 | import_path = f"evals.{eval_name}.eval" 19 | return importlib.import_module(import_path).main(args_eval=args_eval, resume_preempt=resume_preempt) 20 | -------------------------------------------------------------------------------- /evals/video_classification_frozen/models.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import importlib 7 | import logging 8 | 9 | logging.basicConfig() 10 | logger = logging.getLogger() 11 | logger.setLevel(logging.INFO) 12 | 13 | 14 | def init_module( 15 | module_name, 16 | device, 17 | frames_per_clip, 18 | resolution, 19 | checkpoint, 20 | model_kwargs, 21 | wrapper_kwargs, 22 | ): 23 | """ 24 | Build (frozen) model and initialize from pretrained checkpoint 25 | 26 | API requirements for Encoder module: 27 | 1) Needs to be a pytorch module with 'forward()' function protocol: 28 | :param x: (Tensor) Video clip (shape=[batch_size x num_channels x num_frames x height x width]) 29 | :returns: (Tensor) Representations of video clip (shape=[batch_size x num_encoder_tokens x feature_dim]) 30 | """ 31 | model = ( 32 | importlib.import_module(f"{module_name}") 33 | .init_module( 34 | frames_per_clip=frames_per_clip, 35 | resolution=resolution, 36 | checkpoint=checkpoint, 37 | model_kwargs=model_kwargs, 38 | wrapper_kwargs=wrapper_kwargs, 39 | ) 40 | .to(device) 41 | ) 42 | model.eval() 43 | for p in model.parameters(): 44 | p.requires_grad = False 45 | print(model) 46 | return model 47 | -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from evals.hub.preprocessor import vjepa2_preprocessor 7 | from src.hub.backbones import ( 8 | vjepa2_ac_vit_giant, 9 | vjepa2_vit_giant, 10 | vjepa2_vit_giant_384, 11 | vjepa2_vit_huge, 12 | vjepa2_vit_large, 13 | ) 14 | 15 | dependencies = ["torch", "timm", "einops"] 16 | -------------------------------------------------------------------------------- /notebooks/franka_example_traj.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/vjepa2/2dbda75ac7ab8e4e16f1c987c2b76261b54a6e3d/notebooks/franka_example_traj.npz -------------------------------------------------------------------------------- /notebooks/utils/world_model_wrapper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import numpy as np 7 | import torch.nn.functional as F 8 | 9 | from .mpc_utils import cem, compute_new_pose 10 | 11 | 12 | class WorldModel(object): 13 | 14 | def __init__( 15 | self, 16 | encoder, 17 | predictor, 18 | tokens_per_frame, 19 | transform, 20 | mpc_args={ 21 | "rollout": 2, 22 | "samples": 400, 23 | "topk": 10, 24 | "cem_steps": 10, 25 | "momentum_mean": 0.15, 26 | "momentum_std": 0.15, 27 | "maxnorm": 0.05, 28 | "verbose": True, 29 | }, 30 | normalize_reps=True, 31 | device="cuda:0", 32 | ): 33 | super().__init__() 34 | self.encoder = encoder 35 | self.predictor = predictor 36 | self.normalize_reps = normalize_reps 37 | self.transform = transform 38 | self.tokens_per_frame = tokens_per_frame 39 | self.device = device 40 | self.mpc_args = mpc_args 41 | 42 | def encode(self, image): 43 | clip = np.expand_dims(image, axis=0) 44 | clip = self.transform(clip)[None, :] 45 | B, C, T, H, W = clip.size() 46 | clip = clip.permute(0, 2, 1, 3, 4).flatten(0, 1).unsqueeze(2).repeat(1, 1, 2, 1, 1) 47 | clip = clip.to(self.device, non_blocking=True) 48 | h = self.encoder(clip) 49 | h = h.view(B, T, -1, h.size(-1)).flatten(1, 2) 50 | if self.normalize_reps: 51 | h = F.layer_norm(h, (h.size(-1),)) 52 | return h 53 | 54 | def infer_next_action(self, rep, pose, goal_rep, close_gripper=None): 55 | 56 | def step_predictor(reps, actions, poses): 57 | B, T, N_T, D = reps.size() 58 | reps = reps.flatten(1, 2) 59 | next_rep = self.predictor(reps, actions, poses)[:, -self.tokens_per_frame :] 60 | if self.normalize_reps: 61 | next_rep = F.layer_norm(next_rep, (next_rep.size(-1),)) 62 | next_rep = next_rep.view(B, 1, N_T, D) 63 | next_pose = compute_new_pose(poses[:, -1:], actions[:, -1:]) 64 | return next_rep, next_pose 65 | 66 | mpc_action = cem( 67 | context_frame=rep, 68 | context_pose=pose, 69 | goal_frame=goal_rep, 70 | world_model=step_predictor, 71 | close_gripper=close_gripper, 72 | **self.mpc_args, 73 | )[0] 74 | 75 | return mpc_action 76 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.isort] 2 | profile="black" 3 | line_length=119 4 | 5 | [tool.black] 6 | line-length = 119 7 | -------------------------------------------------------------------------------- /requirements-test.txt: -------------------------------------------------------------------------------- 1 | # Tools for static checking. 2 | black == 24.4.2 3 | flake8 == 7.0.0 4 | isort == 5.13.2 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=2 2 | torchvision 3 | tensorboard 4 | wandb 5 | iopath 6 | pyyaml 7 | numpy 8 | opencv-python 9 | submitit 10 | braceexpand 11 | webdataset 12 | timm 13 | transformers 14 | peft 15 | decord 16 | pandas 17 | einops 18 | beartype 19 | psutil 20 | h5py 21 | fire 22 | python-box 23 | scikit-image 24 | ftfy 25 | jupyter 26 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from setuptools import setup 7 | 8 | NAME = "vjepa2" 9 | VERSION = "0.0.1" 10 | DESCRIPTION = "PyTorch code and models for V-JEPA 2." 11 | URL = "https://github.com/facebookresearch/vjepa2" 12 | 13 | 14 | def get_requirements(): 15 | with open("./requirements.txt") as reqsf: 16 | reqs = reqsf.readlines() 17 | return reqs 18 | 19 | 20 | if __name__ == "__main__": 21 | setup( 22 | name=NAME, 23 | version=VERSION, 24 | description=DESCRIPTION, 25 | url=URL, 26 | python_requires=">=3.11", 27 | install_requires=get_requirements(), 28 | ) 29 | -------------------------------------------------------------------------------- /src/datasets/data_manager.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from logging import getLogger 7 | 8 | _GLOBAL_SEED = 0 9 | logger = getLogger() 10 | 11 | 12 | def init_data( 13 | batch_size, 14 | transform=None, 15 | shared_transform=None, 16 | data="ImageNet", 17 | collator=None, 18 | pin_mem=True, 19 | num_workers=8, 20 | world_size=1, 21 | rank=0, 22 | root_path=None, 23 | image_folder=None, 24 | training=True, 25 | drop_last=True, 26 | subset_file=None, 27 | clip_len=None, 28 | dataset_fpcs=None, 29 | frame_sample_rate=None, 30 | duration=None, 31 | fps=None, 32 | num_clips=1, 33 | random_clip_sampling=True, 34 | allow_clip_overlap=False, 35 | filter_short_videos=False, 36 | filter_long_videos=int(1e9), 37 | datasets_weights=None, 38 | persistent_workers=False, 39 | deterministic=True, 40 | log_dir=None, 41 | ): 42 | if data.lower() == "imagenet": 43 | from src.datasets.imagenet1k import make_imagenet1k 44 | 45 | dataset, data_loader, dist_sampler = make_imagenet1k( 46 | transform=transform, 47 | batch_size=batch_size, 48 | collator=collator, 49 | pin_mem=pin_mem, 50 | training=training, 51 | num_workers=num_workers, 52 | world_size=world_size, 53 | rank=rank, 54 | root_path=root_path, 55 | image_folder=image_folder, 56 | persistent_workers=persistent_workers, 57 | drop_last=drop_last, 58 | subset_file=subset_file, 59 | ) 60 | 61 | elif data.lower() == "videodataset": 62 | from src.datasets.video_dataset import make_videodataset 63 | 64 | dataset, data_loader, dist_sampler = make_videodataset( 65 | data_paths=root_path, 66 | batch_size=batch_size, 67 | frames_per_clip=clip_len, 68 | dataset_fpcs=dataset_fpcs, 69 | frame_step=frame_sample_rate, 70 | duration=duration, 71 | fps=fps, 72 | num_clips=num_clips, 73 | random_clip_sampling=random_clip_sampling, 74 | allow_clip_overlap=allow_clip_overlap, 75 | filter_short_videos=filter_short_videos, 76 | filter_long_videos=filter_long_videos, 77 | shared_transform=shared_transform, 78 | transform=transform, 79 | datasets_weights=datasets_weights, 80 | collator=collator, 81 | num_workers=num_workers, 82 | pin_mem=pin_mem, 83 | persistent_workers=persistent_workers, 84 | world_size=world_size, 85 | rank=rank, 86 | deterministic=deterministic, 87 | log_dir=log_dir, 88 | ) 89 | 90 | return (data_loader, dist_sampler) 91 | -------------------------------------------------------------------------------- /src/datasets/imagenet1k.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import os 7 | import subprocess 8 | import time 9 | from logging import getLogger 10 | 11 | import numpy as np 12 | import torch 13 | import torchvision 14 | 15 | _GLOBAL_SEED = 0 16 | logger = getLogger() 17 | 18 | 19 | class ImageNet(torchvision.datasets.ImageFolder): 20 | 21 | def __init__( 22 | self, 23 | root, 24 | image_folder="imagenet_full_size/061417/", 25 | tar_file="imagenet_full_size-061417.tar.gz", 26 | transform=None, 27 | train=True, 28 | job_id=None, 29 | local_rank=None, 30 | index_targets=False, 31 | ): 32 | """ 33 | ImageNet 34 | 35 | Dataset wrapper 36 | 37 | :param root: root network directory for ImageNet data 38 | :param image_folder: path to images inside root network directory 39 | :param tar_file: zipped image_folder inside root network directory 40 | :param train: whether to load train data (or validation) 41 | :param job_id: scheduler job-id used to create dir on local machine 42 | :param index_targets: whether to index the id of each labeled image 43 | """ 44 | 45 | suffix = "train/" if train else "val/" 46 | data_path = os.path.join(root, image_folder, suffix) 47 | logger.info(f"data-path {data_path}") 48 | 49 | super(ImageNet, self).__init__(root=data_path, transform=transform) 50 | logger.info("Initialized ImageNet") 51 | 52 | if index_targets: 53 | self.targets = [] 54 | for sample in self.samples: 55 | self.targets.append(sample[1]) 56 | self.targets = np.array(self.targets) 57 | self.samples = np.array(self.samples) 58 | 59 | mint = None 60 | self.target_indices = [] 61 | for t in range(len(self.classes)): 62 | indices = np.squeeze(np.argwhere(self.targets == t)).tolist() 63 | self.target_indices.append(indices) 64 | mint = len(indices) if mint is None else min(mint, len(indices)) 65 | logger.debug(f"num-labeled target {t} {len(indices)}") 66 | logger.info(f"min. labeled indices {mint}") 67 | 68 | 69 | class ImageNetSubset(object): 70 | 71 | def __init__(self, dataset, subset_file): 72 | """ 73 | ImageNetSubset 74 | 75 | :param dataset: ImageNet dataset object 76 | :param subset_file: '.txt' file containing IDs of IN1K images to keep 77 | """ 78 | self.dataset = dataset 79 | self.subset_file = subset_file 80 | self.filter_dataset_(subset_file) 81 | 82 | def filter_dataset_(self, subset_file): 83 | """Filter self.dataset to a subset""" 84 | root = self.dataset.root 85 | class_to_idx = self.dataset.class_to_idx 86 | # -- update samples to subset of IN1k targets/samples 87 | new_samples = [] 88 | logger.info(f"Using {subset_file}") 89 | with open(subset_file, "r") as rfile: 90 | for line in rfile: 91 | class_name = line.split("_")[0] 92 | target = class_to_idx[class_name] 93 | img = line.split("\n")[0] 94 | new_samples.append((os.path.join(root, class_name, img), target)) 95 | self.samples = new_samples 96 | 97 | @property 98 | def classes(self): 99 | return self.dataset.classes 100 | 101 | def __len__(self): 102 | return len(self.samples) 103 | 104 | def __getitem__(self, index): 105 | path, target = self.samples[index] 106 | img = self.dataset.loader(path) 107 | if self.dataset.transform is not None: 108 | img = self.dataset.transform(img) 109 | if self.dataset.target_transform is not None: 110 | target = self.dataset.target_transform(target) 111 | return img, target 112 | 113 | 114 | def make_imagenet1k( 115 | transform, 116 | batch_size, 117 | collator=None, 118 | pin_mem=True, 119 | num_workers=8, 120 | world_size=1, 121 | rank=0, 122 | root_path=None, 123 | image_folder=None, 124 | training=True, 125 | drop_last=True, 126 | persistent_workers=False, 127 | subset_file=None, 128 | ): 129 | dataset = ImageNet( 130 | root=root_path, 131 | image_folder=image_folder, 132 | transform=transform, 133 | train=training, 134 | index_targets=False, 135 | ) 136 | if subset_file is not None: 137 | dataset = ImageNetSubset(dataset, subset_file) 138 | logger.info("ImageNet dataset created") 139 | dist_sampler = torch.utils.data.distributed.DistributedSampler(dataset=dataset, num_replicas=world_size, rank=rank) 140 | data_loader = torch.utils.data.DataLoader( 141 | dataset, 142 | collate_fn=collator, 143 | sampler=dist_sampler, 144 | batch_size=batch_size, 145 | drop_last=drop_last, 146 | pin_memory=pin_mem, 147 | num_workers=num_workers, 148 | persistent_workers=persistent_workers, 149 | ) 150 | logger.info("ImageNet unsupervised data loader created") 151 | 152 | return dataset, data_loader, dist_sampler 153 | -------------------------------------------------------------------------------- /src/datasets/utils/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from src.utils.cluster import dataset_paths 7 | from src.utils.logging import get_logger 8 | 9 | logger = get_logger("Datasets utils") 10 | 11 | 12 | def get_dataset_paths(datasets: list[str]): 13 | paths = [] 14 | for d in datasets: 15 | try: 16 | path = dataset_paths().get(d) 17 | except Exception: 18 | raise Exception(f"Unknown dataset: {d}") 19 | paths.append(path) 20 | logger.info(f"Datapaths {paths}") 21 | return paths 22 | -------------------------------------------------------------------------------- /src/datasets/utils/video/functional.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import numbers 7 | 8 | import cv2 9 | import numpy as np 10 | import PIL 11 | import torch 12 | from torchvision.transforms import functional as tvf 13 | 14 | 15 | def _is_tensor_clip(clip): 16 | return torch.is_tensor(clip) and clip.ndimension() == 4 17 | 18 | 19 | def crop_clip(clip, min_h, min_w, h, w): 20 | if isinstance(clip[0], np.ndarray) or isinstance(clip[0], torch.Tensor): 21 | if clip[0].shape[-1] == 3: 22 | cropped = [img[min_h : min_h + h, min_w : min_w + w, :] for img in clip] 23 | else: 24 | assert clip[0].shape[0] == 3 25 | cropped = [img[:, min_h : min_h + h, min_w : min_w + w] for img in clip] 26 | 27 | elif isinstance(clip[0], PIL.Image.Image): 28 | cropped = [img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip] 29 | 30 | else: 31 | raise TypeError( 32 | "Expected numpy.ndarray or PIL.Image or torch.Tensor):" + "but got list of {0}".format(type(clip[0])) 33 | ) 34 | return cropped 35 | 36 | 37 | def resize_clip(clip, size, interpolation="bilinear"): 38 | if isinstance(clip[0], np.ndarray) or isinstance(clip[0], torch.Tensor): 39 | if isinstance(size, numbers.Number): 40 | if clip[0].shape[-1] == 3: 41 | im_h, im_w, im_c = clip[0].shape 42 | else: 43 | assert clip[0].shape[0] == 3 44 | im_c, im_h, im_w = clip[0].shape 45 | # Min spatial dim already matches minimal size 46 | if (im_w <= im_h and im_w == size) or (im_h <= im_w and im_h == size): 47 | return clip 48 | new_h, new_w = get_resize_sizes(im_h, im_w, size) 49 | size = (new_w, new_h) 50 | else: 51 | size = size[0], size[1] 52 | 53 | if isinstance(clip[0], np.ndarray): 54 | if interpolation == "bilinear": 55 | np_inter = cv2.INTER_LINEAR 56 | else: 57 | np_inter = cv2.INTER_NEAREST 58 | scaled = [cv2.resize(img, size, interpolation=np_inter) for img in clip] 59 | else: # isinstance(clip[0], torch.Tensor) 60 | if interpolation == "bilinear": 61 | np_inter = tvf.InterpolationMode.BILINEAR 62 | else: 63 | np_inter = tvf.InterpolationMode.NEAREST 64 | size = (size[1], size[0]) # torchvision transformers expect the size in (h, w) order. 65 | scaled = [tvf.resize(img, size, interpolation=np_inter) for img in clip] 66 | elif isinstance(clip[0], PIL.Image.Image): 67 | if isinstance(size, numbers.Number): 68 | im_w, im_h = clip[0].size 69 | # Min spatial dim already matches minimal size 70 | if (im_w <= im_h and im_w == size) or (im_h <= im_w and im_h == size): 71 | return clip 72 | new_h, new_w = get_resize_sizes(im_h, im_w, size) 73 | size = (new_w, new_h) 74 | else: 75 | size = size[1], size[0] 76 | if interpolation == "bilinear": 77 | pil_inter = PIL.Image.BILINEAR 78 | else: 79 | pil_inter = PIL.Image.NEAREST 80 | scaled = [img.resize(size, pil_inter) for img in clip] 81 | else: 82 | raise TypeError( 83 | "Expected numpy.ndarray or PIL.Image or torch.Tensor" + "but got list of {0}".format(type(clip[0])) 84 | ) 85 | return scaled 86 | 87 | 88 | def get_resize_sizes(im_h, im_w, size): 89 | if im_w < im_h: 90 | ow = size 91 | oh = int(size * im_h / im_w) 92 | else: 93 | oh = size 94 | ow = int(size * im_w / im_h) 95 | return oh, ow 96 | 97 | 98 | def normalize(clip, mean, std, inplace=False): 99 | if not _is_tensor_clip(clip): 100 | raise TypeError("tensor is not a torch clip.") 101 | 102 | if not inplace: 103 | clip = clip.clone() 104 | 105 | dtype = clip.dtype 106 | mean = torch.as_tensor(mean, dtype=dtype, device=clip.device) 107 | std = torch.as_tensor(std, dtype=dtype, device=clip.device) 108 | clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None]) 109 | 110 | return clip 111 | -------------------------------------------------------------------------------- /src/datasets/utils/video/transforms_builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from typing import Optional 7 | 8 | import torch 9 | import torchvision.transforms as transforms 10 | 11 | import src.datasets.utils.video.transforms as video_transforms 12 | from src.datasets.utils.video.randerase import RandomErasing 13 | 14 | 15 | def make_transforms( 16 | random_horizontal_flip=True, 17 | random_resize_aspect_ratio=(3 / 4, 4 / 3), 18 | random_resize_scale=(0.3, 1.0), 19 | reprob=0.0, 20 | auto_augment=False, 21 | motion_shift=False, 22 | crop_size=224, 23 | normalize=((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 24 | pad_frame_count: Optional[int] = None, 25 | pad_frame_method: str = "circulant", 26 | ): 27 | _frames_augmentation = VideoTransform( 28 | random_horizontal_flip=random_horizontal_flip, 29 | random_resize_aspect_ratio=random_resize_aspect_ratio, 30 | random_resize_scale=random_resize_scale, 31 | reprob=reprob, 32 | auto_augment=auto_augment, 33 | motion_shift=motion_shift, 34 | crop_size=crop_size, 35 | normalize=normalize, 36 | pad_frame_count=pad_frame_count, 37 | pad_frame_method=pad_frame_method, 38 | ) 39 | return _frames_augmentation 40 | 41 | 42 | class VideoTransform(object): 43 | 44 | def __init__( 45 | self, 46 | random_horizontal_flip=True, 47 | random_resize_aspect_ratio=(3 / 4, 4 / 3), 48 | random_resize_scale=(0.3, 1.0), 49 | reprob=0.0, 50 | auto_augment=False, 51 | motion_shift=False, 52 | crop_size=224, 53 | normalize=((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 54 | pad_frame_count: Optional[int] = None, 55 | pad_frame_method: str = "circulant", 56 | ): 57 | self.random_horizontal_flip = random_horizontal_flip 58 | self.random_resize_aspect_ratio = random_resize_aspect_ratio 59 | self.random_resize_scale = random_resize_scale 60 | self.auto_augment = auto_augment 61 | self.motion_shift = motion_shift 62 | self.crop_size = crop_size 63 | self.mean = torch.tensor(normalize[0], dtype=torch.float32) 64 | self.std = torch.tensor(normalize[1], dtype=torch.float32) 65 | self.pad_frame_count = pad_frame_count 66 | self.pad_frame_method = pad_frame_method 67 | 68 | if not self.auto_augment: 69 | # Without auto-augment, PIL and tensor conversions simply scale uint8 space by 255. 70 | self.mean *= 255.0 71 | self.std *= 255.0 72 | 73 | self.autoaug_transform = video_transforms.create_random_augment( 74 | input_size=(crop_size, crop_size), 75 | auto_augment="rand-m7-n4-mstd0.5-inc1", 76 | interpolation="bicubic", 77 | ) 78 | 79 | self.spatial_transform = ( 80 | video_transforms.random_resized_crop_with_shift if motion_shift else video_transforms.random_resized_crop 81 | ) 82 | 83 | self.reprob = reprob 84 | self.erase_transform = RandomErasing( 85 | reprob, 86 | mode="pixel", 87 | max_count=1, 88 | num_splits=1, 89 | device="cpu", 90 | ) 91 | 92 | def __call__(self, buffer): 93 | 94 | if self.auto_augment: 95 | buffer = [transforms.ToPILImage()(frame) for frame in buffer] 96 | buffer = self.autoaug_transform(buffer) 97 | buffer = [transforms.ToTensor()(img) for img in buffer] 98 | buffer = torch.stack(buffer) # T C H W 99 | buffer = buffer.permute(0, 2, 3, 1) # T H W C 100 | elif torch.is_tensor(buffer): 101 | # TODO: ensure input is always a tensor? 102 | buffer = buffer.to(torch.float32) 103 | else: 104 | buffer = torch.tensor(buffer, dtype=torch.float32) 105 | 106 | buffer = buffer.permute(3, 0, 1, 2) # T H W C -> C T H W 107 | 108 | buffer = self.spatial_transform( 109 | images=buffer, 110 | target_height=self.crop_size, 111 | target_width=self.crop_size, 112 | scale=self.random_resize_scale, 113 | ratio=self.random_resize_aspect_ratio, 114 | ) 115 | if self.random_horizontal_flip: 116 | buffer, _ = video_transforms.horizontal_flip(0.5, buffer) 117 | 118 | buffer = _tensor_normalize_inplace(buffer, self.mean, self.std) 119 | if self.reprob > 0: 120 | buffer = buffer.permute(1, 0, 2, 3) 121 | buffer = self.erase_transform(buffer) 122 | buffer = buffer.permute(1, 0, 2, 3) 123 | 124 | if self.pad_frame_count is not None: 125 | buffer = video_transforms.frame_pad(buffer, self.pad_frame_count, self.pad_frame_method) 126 | 127 | return buffer 128 | 129 | 130 | def tensor_normalize(tensor, mean, std): 131 | """ 132 | Normalize a given tensor by subtracting the mean and dividing the std. 133 | Args: 134 | tensor (tensor): tensor to normalize. 135 | mean (tensor or list): mean value to subtract. 136 | std (tensor or list): std to divide. 137 | """ 138 | if tensor.dtype == torch.uint8: 139 | tensor = tensor.float() 140 | tensor = tensor / 255.0 141 | if isinstance(mean, list): 142 | mean = torch.tensor(mean) 143 | if isinstance(std, list): 144 | std = torch.tensor(std) 145 | tensor = tensor - mean 146 | tensor = tensor / std 147 | return tensor 148 | 149 | 150 | def _tensor_normalize_inplace(tensor, mean, std): 151 | """ 152 | Normalize a given tensor by subtracting the mean and dividing the std. 153 | Args: 154 | tensor (tensor): tensor to normalize (with dimensions C, T, H, W). 155 | mean (tensor): mean value to subtract (in 0 to 255 floats). 156 | std (tensor): std to divide (in 0 to 255 floats). 157 | """ 158 | if tensor.dtype == torch.uint8: 159 | tensor = tensor.float() 160 | 161 | C, T, H, W = tensor.shape 162 | tensor = tensor.view(C, -1).permute(1, 0) # Make C the last dimension 163 | tensor.sub_(mean).div_(std) 164 | tensor = tensor.permute(1, 0).view(C, T, H, W) # Put C back in front 165 | return tensor 166 | -------------------------------------------------------------------------------- /src/datasets/utils/video/volume_transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import numpy as np 7 | import torch 8 | from PIL import Image 9 | 10 | 11 | def convert_img(img): 12 | """Converts (H, W, C) numpy.ndarray to (C, W, H) format""" 13 | if len(img.shape) == 3: 14 | img = img.transpose(2, 0, 1) 15 | if len(img.shape) == 2: 16 | img = np.expand_dims(img, 0) 17 | return img 18 | 19 | 20 | class ClipToTensor(object): 21 | """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255] 22 | to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0] 23 | """ 24 | 25 | def __init__(self, channel_nb=3, div_255=True, numpy=False): 26 | self.channel_nb = channel_nb 27 | self.div_255 = div_255 28 | self.numpy = numpy 29 | 30 | def __call__(self, clip): 31 | """ 32 | Args: clip (list of numpy.ndarray): clip (list of images) 33 | to be converted to tensor. 34 | """ 35 | # Retrieve shape 36 | if isinstance(clip[0], np.ndarray): 37 | h, w, ch = clip[0].shape 38 | assert ch == self.channel_nb, "Got {0} instead of 3 channels".format(ch) 39 | elif isinstance(clip[0], Image.Image): 40 | w, h = clip[0].size 41 | elif isinstance(clip[0], torch.Tensor): 42 | tensor_clip = torch.stack(clip) 43 | # Converting (T, C, H, W) -> (C, T, H, W) to match what `convert_img` followed by 44 | # `np_clip[:, img_idx, :, :] = img` does for other data types. 45 | tensor_clip = tensor_clip.permute(1, 0, 2, 3) 46 | if not isinstance(tensor_clip, torch.FloatTensor): 47 | tensor_clip = tensor_clip.float() 48 | if self.div_255: 49 | tensor_clip = torch.div(tensor_clip, 255) 50 | return tensor_clip 51 | else: 52 | raise TypeError( 53 | "Expected numpy.ndarray or PIL.Image or torch.Tensor\ 54 | but got list of {0}".format( 55 | type(clip[0]) 56 | ) 57 | ) 58 | 59 | np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)]) 60 | 61 | # Convert 62 | for img_idx, img in enumerate(clip): 63 | if isinstance(img, np.ndarray): 64 | pass 65 | elif isinstance(img, Image.Image): 66 | img = np.array(img, copy=False) 67 | else: 68 | raise TypeError( 69 | "Expected numpy.ndarray or PIL.Image\ 70 | but got list of {0}".format( 71 | type(clip[0]) 72 | ) 73 | ) 74 | img = convert_img(img) 75 | np_clip[:, img_idx, :, :] = img 76 | 77 | if self.numpy: 78 | if self.div_255: 79 | np_clip = np_clip / 255.0 80 | return np_clip 81 | 82 | else: 83 | tensor_clip = torch.from_numpy(np_clip) 84 | 85 | if not isinstance(tensor_clip, torch.FloatTensor): 86 | tensor_clip = tensor_clip.float() 87 | if self.div_255: 88 | tensor_clip = torch.div(tensor_clip, 255) 89 | return tensor_clip 90 | 91 | 92 | # Note this norms data to -1/1 93 | class ClipToTensor_K(object): 94 | """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255] 95 | to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0] 96 | """ 97 | 98 | def __init__(self, channel_nb=3, div_255=True, numpy=False): 99 | self.channel_nb = channel_nb 100 | self.div_255 = div_255 101 | self.numpy = numpy 102 | 103 | def __call__(self, clip): 104 | """ 105 | Args: clip (list of numpy.ndarray): clip (list of images) 106 | to be converted to tensor. 107 | """ 108 | # Retrieve shape 109 | if isinstance(clip[0], np.ndarray): 110 | h, w, ch = clip[0].shape 111 | assert ch == self.channel_nb, "Got {0} instead of 3 channels".format(ch) 112 | elif isinstance(clip[0], Image.Image): 113 | w, h = clip[0].size 114 | else: 115 | raise TypeError( 116 | "Expected numpy.ndarray or PIL.Image\ 117 | but got list of {0}".format( 118 | type(clip[0]) 119 | ) 120 | ) 121 | 122 | np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)]) 123 | 124 | # Convert 125 | for img_idx, img in enumerate(clip): 126 | if isinstance(img, np.ndarray): 127 | pass 128 | elif isinstance(img, Image.Image): 129 | img = np.array(img, copy=False) 130 | else: 131 | raise TypeError( 132 | "Expected numpy.ndarray or PIL.Image\ 133 | but got list of {0}".format( 134 | type(clip[0]) 135 | ) 136 | ) 137 | img = convert_img(img) 138 | np_clip[:, img_idx, :, :] = img 139 | if self.numpy: 140 | if self.div_255: 141 | np_clip = (np_clip - 127.5) / 127.5 142 | return np_clip 143 | 144 | else: 145 | tensor_clip = torch.from_numpy(np_clip) 146 | 147 | if not isinstance(tensor_clip, torch.FloatTensor): 148 | tensor_clip = tensor_clip.float() 149 | if self.div_255: 150 | tensor_clip = torch.div(torch.sub(tensor_clip, 127.5), 127.5) 151 | return tensor_clip 152 | 153 | 154 | class ToTensor(object): 155 | """Converts numpy array to tensor""" 156 | 157 | def __call__(self, array): 158 | tensor = torch.from_numpy(array) 159 | return tensor 160 | -------------------------------------------------------------------------------- /src/datasets/utils/worker_init_fn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # Copyright The Lightning AI team. 4 | 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | # This code originally comes from PyTorch Lighting with some light modificaitons: 18 | # https://github.com/Lightning-AI/pytorch-lightning/blob/a944e7744e57a5a2c13f3c73b9735edf2f71e329/src/lightning/fabric/utilities/seed.py 19 | 20 | 21 | import os 22 | import random 23 | from typing import Optional 24 | 25 | import numpy as np 26 | import torch 27 | 28 | from src.utils.logging import get_logger 29 | 30 | logger = get_logger("worker_init_fn") 31 | 32 | 33 | def _generate_seed_sequence(base_seed: int, worker_id: int, global_rank: int, count: int) -> list[int]: 34 | """Generates a sequence of seeds from a base seed, worker id and rank using the linear congruential generator (LCG) 35 | algorithm.""" 36 | # Combine base seed, worker id and rank into a unique 64-bit number 37 | combined_seed = (base_seed << 32) | (worker_id << 16) | global_rank 38 | seeds = [] 39 | for _ in range(count): 40 | # x_(n+1) = (a * x_n + c) mod m. With c=1, m=2^64 and a is D. Knuth's constant 41 | combined_seed = (combined_seed * 6364136223846793005 + 1) & ((1 << 64) - 1) 42 | seeds.append(combined_seed) 43 | return seeds 44 | 45 | 46 | def pl_worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: # pragma: no cover 47 | r"""The worker_init_fn that Lightning automatically adds to your dataloader if you previously set the seed with 48 | ``seed_everything(seed, workers=True)``. 49 | 50 | See also the PyTorch documentation on 51 | `randomness in DataLoaders <https://pytorch.org/docs/stable/notes/randomness.html#dataloader>`_. 52 | 53 | """ 54 | # implementation notes: https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562 55 | if rank is None: 56 | procid = os.environ.get("SLURM_PROCID") 57 | if procid is None: 58 | logger.warning("SLURM_PROCID is not set, setting rank to 0") 59 | rank = 0 60 | else: 61 | rank = int(procid) 62 | 63 | process_seed = torch.initial_seed() 64 | # back out the base seed so we can use all the bits 65 | base_seed = process_seed - worker_id 66 | logger.debug( 67 | f"Initializing random number generators of process {rank} worker {worker_id} with base seed {base_seed}" 68 | ) 69 | seed_sequence = _generate_seed_sequence(base_seed, worker_id, rank, count=4) 70 | torch.manual_seed(seed_sequence[0]) # torch takes a 64-bit seed 71 | random.seed((seed_sequence[1] << 32) | seed_sequence[2]) # combine two 64-bit seeds 72 | 73 | ss = np.random.SeedSequence([base_seed, worker_id, rank]) 74 | np_rng_seed = ss.generate_state(4) 75 | 76 | np.random.seed(np_rng_seed) 77 | -------------------------------------------------------------------------------- /src/hub/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/vjepa2/2dbda75ac7ab8e4e16f1c987c2b76261b54a6e3d/src/hub/__init__.py -------------------------------------------------------------------------------- /src/hub/backbones.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | 8 | VJEPA_BASE_URL = "https://dl.fbaipublicfiles.com/vjepa2" 9 | 10 | # for testing 11 | # VJEPA_BASE_URL = "http://localhost:8300" 12 | 13 | ARCH_NAME_MAP = { 14 | "vit_large": ("vit_large", "vitl"), 15 | "vit_huge": ("vit_huge", "vith"), 16 | "vit_giant": ("vit_giant_xformers", "vitg"), 17 | "vit_ac_giant": ("vit_giant_xformers", "vjepa2-ac-vitg"), 18 | "vit_giant_384": ("vit_giant_xformers", "vitg-384"), 19 | } 20 | 21 | 22 | def _clean_backbone_key(state_dict): 23 | for key, val in state_dict.copy().items(): 24 | _ = state_dict.pop(key) 25 | key = key.replace("module.", "") 26 | key = key.replace("backbone.", "") 27 | state_dict[key] = val 28 | return state_dict 29 | 30 | 31 | def _make_vjepa2_ac_model( 32 | *, 33 | model_name: str = "vit_ac_giant", 34 | img_size=256, 35 | patch_size=16, 36 | tubelet_size=2, 37 | num_frames=64, 38 | pretrained: bool = True, 39 | **kwargs, 40 | ): 41 | from ..models import ac_predictor as vit_ac_predictor 42 | from ..models import vision_transformer as vit_encoder 43 | 44 | vit_encoder_kwargs = dict( 45 | patch_size=patch_size, 46 | img_size=(img_size, img_size), 47 | num_frames=num_frames, 48 | tubelet_size=tubelet_size, 49 | use_sdpa=True, 50 | use_SiLU=False, 51 | wide_SiLU=True, 52 | uniform_power=False, 53 | use_rope=True, 54 | ) 55 | vit_encoder_kwargs.update(**kwargs) 56 | 57 | arch_name = ARCH_NAME_MAP[model_name][0] 58 | encoder = vit_encoder.__dict__[arch_name](**vit_encoder_kwargs) 59 | 60 | vit_predictor_kwargs = dict( 61 | img_size=(img_size, img_size), 62 | patch_size=patch_size, 63 | num_frames=num_frames, 64 | tubelet_size=tubelet_size, 65 | embed_dim=encoder.embed_dim, 66 | ) 67 | vit_predictor_kwargs.update(**kwargs) 68 | 69 | predictor = vit_ac_predictor.__dict__["vit_ac_predictor"](**vit_predictor_kwargs) 70 | 71 | if pretrained: 72 | model_file = ARCH_NAME_MAP[model_name][-1] 73 | url = VJEPA_BASE_URL + f"/{model_file}.pt" 74 | state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu") 75 | encoder_state_dict = _clean_backbone_key(state_dict["encoder"]) 76 | encoder.load_state_dict(encoder_state_dict, strict=False) 77 | predictor_state_dict = _clean_backbone_key(state_dict["predictor"]) 78 | predictor.load_state_dict(predictor_state_dict, strict=True) 79 | 80 | return encoder, predictor 81 | 82 | 83 | def _make_vjepa2_model( 84 | *, 85 | model_name: str = "vit_large", 86 | img_size=256, 87 | patch_size=16, 88 | tubelet_size=2, 89 | num_frames=64, 90 | pretrained: bool = True, 91 | **kwargs, 92 | ): 93 | from ..models import predictor as vit_predictor 94 | from ..models import vision_transformer as vit_encoder 95 | 96 | vit_encoder_kwargs = dict( 97 | patch_size=patch_size, 98 | img_size=(img_size, img_size), 99 | num_frames=num_frames, 100 | tubelet_size=tubelet_size, 101 | use_sdpa=True, 102 | use_SiLU=False, 103 | wide_SiLU=True, 104 | uniform_power=False, 105 | use_rope=True, 106 | ) 107 | vit_encoder_kwargs.update(**kwargs) 108 | 109 | arch_name = ARCH_NAME_MAP[model_name][0] 110 | encoder = vit_encoder.__dict__[arch_name](**vit_encoder_kwargs) 111 | 112 | vit_predictor_kwargs = dict( 113 | img_size=(img_size, img_size), 114 | patch_size=patch_size, 115 | use_mask_tokens=True, 116 | embed_dim=encoder.embed_dim, 117 | predictor_embed_dim=384, 118 | num_frames=num_frames, 119 | tubelet_size=tubelet_size, 120 | depth=12, 121 | num_heads=12, 122 | num_mask_tokens=10, 123 | use_rope=True, 124 | uniform_power=False, 125 | use_sdpa=True, 126 | use_silu=False, 127 | wide_silu=True, 128 | ) 129 | vit_predictor_kwargs.update(**kwargs) 130 | 131 | predictor = vit_predictor.__dict__["vit_predictor"](**vit_predictor_kwargs) 132 | 133 | if pretrained: 134 | model_file = ARCH_NAME_MAP[model_name][-1] 135 | url = VJEPA_BASE_URL + f"/{model_file}.pt" 136 | state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu") 137 | encoder_state_dict = _clean_backbone_key(state_dict["encoder"]) 138 | encoder.load_state_dict(encoder_state_dict, strict=False) # state_dict has pos_embed but we use RoPE 139 | predictor_state_dict = _clean_backbone_key(state_dict["predictor"]) 140 | predictor.load_state_dict(predictor_state_dict, strict=False) # state_dict has pos_embed but we use RoPE 141 | 142 | return encoder, predictor 143 | 144 | 145 | def vjepa2_vit_large(*, pretrained: bool = True, **kwargs): 146 | """ 147 | VJEPA 2 ViT-Large model 148 | """ 149 | return _make_vjepa2_model(model_name="vit_large", img_size=256, pretrained=pretrained, **kwargs) 150 | 151 | 152 | def vjepa2_vit_huge(*, pretrained: bool = True, **kwargs): 153 | """ 154 | VJEPA 2 ViT-Huge model 155 | """ 156 | return _make_vjepa2_model(model_name="vit_huge", img_size=256, pretrained=pretrained, **kwargs) 157 | 158 | 159 | def vjepa2_vit_giant(*, pretrained: bool = True, **kwargs): 160 | """ 161 | VJEPA 2 ViT-giant model 162 | """ 163 | return _make_vjepa2_model(model_name="vit_giant", img_size=256, pretrained=pretrained, **kwargs) 164 | 165 | 166 | def vjepa2_vit_giant_384(*, pretrained: bool = True, **kwargs): 167 | """ 168 | VJEPA 2 ViT-giant-384 model 169 | """ 170 | return _make_vjepa2_model(model_name="vit_giant_384", img_size=384, pretrained=pretrained, **kwargs) 171 | 172 | 173 | def vjepa2_ac_vit_giant(*, pretrained: bool = True, **kwargs): 174 | """ 175 | VJEPA 2-AC ViT-giant model 176 | """ 177 | return _make_vjepa2_ac_model(model_name="vit_ac_giant", img_size=256, pretrained=pretrained, **kwargs) 178 | -------------------------------------------------------------------------------- /src/masks/default.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from logging import getLogger 7 | 8 | import torch 9 | 10 | _GLOBAL_SEED = 0 11 | logger = getLogger() 12 | 13 | 14 | class DefaultCollator(object): 15 | 16 | def __call__(self, batch): 17 | collated_batch = torch.utils.data.default_collate(batch) 18 | return collated_batch, None, None 19 | -------------------------------------------------------------------------------- /src/masks/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | 8 | 9 | def apply_masks(x, masks, concat=True): 10 | """ 11 | :param x: tensor of shape [B (batch-size), N (num-patches), D (feature-dim)] 12 | :param masks: list of tensors of shape [B, K] containing indices of K patches in [N] to keep 13 | """ 14 | all_x = [] 15 | for m in masks: 16 | mask_keep = m.unsqueeze(-1).repeat(1, 1, x.size(-1)) 17 | all_x += [torch.gather(x, dim=1, index=mask_keep)] 18 | if not concat: 19 | return all_x 20 | 21 | return torch.cat(all_x, dim=0) 22 | -------------------------------------------------------------------------------- /src/models/attentive_pooler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import math 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | from src.models.utils.modules import Block, CrossAttention, CrossAttentionBlock 13 | from src.utils.tensors import trunc_normal_ 14 | 15 | 16 | class AttentivePooler(nn.Module): 17 | """Attentive Pooler""" 18 | 19 | def __init__( 20 | self, 21 | num_queries=1, 22 | embed_dim=768, 23 | num_heads=12, 24 | mlp_ratio=4.0, 25 | depth=1, 26 | norm_layer=nn.LayerNorm, 27 | init_std=0.02, 28 | qkv_bias=True, 29 | complete_block=True, 30 | use_activation_checkpointing=False, 31 | ): 32 | super().__init__() 33 | self.use_activation_checkpointing = use_activation_checkpointing 34 | self.query_tokens = nn.Parameter(torch.zeros(1, num_queries, embed_dim)) 35 | 36 | self.complete_block = complete_block 37 | if complete_block: 38 | self.cross_attention_block = CrossAttentionBlock( 39 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, norm_layer=norm_layer 40 | ) 41 | else: 42 | self.cross_attention_block = CrossAttention(dim=embed_dim, num_heads=num_heads, qkv_bias=qkv_bias) 43 | 44 | self.blocks = None 45 | if depth > 1: 46 | self.blocks = nn.ModuleList( 47 | [ 48 | Block( 49 | dim=embed_dim, 50 | num_heads=num_heads, 51 | mlp_ratio=mlp_ratio, 52 | qkv_bias=qkv_bias, 53 | qk_scale=False, 54 | norm_layer=norm_layer, 55 | ) 56 | for i in range(depth - 1) 57 | ] 58 | ) 59 | 60 | self.init_std = init_std 61 | trunc_normal_(self.query_tokens, std=self.init_std) 62 | self.apply(self._init_weights) 63 | self._rescale_blocks() 64 | 65 | def _rescale_blocks(self): 66 | def rescale(param, layer_id): 67 | param.div_(math.sqrt(2.0 * layer_id)) 68 | 69 | layer_id = 0 70 | if self.blocks is not None: 71 | for layer_id, layer in enumerate(self.blocks): 72 | rescale(layer.attn.proj.weight.data, layer_id + 1) 73 | rescale(layer.mlp.fc2.weight.data, layer_id + 1) 74 | 75 | if self.complete_block: 76 | rescale(self.cross_attention_block.mlp.fc2.weight.data, layer_id + 1) 77 | 78 | def _init_weights(self, m): 79 | if isinstance(m, nn.Linear): 80 | trunc_normal_(m.weight, std=self.init_std) 81 | if isinstance(m, nn.Linear) and m.bias is not None: 82 | nn.init.constant_(m.bias, 0) 83 | elif isinstance(m, nn.LayerNorm): 84 | nn.init.constant_(m.bias, 0) 85 | nn.init.constant_(m.weight, 1.0) 86 | elif isinstance(m, nn.Conv2d): 87 | trunc_normal_(m.weight, std=self.init_std) 88 | if m.bias is not None: 89 | nn.init.constant_(m.bias, 0) 90 | 91 | def forward(self, x): 92 | if self.blocks is not None: 93 | for blk in self.blocks: 94 | if self.use_activation_checkpointing: 95 | x = torch.utils.checkpoint.checkpoint(blk, x, False, None, use_reentrant=False) 96 | else: 97 | x = blk(x) 98 | q = self.query_tokens.repeat(len(x), 1, 1) 99 | q = self.cross_attention_block(q, x) 100 | return q 101 | 102 | 103 | class AttentiveClassifier(nn.Module): 104 | """Attentive Classifier""" 105 | 106 | def __init__( 107 | self, 108 | embed_dim=768, 109 | num_heads=12, 110 | mlp_ratio=4.0, 111 | depth=1, 112 | norm_layer=nn.LayerNorm, 113 | init_std=0.02, 114 | qkv_bias=True, 115 | num_classes=1000, 116 | complete_block=True, 117 | use_activation_checkpointing=False, 118 | ): 119 | super().__init__() 120 | self.pooler = AttentivePooler( 121 | num_queries=1, 122 | embed_dim=embed_dim, 123 | num_heads=num_heads, 124 | mlp_ratio=mlp_ratio, 125 | depth=depth, 126 | norm_layer=norm_layer, 127 | init_std=init_std, 128 | qkv_bias=qkv_bias, 129 | complete_block=complete_block, 130 | use_activation_checkpointing=use_activation_checkpointing, 131 | ) 132 | self.linear = nn.Linear(embed_dim, num_classes, bias=True) 133 | 134 | def forward(self, x): 135 | x = self.pooler(x).squeeze(1) 136 | x = self.linear(x) 137 | return x 138 | -------------------------------------------------------------------------------- /src/models/utils/patch_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch.nn as nn 7 | from einops import rearrange 8 | 9 | 10 | class PatchEmbed(nn.Module): 11 | """ 12 | Image to Patch Embedding 13 | """ 14 | 15 | def __init__(self, patch_size=16, in_chans=3, embed_dim=768): 16 | super().__init__() 17 | self.patch_size = patch_size 18 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 19 | 20 | def forward(self, x): 21 | B, C, H, W = x.shape 22 | x = self.proj(x).flatten(2).transpose(1, 2) 23 | return x 24 | 25 | 26 | class PatchEmbed3D(nn.Module): 27 | """ 28 | Image to Patch Embedding 29 | """ 30 | 31 | def __init__( 32 | self, 33 | patch_size=16, 34 | tubelet_size=2, 35 | in_chans=3, 36 | embed_dim=768, 37 | ): 38 | super().__init__() 39 | self.patch_size = patch_size 40 | self.tubelet_size = tubelet_size 41 | 42 | self.proj = nn.Conv3d( 43 | in_channels=in_chans, 44 | out_channels=embed_dim, 45 | kernel_size=(tubelet_size, patch_size, patch_size), 46 | stride=(tubelet_size, patch_size, patch_size), 47 | ) 48 | 49 | def forward(self, x, **kwargs): 50 | B, C, T, H, W = x.shape 51 | x = self.proj(x).flatten(2).transpose(1, 2) 52 | return x 53 | -------------------------------------------------------------------------------- /src/models/utils/pos_embs.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import numpy as np 7 | 8 | 9 | def get_3d_sincos_pos_embed(embed_dim, grid_size, grid_depth, cls_token=False, uniform_power=False): 10 | """ 11 | grid_size: int of the grid height and width 12 | grid_depth: int of the grid depth 13 | returns: 14 | pos_embed: [grid_depth*grid_size*grid_size, embed_dim] (w/o cls_token) 15 | or [1+grid_depth*grid_size*grid_size, embed_dim] (w/ cls_token) 16 | """ 17 | grid_d = np.arange(grid_depth, dtype=float) 18 | grid_h = np.arange(grid_size, dtype=float) 19 | grid_w = np.arange(grid_size, dtype=float) 20 | grid_h, grid_d, grid_w = np.meshgrid( 21 | grid_h, grid_d, grid_w 22 | ) # order of meshgrid is very important for indexing as [d,h,w] 23 | 24 | if not uniform_power: 25 | h_embed_dim = embed_dim // 4 26 | w_embed_dim = embed_dim // 4 27 | d_embed_dim = embed_dim // 2 28 | else: 29 | h_embed_dim = w_embed_dim = d_embed_dim = int(np.ceil(embed_dim / 6) * 2) 30 | 31 | emb_h = get_1d_sincos_pos_embed_from_grid(h_embed_dim, grid_h) # (T*H*W, D1) 32 | emb_w = get_1d_sincos_pos_embed_from_grid(w_embed_dim, grid_w) # (T*H*W, D2) 33 | emb_d = get_1d_sincos_pos_embed_from_grid(d_embed_dim, grid_d) # (T*H*W, D3) 34 | pos_embed = np.concatenate([emb_d, emb_h, emb_w], axis=1) 35 | pos_embed = pos_embed[:, :embed_dim] 36 | if cls_token: 37 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 38 | return pos_embed 39 | 40 | 41 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 42 | """ 43 | grid_size: int of the grid height and width 44 | returns: 45 | pos_embed: [grid_size*grid_size, embed_dim] (w/o cls_token) 46 | or [1+grid_size*grid_size, embed_dim] (w/ cls_token) 47 | """ 48 | grid_h = np.arange(grid_size, dtype=float) 49 | grid_w = np.arange(grid_size, dtype=float) 50 | grid_w, grid_h = np.meshgrid(grid_w, grid_h) # order of meshgrid is very important for indexing as [h, w] 51 | 52 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid_h) # (H*W, D/2) 53 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid_w) # (H*W, D/2) 54 | pos_embed = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 55 | if cls_token: 56 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 57 | return pos_embed 58 | 59 | 60 | def get_1d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 61 | """ 62 | embed_dim: output dimension for each position 63 | grid_size: int of the grid length 64 | returns: 65 | pos_embed: [grid_size, embed_dim] (w/o cls_token) 66 | or [1+grid_size, embed_dim] (w/ cls_token) 67 | """ 68 | grid = np.arange(grid_size, dtype=float) 69 | pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid) 70 | if cls_token: 71 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 72 | return pos_embed 73 | 74 | 75 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 76 | """ 77 | embed_dim: output dimension for each position 78 | pos: a list of positions to be encoded: size (M,) 79 | returns: (M, D) 80 | """ 81 | assert embed_dim % 2 == 0 82 | omega = np.arange(embed_dim // 2, dtype=float) 83 | omega /= embed_dim / 2.0 84 | omega = 1.0 / 10000**omega # (D/2,) 85 | 86 | pos = pos.reshape(-1) # (M,) 87 | out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product 88 | 89 | emb_sin = np.sin(out) # (M, D/2) 90 | emb_cos = np.cos(out) # (M, D/2) 91 | 92 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 93 | return emb 94 | -------------------------------------------------------------------------------- /src/utils/checkpoint_loader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import os 7 | import random 8 | import time 9 | from typing import Any 10 | 11 | import torch 12 | from torch.serialization import MAP_LOCATION 13 | 14 | from src.utils.logging import get_logger 15 | 16 | logger = get_logger(os.path.basename(__file__)) 17 | 18 | 19 | def robust_checkpoint_loader(r_path: str, map_location: MAP_LOCATION = "cpu", max_retries: int = 3) -> Any: 20 | """ 21 | Loads a checkpoint from a path, retrying up to max_retries times if the checkpoint is not found. 22 | """ 23 | retries = 0 24 | 25 | while retries < max_retries: 26 | try: 27 | return torch.load(r_path, map_location=map_location) 28 | except Exception as e: 29 | logger.warning(f"Encountered exception when loading checkpoint {e}") 30 | retries += 1 31 | if retries < max_retries: 32 | sleep_time_s = (2**retries) * random.uniform(1.0, 1.1) 33 | logger.warning(f"Sleeping {sleep_time_s}s and trying again, count {retries}/{max_retries}") 34 | time.sleep(sleep_time_s) 35 | continue 36 | else: 37 | raise e 38 | -------------------------------------------------------------------------------- /src/utils/distributed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import os 7 | from pathlib import Path 8 | 9 | import torch 10 | import torch.distributed as dist 11 | 12 | from src.utils.logging import get_logger 13 | 14 | logger = get_logger() 15 | 16 | 17 | def init_distributed(port=37129, rank_and_world_size=(None, None)): 18 | # try to set all environment variables to avoid triggering a segfault 19 | # environment variables can be reallocated during the execution of torch.distributed.init_process_group 20 | # the idea is a race condition may trigger if init_progress_group is modifying an environment variable at 21 | # the same time as Python, so we try to set all environs before initializing distributed 22 | if "SLURM_JOB_ID" in os.environ: 23 | # Use the slurm_tmpdir (if it exists) instead of /tmp 24 | tmpdir = Path(f"/scratch/slurm_tmpdir/{os.environ['SLURM_JOB_ID']}") 25 | if tmpdir.exists(): 26 | os.environ["TMPDIR"] = str(tmpdir) 27 | 28 | if dist.is_available() and dist.is_initialized(): 29 | return dist.get_world_size(), dist.get_rank() 30 | 31 | rank, world_size = rank_and_world_size 32 | os.environ["MASTER_ADDR"] = "localhost" 33 | 34 | if (rank is None) or (world_size is None): 35 | try: 36 | world_size = int(os.environ["SLURM_NTASKS"]) 37 | rank = int(os.environ["SLURM_PROCID"]) 38 | os.environ["MASTER_ADDR"] = os.environ["HOSTNAME"] 39 | except Exception: 40 | logger.info("SLURM vars not set (distributed training not available)") 41 | world_size, rank = 1, 0 42 | return world_size, rank 43 | 44 | try: 45 | os.environ["MASTER_PORT"] = str(port) 46 | torch.distributed.init_process_group(backend="nccl", world_size=world_size, rank=rank) 47 | except Exception as e: 48 | world_size, rank = 1, 0 49 | logger.info(f"Rank: {rank}. Distributed training not available {e}") 50 | 51 | return world_size, rank 52 | 53 | 54 | class AllGather(torch.autograd.Function): 55 | 56 | @staticmethod 57 | def forward(ctx, x): 58 | if dist.is_available() and dist.is_initialized() and (dist.get_world_size() > 1): 59 | x = x.contiguous() 60 | outputs = [torch.zeros_like(x) for _ in range(dist.get_world_size())] 61 | dist.all_gather(outputs, x) 62 | return torch.cat(outputs, 0) 63 | return x 64 | 65 | @staticmethod 66 | def backward(ctx, grads): 67 | if dist.is_available() and dist.is_initialized() and (dist.get_world_size() > 1): 68 | s = (grads.shape[0] // dist.get_world_size()) * dist.get_rank() 69 | e = (grads.shape[0] // dist.get_world_size()) * (dist.get_rank() + 1) 70 | grads = grads.contiguous() 71 | dist.all_reduce(grads) 72 | return grads[s:e] 73 | return grads 74 | 75 | 76 | class AllReduceSum(torch.autograd.Function): 77 | 78 | @staticmethod 79 | def forward(ctx, x): 80 | if dist.is_available() and dist.is_initialized() and (dist.get_world_size() > 1): 81 | x = x.contiguous() 82 | dist.all_reduce(x) 83 | return x 84 | 85 | @staticmethod 86 | def backward(ctx, grads): 87 | return grads 88 | 89 | 90 | class AllReduce(torch.autograd.Function): 91 | 92 | @staticmethod 93 | def forward(ctx, x): 94 | if dist.is_available() and dist.is_initialized() and (dist.get_world_size() > 1): 95 | x = x.contiguous() / dist.get_world_size() 96 | dist.all_reduce(x) 97 | return x 98 | 99 | @staticmethod 100 | def backward(ctx, grads): 101 | return grads 102 | -------------------------------------------------------------------------------- /src/utils/logging.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import logging 7 | import os 8 | import subprocess 9 | import sys 10 | 11 | import torch 12 | 13 | 14 | def gpu_timer(closure, log_timings=True): 15 | """Helper to time gpu-time to execute closure()""" 16 | log_timings = log_timings and torch.cuda.is_available() 17 | 18 | elapsed_time = -1.0 19 | if log_timings: 20 | start = torch.cuda.Event(enable_timing=True) 21 | end = torch.cuda.Event(enable_timing=True) 22 | start.record() 23 | 24 | result = closure() 25 | 26 | if log_timings: 27 | end.record() 28 | torch.cuda.synchronize() 29 | elapsed_time = start.elapsed_time(end) 30 | 31 | return result, elapsed_time 32 | 33 | 34 | LOG_FORMAT = "[%(levelname)-8s][%(asctime)s][%(name)-20s][%(funcName)-25s] %(message)s" 35 | DATE_FORMAT = "%Y-%m-%d %H:%M:%S" 36 | 37 | 38 | def get_logger(name=None, force=False): 39 | logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=LOG_FORMAT, datefmt=DATE_FORMAT, force=force) 40 | return logging.getLogger(name=name) 41 | 42 | 43 | class CSVLogger(object): 44 | 45 | def __init__(self, fname, *argv, **kwargs): 46 | self.fname = fname 47 | self.types = [] 48 | mode = kwargs.get("mode", "+a") 49 | self.delim = kwargs.get("delim", ",") 50 | # -- print headers 51 | with open(self.fname, mode) as f: 52 | for i, v in enumerate(argv, 1): 53 | self.types.append(v[0]) 54 | if i < len(argv): 55 | print(v[1], end=self.delim, file=f) 56 | else: 57 | print(v[1], end="\n", file=f) 58 | 59 | def log(self, *argv): 60 | with open(self.fname, "+a") as f: 61 | for i, tv in enumerate(zip(self.types, argv), 1): 62 | end = self.delim if i < len(argv) else "\n" 63 | print(tv[0] % tv[1], end=end, file=f) 64 | 65 | 66 | class AverageMeter(object): 67 | """computes and stores the average and current value""" 68 | 69 | def __init__(self): 70 | self.reset() 71 | 72 | def reset(self): 73 | self.val = 0 74 | self.avg = 0 75 | self.max = float("-inf") 76 | self.min = float("inf") 77 | self.sum = 0 78 | self.count = 0 79 | 80 | def update(self, val, n=1): 81 | self.val = val 82 | try: 83 | self.max = max(val, self.max) 84 | self.min = min(val, self.min) 85 | except Exception: 86 | pass 87 | self.sum += val * n 88 | self.count += n 89 | self.avg = self.sum / self.count 90 | 91 | 92 | def jepa_rootpath(): 93 | this_file = os.path.abspath(__file__) 94 | return "/".join(this_file.split("/")[:-3]) 95 | 96 | 97 | def git_information(): 98 | jepa_root = jepa_rootpath() 99 | try: 100 | resp = ( 101 | subprocess.check_output(["git", "-C", jepa_root, "rev-parse", "HEAD", "--abbrev-ref", "HEAD"]) 102 | .decode("ascii") 103 | .strip() 104 | ) 105 | commit, branch = resp.split("\n") 106 | return f"branch: {branch}\ncommit: {commit}\n" 107 | except Exception: 108 | return "unknown" 109 | -------------------------------------------------------------------------------- /src/utils/schedulers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import math 7 | 8 | 9 | class WSDSchedule(object): 10 | 11 | def __init__(self, optimizer, warmup_steps, anneal_steps, T_max, start_lr, ref_lr, final_lr=0.0): 12 | self.optimizer = optimizer 13 | self.start_lr = start_lr 14 | self.ref_lr = ref_lr 15 | self.final_lr = final_lr 16 | self.anneal_steps = anneal_steps 17 | self.warmup_steps = warmup_steps 18 | self.T_max = T_max - warmup_steps - anneal_steps 19 | self._step = 0.0 20 | 21 | def step(self): 22 | self._step += 1 23 | if self._step < self.warmup_steps: 24 | progress = float(self._step) / float(max(1, self.warmup_steps)) 25 | new_lr = self.start_lr + progress * (self.ref_lr - self.start_lr) 26 | elif self._step < self.T_max + self.warmup_steps: 27 | new_lr = self.ref_lr 28 | else: 29 | _step = self._step - (self.T_max + self.warmup_steps) 30 | progress = float(_step) / float(max(1, self.anneal_steps)) 31 | new_lr = self.ref_lr + progress * (self.final_lr - self.ref_lr) 32 | 33 | for group in self.optimizer.param_groups: 34 | group["lr"] = new_lr 35 | if "lr_scale" in group: 36 | group["lr"] *= group["lr_scale"] 37 | 38 | return new_lr 39 | 40 | 41 | class WarmupCosineSchedule(object): 42 | 43 | def __init__(self, optimizer, warmup_steps, start_lr, ref_lr, T_max, last_epoch=-1, final_lr=0.0): 44 | self.optimizer = optimizer 45 | self.start_lr = start_lr 46 | self.ref_lr = ref_lr 47 | self.final_lr = final_lr 48 | self.warmup_steps = warmup_steps 49 | self.T_max = T_max - warmup_steps 50 | self._step = 0.0 51 | 52 | def step(self): 53 | self._step += 1 54 | if self._step < self.warmup_steps: 55 | progress = float(self._step) / float(max(1, self.warmup_steps)) 56 | new_lr = self.start_lr + progress * (self.ref_lr - self.start_lr) 57 | else: 58 | # -- progress after warmup 59 | progress = float(self._step - self.warmup_steps) / float(max(1, self.T_max)) 60 | new_lr = max( 61 | self.final_lr, 62 | self.final_lr + (self.ref_lr - self.final_lr) * 0.5 * (1.0 + math.cos(math.pi * progress)), 63 | ) 64 | 65 | for group in self.optimizer.param_groups: 66 | group["lr"] = new_lr 67 | 68 | return new_lr 69 | 70 | 71 | class CosineWDSchedule(object): 72 | 73 | def __init__(self, optimizer, ref_wd, T_max, final_wd=0.0): 74 | self.optimizer = optimizer 75 | self.ref_wd = ref_wd 76 | self.final_wd = final_wd 77 | self.T_max = T_max 78 | self._step = 0.0 79 | 80 | def step(self): 81 | self._step += 1 82 | progress = self._step / self.T_max 83 | new_wd = self.final_wd + (self.ref_wd - self.final_wd) * 0.5 * (1.0 + math.cos(math.pi * progress)) 84 | 85 | if self.final_wd <= self.ref_wd: 86 | new_wd = max(self.final_wd, new_wd) 87 | else: 88 | new_wd = min(self.final_wd, new_wd) 89 | 90 | for group in self.optimizer.param_groups: 91 | if ("WD_exclude" not in group) or not group["WD_exclude"]: 92 | group["weight_decay"] = new_wd 93 | return new_wd 94 | -------------------------------------------------------------------------------- /src/utils/tensors.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import math 7 | from logging import getLogger 8 | 9 | import torch 10 | 11 | logger = getLogger() 12 | 13 | 14 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 15 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 16 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 17 | def norm_cdf(x): 18 | # Computes standard normal cumulative distribution function 19 | return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 20 | 21 | with torch.no_grad(): 22 | # Values are generated by using a truncated uniform distribution and 23 | # then using the inverse CDF for the normal distribution. 24 | # Get upper and lower cdf values 25 | lower = norm_cdf((a - mean) / std) 26 | upper = norm_cdf((b - mean) / std) 27 | 28 | # Uniformly fill tensor with values from [lower, upper], then translate to 29 | # [2*lower-1, 2*upper-1]. 30 | tensor.uniform_(2 * lower - 1, 2 * upper - 1) 31 | 32 | # Use inverse cdf transform for normal distribution to get truncated 33 | # standard normal 34 | tensor.erfinv_() 35 | 36 | # Transform to proper mean, std 37 | tensor.mul_(std * math.sqrt(2.0)) 38 | tensor.add_(mean) 39 | 40 | # Clamp to ensure it's in the proper range 41 | tensor.clamp_(min=a, max=b) 42 | return tensor 43 | 44 | 45 | def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): 46 | # type: (Tensor, float, float, float, float) -> Tensor 47 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 48 | 49 | 50 | def repeat_interleave_batch(x, B, repeat): 51 | N = len(x) // B 52 | x = torch.cat([torch.cat([x[i * B : (i + 1) * B] for _ in range(repeat)], dim=0) for i in range(N)], dim=0) 53 | return x 54 | -------------------------------------------------------------------------------- /src/utils/wrappers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch.nn as nn 7 | 8 | 9 | class MultiSeqWrapper(nn.Module): 10 | 11 | def __init__(self, backbone): 12 | super().__init__() 13 | self.backbone = backbone 14 | 15 | def forward(self, x, masks=None): 16 | """ 17 | :param x: [list] List of Tensors of different seq lengths 18 | :param masks: [list[list]] List of Tensors (out index: masks for given seq length, inner index: multimasks for that seq len) 19 | """ 20 | if masks is None: 21 | return [self.backbone(xi) for xi in x] 22 | 23 | outs = [[] for _ in x] 24 | for i, (xi, mi) in enumerate(zip(x, masks)): 25 | for mij in mi: 26 | outs[i] += [self.backbone(xi, masks=mij)] 27 | return outs 28 | 29 | 30 | class PredictorMultiSeqWrapper(nn.Module): 31 | 32 | def __init__(self, backbone): 33 | super().__init__() 34 | self.backbone = backbone 35 | 36 | def forward(self, x, masks_x, masks_y, has_cls=False): 37 | n = 0 38 | outs = [[] for _ in x] 39 | for i, (xi, mxi, myi) in enumerate(zip(x, masks_x, masks_y)): 40 | for xij, mxij, myij in zip(xi, mxi, myi): 41 | outs[i] += [self.backbone(xij, mxij, myij, mask_index=i, has_cls=has_cls)] 42 | n += 1 43 | return outs 44 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import os 7 | import sys 8 | 9 | JEPA_ROOT = os.path.dirname(os.path.dirname(__file__)) 10 | sys.path.append(JEPA_ROOT) 11 | -------------------------------------------------------------------------------- /tests/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import os 7 | import sys 8 | 9 | JEPA_ROOT = os.path.dirname(os.path.dirname(__file__)) 10 | sys.path.append(JEPA_ROOT) 11 | -------------------------------------------------------------------------------- /tests/datasets/test_dataloader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import unittest 7 | 8 | from src.datasets.utils.dataloader import ConcatIndices 9 | 10 | 11 | class TestConcatIndices(unittest.TestCase): 12 | def test_concat_indices(self): 13 | sizes = [10, 20, 30, 40] 14 | total_size = sum(sizes) 15 | concat_indices = ConcatIndices(sizes) 16 | 17 | # -1 is outside the total range 18 | with self.assertRaises(ValueError): 19 | concat_indices[-1] 20 | # 0-9 map to dataset 0 21 | self.assertEqual(concat_indices[0], (0, 0)) 22 | self.assertEqual(concat_indices[9], (0, 9)) 23 | # 10-29 map to dataset 1 24 | self.assertEqual(concat_indices[10], (1, 0)) 25 | self.assertEqual(concat_indices[29], (1, 19)) 26 | # 30-59 map to dataset 2 27 | self.assertEqual(concat_indices[30], (2, 0)) 28 | self.assertEqual(concat_indices[59], (2, 29)) 29 | # 60-99 map to dataset 3 30 | self.assertEqual(concat_indices[60], (3, 0)) 31 | self.assertEqual(concat_indices[99], (3, 39)) 32 | # 100 is outside the total range 33 | with self.assertRaises(ValueError): 34 | concat_indices[total_size] 35 | -------------------------------------------------------------------------------- /tests/datasets/test_memory_efficient_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import unittest 7 | 8 | from src.datasets.utils.weighted_sampler import ( 9 | MemoryEfficientDistributedWeightedSampler, 10 | MemoryEfficientDistributedWeightedSamplerLessRepeat, 11 | ) 12 | 13 | 14 | class MockDataset: 15 | def __init__(self, datasets, dataset_weights): 16 | self.datasets = datasets 17 | self.dataset_weights = dataset_weights 18 | 19 | def __len__(self): 20 | return sum(len(d) for d in self.datasets) 21 | 22 | 23 | class TestMemoryEfficientSampler(unittest.TestCase): 24 | 25 | def test_shuffled_sampling_single(self): 26 | "The specific values returned are a function of the random sampler with the given seed." 27 | datasets = [] 28 | for i in range(3): 29 | datasets.append([f"DS{i}"] * 100 * (i + 1)) 30 | 31 | mock_dataset = MockDataset(datasets, [1, 1, 1]) 32 | sampler = MemoryEfficientDistributedWeightedSampler(mock_dataset, num_replicas=1, rank=0, shuffle=True) 33 | 34 | smplr_it = iter(sampler) 35 | 36 | ex = next(smplr_it) 37 | self.assertIsNotNone(ex) 38 | self.assertEqual(ex, 202) # Based on previous run 39 | 40 | ex = next(smplr_it) 41 | self.assertIsNotNone(ex) 42 | self.assertEqual(ex, 26) # Based on previous run 43 | 44 | def test_shuffled_sampling(self): 45 | datasets = [] 46 | for i in range(3): 47 | datasets.append([f"DS{i}"] * 100 * (i + 1)) 48 | 49 | mock_dataset = MockDataset(datasets, [1, 2000, 1]) 50 | sampler = MemoryEfficientDistributedWeightedSampler(mock_dataset, num_replicas=8, rank=3, shuffle=True) 51 | 52 | smplr_it = iter(sampler) 53 | 54 | # Notice how the following samples are drawn from the 2nd dataset, which has a weight of 2000. 55 | ex = next(smplr_it) 56 | self.assertIsNotNone(ex) 57 | self.assertEqual(ex, 135) # Based on previous run 58 | 59 | ex = next(smplr_it) 60 | self.assertIsNotNone(ex) 61 | self.assertEqual(ex, 143) # Based on previous run 62 | 63 | def test_non_shuffled_sampling(self): 64 | datasets = [] 65 | for i in range(3): 66 | datasets.append([f"DS{i}"] * 100 * (i + 1)) 67 | 68 | mock_dataset = MockDataset(datasets, [1, 10, 1]) 69 | sampler = MemoryEfficientDistributedWeightedSampler(mock_dataset, num_replicas=4, rank=2, shuffle=False) 70 | 71 | smplr_it = iter(sampler) 72 | 73 | ex = next(smplr_it) 74 | self.assertIsNotNone(ex) 75 | self.assertEqual(ex, 102) # Calculated based on the `__next__` function's non shuffled logic. 76 | 77 | ex = next(smplr_it) 78 | self.assertIsNotNone(ex) 79 | self.assertEqual(ex, 106) # Calculated based on the `__next__` function's non shuffled logic. 80 | 81 | 82 | class TestMemoryEfficientSamplerLessRepeat(unittest.TestCase): 83 | 84 | def test_shuffled_sampling_single(self): 85 | """ 86 | Testing all weights are equal to 1. 87 | The specific values returned are a function of the random sampler with the given seed. 88 | """ 89 | datasets = [] 90 | for i in range(3): 91 | datasets.append([f"DS{i}"] * 100 * (i + 1)) 92 | 93 | mock_dataset = MockDataset(datasets, [1, 1, 1]) 94 | sampler = MemoryEfficientDistributedWeightedSamplerLessRepeat( 95 | mock_dataset, num_replicas=1, rank=0, shuffle=True 96 | ) 97 | 98 | smplr_it = iter(sampler) 99 | 100 | ex = next(smplr_it) 101 | self.assertIsNotNone(ex) 102 | self.assertEqual(ex, 144) # Based on previous run 103 | 104 | ex = next(smplr_it) 105 | self.assertIsNotNone(ex) 106 | self.assertEqual(ex, 84) # Based on previous run 107 | 108 | def test_shuffled_sampling(self): 109 | """ 110 | Testing one dominant dataset. 111 | The specific values returned are a function of the random sampler with the given seed. 112 | """ 113 | datasets = [] 114 | for i in range(3): 115 | datasets.append([f"DS{i}"] * 100 * (i + 1)) 116 | 117 | mock_dataset = MockDataset(datasets, [1, 2000, 1]) 118 | sampler = MemoryEfficientDistributedWeightedSamplerLessRepeat( 119 | mock_dataset, num_replicas=8, rank=3, shuffle=True 120 | ) 121 | 122 | smplr_it = iter(sampler) 123 | 124 | # Notice how the following samples are drawn from the 2nd dataset, which has a weight of 2000. 125 | ex = next(smplr_it) 126 | self.assertIsNotNone(ex) 127 | self.assertEqual(ex, 255) # Based on previous run 128 | 129 | ex = next(smplr_it) 130 | self.assertIsNotNone(ex) 131 | self.assertEqual(ex, 231) # Based on previous run 132 | 133 | def test_non_shuffled_sampling(self): 134 | datasets = [] 135 | for i in range(3): 136 | datasets.append([f"DS{i}"] * 100 * (i + 1)) 137 | 138 | mock_dataset = MockDataset(datasets, [1, 10, 1]) 139 | sampler = MemoryEfficientDistributedWeightedSamplerLessRepeat( 140 | mock_dataset, num_replicas=4, rank=2, shuffle=False 141 | ) 142 | 143 | smplr_it = iter(sampler) 144 | 145 | ex = next(smplr_it) 146 | self.assertIsNotNone(ex) 147 | self.assertEqual(ex, 102) # Calculated based on the `__next__` function's non shuffled logic. 148 | 149 | ex = next(smplr_it) 150 | self.assertIsNotNone(ex) 151 | self.assertEqual(ex, 106) # Calculated based on the `__next__` function's non shuffled logic. 152 | -------------------------------------------------------------------------------- /tests/datasets/test_vjepa_transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import unittest 7 | 8 | import numpy as np 9 | import torch 10 | 11 | from app.vjepa import transforms 12 | from src.datasets.utils.video import functional 13 | from src.datasets.utils.video.volume_transforms import ClipToTensor 14 | 15 | 16 | class TestNormalize(unittest.TestCase): 17 | 18 | def setUp(self): 19 | self.g = torch.Generator() 20 | self.g.manual_seed(42) 21 | 22 | def test_approximation_equivalance(self): 23 | T, H, W, C = 16, 224, 224, 3 24 | shape = (T, H, W, C) 25 | mean = torch.tensor([0.485, 0.456, 0.406]) 26 | std = torch.tensor([0.229, 0.224, 0.225]) 27 | for i in range(10): 28 | X = torch.randint(low=0, high=255, size=shape, generator=self.g, dtype=torch.uint8) 29 | X_clone = X.clone().permute(3, 0, 1, 2) # C, T, H, W 30 | 31 | X_norm = transforms.tensor_normalize(X, mean, std) 32 | X_norm_fast = transforms._tensor_normalize_inplace(X_clone, 255.0 * mean, 255.0 * std) 33 | self.assertTrue(torch.allclose(X_norm, X_norm_fast.permute(1, 2, 3, 0))) 34 | 35 | 36 | class TestVideoTransformFunctionalCrop(unittest.TestCase): 37 | def test_tensor_numpy(self): 38 | T, C, H, W = 16, 3, 280, 320 39 | shape = (T, C, H, W) 40 | crop_szie = (10, 10, 224, 224) 41 | video_tensor = torch.randint(low=0, high=255, size=shape, dtype=torch.uint8) 42 | video_numpy = video_tensor.numpy() 43 | 44 | cropped_tensor = functional.crop_clip(video_tensor, *crop_szie) 45 | self.assertIsInstance(cropped_tensor[0], torch.Tensor) 46 | 47 | cropped_np_array = functional.crop_clip(video_numpy, *crop_szie) 48 | self.assertIsInstance(cropped_np_array[0], np.ndarray) 49 | 50 | for clip_tensor, clip_np in zip(cropped_tensor, cropped_np_array): 51 | torch.testing.assert_close(clip_tensor, torch.Tensor(clip_np).to(dtype=torch.uint8)) 52 | 53 | 54 | class TestVideoTransformFunctionalResize(unittest.TestCase): 55 | def test_tensor_numpy(self): 56 | T, C, H, W = 16, 3, 280, 320 57 | shape = (T, C, H, W) 58 | resize_to = 256 59 | 60 | video_tensor = torch.randint(low=0, high=255, size=shape, dtype=torch.int16) 61 | # We permute the videos because our underlying numpy.array based transforms are expecting 62 | # image in (H, W, C) shape whereas our tensor transforms are mostly in (C, H, W) 63 | video_numpy = video_tensor.permute(0, 2, 3, 1).numpy() # (T, C, H, W) -> (T, H, W, C) 64 | 65 | resized_tensor = functional.resize_clip(video_tensor, resize_to) 66 | self.assertIsInstance(resized_tensor[0], torch.Tensor) 67 | 68 | resized_np_array = functional.resize_clip(video_numpy, resize_to) 69 | self.assertIsInstance(resized_np_array[0], np.ndarray) 70 | 71 | for clip_tensor, clip_np in zip(resized_tensor, resized_np_array): 72 | clip_tensor = clip_tensor.permute(1, 2, 0) 73 | diff = torch.mean((torch.abs(clip_tensor - torch.Tensor(clip_np).to(torch.int16))) / (clip_tensor + 1)) 74 | 75 | # Transformatinos can not exactly match because of their interpolation functions coming from 76 | # two different sources. Here we check for their relative differences. 77 | # See the discussion here: https://github.com/fairinternal/jepa-internal/pull/65#issuecomment-2101833959 78 | self.assertLess(diff, 0.05) 79 | 80 | 81 | class TestVideoTransformClipToTensor(unittest.TestCase): 82 | def test_tensor_numpy(self): 83 | T, C, H, W = 16, 3, 280, 320 84 | shape = (T, C, H, W) 85 | transform = ClipToTensor() 86 | 87 | video_tensor = [clip for clip in torch.randint(low=0, high=255, size=shape, dtype=torch.int16)] 88 | # We permute the videos because our underlying numpy.array based transforms are expecting 89 | # image in (H, W, C) shape whereas our tensor transforms are mostly in (C, H, W) 90 | video_numpy = [clip.permute(1, 2, 0).numpy() for clip in video_tensor] 91 | torch.testing.assert_close(transform(video_tensor), transform(video_numpy)) 92 | -------------------------------------------------------------------------------- /tests/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import os 7 | import sys 8 | 9 | JEPA_ROOT = os.path.dirname(os.path.dirname(__file__)) 10 | sys.path.append(JEPA_ROOT) 11 | -------------------------------------------------------------------------------- /tests/models/test_models.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import unittest 7 | 8 | import torch 9 | 10 | from src.models.vision_transformer import VIT_EMBED_DIMS, vit_tiny 11 | 12 | 13 | class TestImageViT(unittest.TestCase): 14 | def setUp(self) -> None: 15 | self._vit_tiny = vit_tiny() 16 | self.height, self.width = 224, 224 17 | self.num_patches = (self.height // self._vit_tiny.patch_size) * (self.width // self._vit_tiny.patch_size) 18 | 19 | def test_model_image_nomask_batchsize_4(self): 20 | BS = 4 21 | x = torch.rand((BS, 3, self.height, self.width)) 22 | y = self._vit_tiny(x) 23 | self.assertIsInstance(y, torch.Tensor) 24 | self.assertEqual(y.size(), (BS, self.num_patches, VIT_EMBED_DIMS["vit_tiny"])) 25 | 26 | def test_model_image_nomask_batchsize_1(self): 27 | BS = 1 28 | x = torch.rand((BS, 3, self.height, self.width)) 29 | y = self._vit_tiny(x) 30 | self.assertIsInstance(y, torch.Tensor) 31 | self.assertEqual(y.size(), (BS, self.num_patches, VIT_EMBED_DIMS["vit_tiny"])) 32 | 33 | def test_model_image_masked_batchsize_4(self): 34 | BS = 4 35 | mask_indices = [6, 7, 8] 36 | masks = [torch.tensor(mask_indices, dtype=torch.int64) for _ in range(BS)] 37 | x = torch.rand((BS, 3, self.height, self.width)) 38 | y = self._vit_tiny(x, masks=masks) 39 | self.assertIsInstance(y, torch.Tensor) 40 | self.assertEqual(y.size(), (BS, len(mask_indices), VIT_EMBED_DIMS["vit_tiny"])) 41 | 42 | def test_model_image_masked_batchsize_1(self): 43 | BS = 1 44 | mask_indices = [6, 7, 8] 45 | masks = [torch.tensor(mask_indices, dtype=torch.int64) for _ in range(BS)] 46 | x = torch.rand((BS, 3, self.height, self.width)) 47 | y = self._vit_tiny(x, masks=masks) 48 | self.assertIsInstance(y, torch.Tensor) 49 | self.assertEqual(y.size(), (BS, len(mask_indices), VIT_EMBED_DIMS["vit_tiny"])) 50 | 51 | 52 | class TestVideoViT(unittest.TestCase): 53 | def setUp(self) -> None: 54 | self.num_frames = 8 55 | self._vit_tiny = vit_tiny(num_frames=8) 56 | self.height, self.width = 224, 224 57 | self.num_patches = ( 58 | (self.height // self._vit_tiny.patch_size) 59 | * (self.width // self._vit_tiny.patch_size) 60 | * (self.num_frames // self._vit_tiny.tubelet_size) 61 | ) 62 | 63 | def test_model_video_nomask_batchsize_4(self): 64 | BS = 4 65 | x = torch.rand((BS, 3, self.num_frames, self.height, self.width)) 66 | y = self._vit_tiny(x) 67 | self.assertIsInstance(y, torch.Tensor) 68 | self.assertEqual(y.size(), (BS, self.num_patches, VIT_EMBED_DIMS["vit_tiny"])) 69 | 70 | def test_model_video_nomask_batchsize_1(self): 71 | BS = 1 72 | x = torch.rand((BS, 3, self.num_frames, self.height, self.width)) 73 | y = self._vit_tiny(x) 74 | self.assertIsInstance(y, torch.Tensor) 75 | self.assertEqual(y.size(), (BS, self.num_patches, VIT_EMBED_DIMS["vit_tiny"])) 76 | 77 | def test_model_video_masked_batchsize_4(self): 78 | BS = 4 79 | mask_indices = [6, 7, 8] 80 | masks = [torch.tensor(mask_indices, dtype=torch.int64) for _ in range(BS)] 81 | x = torch.rand((BS, 3, self.num_frames, self.height, self.width)) 82 | y = self._vit_tiny(x, masks=masks) 83 | self.assertIsInstance(y, torch.Tensor) 84 | self.assertEqual(y.size(), (BS, len(mask_indices), VIT_EMBED_DIMS["vit_tiny"])) 85 | 86 | def test_model_video_masked_batchsize_1(self): 87 | BS = 1 88 | mask_indices = [6, 7, 8] 89 | masks = [torch.tensor(mask_indices, dtype=torch.int64) for _ in range(BS)] 90 | x = torch.rand((BS, 3, self.num_frames, self.height, self.width)) 91 | y = self._vit_tiny(x, masks=masks) 92 | self.assertIsInstance(y, torch.Tensor) 93 | self.assertEqual(y.size(), (BS, len(mask_indices), VIT_EMBED_DIMS["vit_tiny"])) 94 | -------------------------------------------------------------------------------- /tests/models/test_predictor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import unittest 7 | 8 | import torch 9 | 10 | from src.models.predictor import VisionTransformerPredictor 11 | 12 | 13 | class TestImagePredictorMaskTokens(unittest.TestCase): 14 | def setUp(self) -> None: 15 | self._embed_dim = 768 16 | self._predictor = VisionTransformerPredictor(embed_dim=self._embed_dim, use_mask_tokens=True) 17 | 18 | def test_image_predictor_batchsize_4(self): 19 | BS = 4 20 | enc_mask_indices = [torch.tensor(BS * [[6, 7, 8]], dtype=torch.int64)] 21 | target_mask_indices = [torch.tensor(BS * [[16, 17, 18, 19]], dtype=torch.int64)] 22 | enc = torch.rand((BS, len(enc_mask_indices[0][0]), self._embed_dim)) 23 | y = self._predictor(enc, enc_mask_indices, target_mask_indices) 24 | self.assertIsInstance(y, torch.Tensor) 25 | self.assertEqual(y.size(), (BS, target_mask_indices[0].size(1), self._embed_dim)) 26 | 27 | def test_image_predictor_batchsize_1(self): 28 | BS = 1 29 | enc_mask_indices = [torch.tensor(BS * [[6, 7, 8]], dtype=torch.int64)] 30 | target_mask_indices = [torch.tensor(BS * [[16, 17, 18, 19]], dtype=torch.int64)] 31 | enc = torch.rand((BS, len(enc_mask_indices[0][0]), self._embed_dim)) 32 | y = self._predictor(enc, enc_mask_indices, target_mask_indices) 33 | self.assertIsInstance(y, torch.Tensor) 34 | self.assertEqual(y.size(), (BS, target_mask_indices[0].size(1), self._embed_dim)) 35 | 36 | 37 | class TestVideoPredictorMaskTokens(unittest.TestCase): 38 | def setUp(self) -> None: 39 | self._embed_dim = 768 40 | self._predictor = VisionTransformerPredictor(embed_dim=self._embed_dim, use_mask_tokens=True) 41 | 42 | def test_video_predictor_batchsize_4(self): 43 | BS = 4 44 | enc_mask_indices = [torch.tensor(BS * [[6, 7, 8]], dtype=torch.int64)] 45 | target_mask_indices = [torch.tensor(BS * [[16, 17, 18, 19]], dtype=torch.int64)] 46 | enc = torch.rand((BS, len(enc_mask_indices[0][0]), self._embed_dim)) 47 | y = self._predictor(enc, enc_mask_indices, target_mask_indices) 48 | self.assertIsInstance(y, torch.Tensor) 49 | self.assertEqual(y.size(), (BS, target_mask_indices[0].size(1), self._embed_dim)) 50 | 51 | def test_video_predictor_batchsize_1(self): 52 | BS = 1 53 | enc_mask_indices = [torch.tensor(BS * [[6, 7, 8]], dtype=torch.int64)] 54 | target_mask_indices = [torch.tensor(BS * [[16, 17, 18, 19]], dtype=torch.int64)] 55 | enc = torch.rand((BS, len(enc_mask_indices[0][0]), self._embed_dim)) 56 | y = self._predictor(enc, enc_mask_indices, target_mask_indices) 57 | self.assertIsInstance(y, torch.Tensor) 58 | self.assertEqual(y.size(), (BS, target_mask_indices[0].size(1), self._embed_dim)) 59 | -------------------------------------------------------------------------------- /tests/models/test_vision_transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import unittest 7 | from copy import deepcopy 8 | 9 | import numpy as np 10 | import pytest 11 | import torch 12 | 13 | from src.models.vision_transformer import vit_giant_xformers_rope 14 | 15 | 16 | # Usage: pytest tests/models/test_vision_transformer.py 17 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="This test requires CUDA") 18 | class TestViTGiant(unittest.TestCase): 19 | def setUp(self) -> None: 20 | self.model_shape_invariant = vit_giant_xformers_rope( 21 | img_size=256, patch_size=16, num_frames=16, handle_nonsquare_inputs=True 22 | ).cuda() 23 | self.model_square = deepcopy(self.model_shape_invariant) 24 | self.model_square.handle_nonsquare_inputs = False 25 | torch.manual_seed(42) 26 | self.total_iters = 10 27 | 28 | def test_square_inputs(self): 29 | for i in range(self.total_iters): 30 | input = torch.rand(1, 3, 16, 256, 256).cuda() 31 | with torch.cuda.amp.autocast(enabled=True): 32 | with torch.no_grad(): 33 | out1 = self.model_shape_invariant(input) 34 | out2 = self.model_square(input) 35 | torch.testing.assert_close(out1, out2) 36 | 37 | def test_square_inputs_with_mask(self): 38 | for i in range(self.total_iters): 39 | input = torch.rand(1, 3, 16, 256, 256).cuda() 40 | mask = torch.randint(0, 2, (1, 2048)).cuda() 41 | with torch.cuda.amp.autocast(enabled=True): 42 | with torch.no_grad(): 43 | out1 = self.model_shape_invariant(input, masks=mask) 44 | out2 = self.model_square(input, masks=mask) 45 | torch.testing.assert_close(out1, out2) 46 | 47 | def test_nonsquare_inputs(self): 48 | for i in range(self.total_iters): 49 | rand_width = np.random.randint(256, 512) 50 | rand_height = np.random.randint(256, 512) 51 | input = torch.rand(1, 3, 16, rand_height, rand_width).cuda() 52 | # Since input is interpolated, output won't be exactly the same 53 | input_resized_to_square = [ 54 | torch.nn.functional.interpolate(input[:, :, frame_idx], size=256, mode="bicubic") 55 | for frame_idx in range(input.shape[2]) 56 | ] 57 | input_resized_to_square = torch.stack(input_resized_to_square, dim=2) 58 | 59 | with torch.cuda.amp.autocast(enabled=True): 60 | with torch.no_grad(): 61 | out1 = self.model_shape_invariant(input).mean(dim=1) 62 | out2 = self.model_square(input_resized_to_square).mean(dim=1) 63 | self.assertAlmostEqual(torch.nn.functional.cosine_similarity(out2, out1).item(), 1.0, places=3) 64 | --------------------------------------------------------------------------------