The response has been limited to 50k tokens of the smallest files in the repo. You can remove this limitation by removing the max tokens filter.
├── .flake8
├── .github
    └── workflows
    │   ├── base_tests.yaml
    │   └── linters.yaml
├── .gitignore
├── APACHE-LICENSE
├── CHANGELOG.md
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── app
    ├── main.py
    ├── main_distributed.py
    ├── scaffold.py
    ├── vjepa
    │   ├── train.py
    │   ├── transforms.py
    │   └── utils.py
    └── vjepa_droid
    │   ├── droid.py
    │   ├── train.py
    │   ├── transforms.py
    │   └── utils.py
├── assets
    ├── flowchart.png
    ├── vjepa2-abstract-new.png
    └── vjepa2-ac-abstract-new.png
├── configs
    ├── eval
    │   ├── vitg-384
    │   │   ├── coin.yaml
    │   │   ├── diving48.yaml
    │   │   ├── ek100.yaml
    │   │   ├── in1k.yaml
    │   │   ├── jester.yaml
    │   │   ├── k400.yaml
    │   │   └── ssv2.yaml
    │   └── vitl
    │   │   ├── coin.yaml
    │   │   ├── diving48.yaml
    │   │   ├── ek100.yaml
    │   │   ├── in1k.yaml
    │   │   ├── jester.yaml
    │   │   ├── k400.yaml
    │   │   └── ssv2.yaml
    ├── inference
    │   ├── vitg-384
    │   │   ├── diving48.yaml
    │   │   ├── ek100.yaml
    │   │   └── ssv2.yaml
    │   └── vitl
    │   │   ├── diving48.yaml
    │   │   ├── ek100.yaml
    │   │   └── ssv2.yaml
    └── train
    │   ├── vitg16
    │       ├── cooldown-256px-64f.yaml
    │       ├── cooldown-384px-64f.yaml
    │       ├── droid-256px-8f.yaml
    │       └── pretrain-256px-16f.yaml
    │   ├── vith16
    │       ├── cooldown-256px-64f.yaml
    │       └── pretrain-256px-16f.yaml
    │   └── vitl16
    │       ├── cooldown-256px-64f.yaml
    │       └── pretrain-256px-16f.yaml
├── evals
    ├── action_anticipation_frozen
    │   ├── dataloader.py
    │   ├── epickitchens.py
    │   ├── eval.py
    │   ├── losses.py
    │   ├── metrics.py
    │   ├── modelcustom
    │   │   └── vit_encoder_predictor_concat_ar.py
    │   ├── models.py
    │   └── utils.py
    ├── hub
    │   ├── __init__.py
    │   └── preprocessor.py
    ├── image_classification_frozen
    │   ├── eval.py
    │   ├── modelcustom
    │   │   └── vit_encoder.py
    │   └── models.py
    ├── main.py
    ├── main_distributed.py
    ├── scaffold.py
    └── video_classification_frozen
    │   ├── eval.py
    │   ├── modelcustom
    │       ├── vit_encoder_multiclip.py
    │       └── vit_encoder_multiclip_multilevel.py
    │   ├── models.py
    │   └── utils.py
├── hubconf.py
├── notebooks
    ├── energy_landscape_example.ipynb
    ├── franka_example_traj.npz
    ├── utils
    │   ├── mpc_utils.py
    │   └── world_model_wrapper.py
    ├── vjepa2_demo.ipynb
    └── vjepa2_demo.py
├── pyproject.toml
├── requirements-test.txt
├── requirements.txt
├── setup.py
├── src
    ├── datasets
    │   ├── data_manager.py
    │   ├── imagenet1k.py
    │   ├── utils
    │   │   ├── dataloader.py
    │   │   ├── utils.py
    │   │   ├── video
    │   │   │   ├── functional.py
    │   │   │   ├── randaugment.py
    │   │   │   ├── randerase.py
    │   │   │   ├── transforms.py
    │   │   │   ├── transforms_builder.py
    │   │   │   └── volume_transforms.py
    │   │   ├── weighted_sampler.py
    │   │   └── worker_init_fn.py
    │   └── video_dataset.py
    ├── hub
    │   ├── __init__.py
    │   └── backbones.py
    ├── masks
    │   ├── default.py
    │   ├── multiseq_multiblock3d.py
    │   └── utils.py
    ├── models
    │   ├── ac_predictor.py
    │   ├── attentive_pooler.py
    │   ├── predictor.py
    │   ├── utils
    │   │   ├── modules.py
    │   │   ├── patch_embed.py
    │   │   └── pos_embs.py
    │   └── vision_transformer.py
    └── utils
    │   ├── checkpoint_loader.py
    │   ├── distributed.py
    │   ├── logging.py
    │   ├── monitoring.py
    │   ├── schedulers.py
    │   ├── tensors.py
    │   └── wrappers.py
└── tests
    ├── __init__.py
    ├── datasets
        ├── __init__.py
        ├── test_dataloader.py
        ├── test_memory_efficient_sampler.py
        └── test_vjepa_transforms.py
    └── models
        ├── __init__.py
        ├── test_models.py
        ├── test_predictor.py
        └── test_vision_transformer.py


/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | max-line-length = 119
3 | select = E,F,W
4 | ignore = E203,E701,W503
5 | per-file-ignores=__init__.py:F401 version.py:F401
6 | 


--------------------------------------------------------------------------------
/.github/workflows/base_tests.yaml:
--------------------------------------------------------------------------------
 1 | name: UnitTests
 2 | 
 3 | on: [push]
 4 | 
 5 | jobs:
 6 |   unittests:
 7 |     runs-on: ubuntu-latest
 8 |     strategy:
 9 |       max-parallel: 4
10 | 
11 |     steps:
12 |     - uses: actions/checkout@v4
13 |     - name: Set up Python 3.12
14 |       uses: actions/setup-python@v5
15 |       with:
16 |         python-version: '3.12'
17 |     - name: Add conda to system path
18 |       run: |
19 |         # $CONDA is an environment variable pointing to the root of the miniconda directory
20 |         echo $CONDA/bin >> $GITHUB_PATH
21 |     - name: Install dependencies
22 |       run: |
23 |         conda create --name test-env python=3.12
24 |         conda install pytest
25 |         echo "Starting setup from $PWD"
26 |         pip install -e .
27 |     - name: Test with pytest
28 |       run: |
29 |         pytest tests
30 | 


--------------------------------------------------------------------------------
/.github/workflows/linters.yaml:
--------------------------------------------------------------------------------
 1 | name: Lint (Common Code)
 2 | 
 3 | on:
 4 |   push:
 5 |     branches:
 6 |       - master
 7 |     paths:
 8 |       - 'app/'
 9 |       - 'evals/*.py'
10 |       - 'src/'
11 |       - 'tests/'
12 |   pull_request:
13 |     branches:
14 |       - master
15 |       - 'gh/**'
16 |     paths:
17 |       - 'app/'
18 |       - 'evals/*.py'
19 |       - 'src/'
20 |       - 'tests/'
21 | 
22 | jobs:
23 |   run-linters:
24 |     name: Run linters
25 |     runs-on: ubuntu-latest
26 | 
27 |     steps:
28 |       - uses: actions/checkout@v4
29 |       - name: Set up Python 3.12
30 |         uses: actions/setup-python@v5
31 |         with:
32 |           python-version: '3.12'
33 |       - name: Install Python lint dependencies
34 |         run: |
35 |           pip install -r requirements-test.txt
36 |       - name: Set lint paths
37 |         run: echo "lint_paths=app evals/*.py src tests" >> "$GITHUB_ENV"
38 |       - name: Run isort
39 |         run: |
40 |           python -m isort $lint_paths --check
41 |       - name: Run flake8
42 |         if: always()
43 |         run: |
44 |           python -m flake8 --config .flake8 --show-source --statistics $lint_paths
45 |       - name: Run black
46 |         if: always()
47 |         run: |
48 |           python -m black --check $lint_paths
49 | 


--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
 1 | *.pyc
 2 | .vscode/
 3 | .*.swp
 4 | 
 5 | run_vjepa_aws.py
 6 | run.py
 7 | main_distributed_video.py
 8 | main_video.py
 9 | 
10 | app/vjepa/configs/temp_aws
11 | app/main_dev.py
12 | app/main_distributed_dev.py
13 | evals/ava/alphaction/data
14 | 
15 | run_evals.py
16 | run_evals_v2.py
17 | run_pretrain.py
18 | 
19 | *.egg-info/
20 | *.ipynb_checkpoints/
21 | 
22 | traces/
23 | third_party/*
24 | 
25 | evals/simu_env_planning/local/
26 | evals/simu_env_planning/docker2/
27 | evals/simu_env_planning/docker/
28 | app/vjepa_droid/local/
29 | app/vjepa_droid_v2/local/
30 | app/vjepa_droid_v3/local/
31 | app/vjepa_droid_v4/local/
32 | configs/local


--------------------------------------------------------------------------------
/CHANGELOG.md:
--------------------------------------------------------------------------------
1 | # Changelog
2 | 
3 | ## [0.0.1] - 2025-06-05
4 | 
5 | Initial release of V-JEPA 2 codebase


--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
 1 | # Code of Conduct
 2 | 
 3 | ## Our Pledge
 4 | 
 5 | In the interest of fostering an open and welcoming environment, we as
 6 | contributors and maintainers pledge to make participation in our project and
 7 | our community a harassment-free experience for everyone, regardless of age, body
 8 | size, disability, ethnicity, sex characteristics, gender identity and expression,
 9 | level of experience, education, socio-economic status, nationality, personal
10 | appearance, race, religion, or sexual identity and orientation.
11 | 
12 | ## Our Standards
13 | 
14 | Examples of behavior that contributes to creating a positive environment
15 | include:
16 | 
17 | * Using welcoming and inclusive language
18 | * Being respectful of differing viewpoints and experiences
19 | * Gracefully accepting constructive criticism
20 | * Focusing on what is best for the community
21 | * Showing empathy towards other community members
22 | 
23 | Examples of unacceptable behavior by participants include:
24 | 
25 | * The use of sexualized language or imagery and unwelcome sexual attention or
26 | advances
27 | * Trolling, insulting/derogatory comments, and personal or political attacks
28 | * Public or private harassment
29 | * Publishing others' private information, such as a physical or electronic
30 | address, without explicit permission
31 | * Other conduct which could reasonably be considered inappropriate in a
32 | professional setting
33 | 
34 | ## Our Responsibilities
35 | 
36 | Project maintainers are responsible for clarifying the standards of acceptable
37 | behavior and are expected to take appropriate and fair corrective action in
38 | response to any instances of unacceptable behavior.
39 | 
40 | Project maintainers have the right and responsibility to remove, edit, or
41 | reject comments, commits, code, wiki edits, issues, and other contributions
42 | that are not aligned to this Code of Conduct, or to ban temporarily or
43 | permanently any contributor for other behaviors that they deem inappropriate,
44 | threatening, offensive, or harmful.
45 | 
46 | ## Scope
47 | 
48 | This Code of Conduct applies within all project spaces, and it also applies when
49 | an individual is representing the project or its community in public spaces.
50 | Examples of representing a project or community include using an official
51 | project e-mail address, posting via an official social media account, or acting
52 | as an appointed representative at an online or offline event. Representation of
53 | a project may be further defined and clarified by project maintainers.
54 | 
55 | This Code of Conduct also applies outside the project spaces when there is a
56 | reasonable belief that an individual's behavior may have a negative impact on
57 | the project or its community.
58 | 
59 | ## Enforcement
60 | 
61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
62 | reported by contacting the project team at <opensource-conduct@fb.com>. All
63 | complaints will be reviewed and investigated and will result in a response that
64 | is deemed necessary and appropriate to the circumstances. The project team is
65 | obligated to maintain confidentiality with regard to the reporter of an incident.
66 | Further details of specific enforcement policies may be posted separately.
67 | 
68 | Project maintainers who do not follow or enforce the Code of Conduct in good
69 | faith may face temporary or permanent repercussions as determined by other
70 | members of the project's leadership.
71 | 
72 | ## Attribution
73 | 
74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
76 | 
77 | [homepage]: https://www.contributor-covenant.org
78 | 
79 | For answers to common questions about this code of conduct, see
80 | https://www.contributor-covenant.org/faq
81 | 


--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
 1 | # Contributing to V-JEPA 2
 2 | We want to make contributing to this project as easy and transparent as
 3 | possible.
 4 | 
 5 | ## Pull Requests
 6 | We welcome your pull requests.
 7 | 
 8 | 1. Fork the repo and create your branch from `main`.
 9 | 2. If you've added code that should be tested, add tests.
10 | 3. If you've changed APIs, update the documentation.
11 | 4. Ensure the test suite passes.
12 | 5. Make sure your code is consistent with style guidance (below) and lints.
13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA").
14 | 7. Add reviewer(s) for approval.
15 | 
16 | ## Contributor License Agreement ("CLA")
17 | In order to accept your pull request, we need you to submit a CLA. You only need
18 | to do this once to work on any of Facebook's open source projects.
19 | 
20 | Complete your CLA here: <https://code.facebook.com/cla>
21 | 
22 | ## Issues
23 | We use GitHub issues to track public bugs. Please ensure your description is
24 | clear and has sufficient instructions to be able to reproduce the issue.
25 | 
26 | Meta has a [bounty program](https://bugbounty.meta.com/) for the safe
27 | disclosure of security bugs. In those cases, please go through the process
28 | outlined on that page and do not file a public issue.
29 | 
30 | ## Coding Style
31 | * 4 spaces for indentation rather than tabs
32 | * 119 character line length
33 | * PEP8 formatting
34 | 
35 | We recommend using `black`, `isort`, and `flake8` to format your code changes.
36 | 
37 | ## License
38 | By contributing to this repository, you agree that your contributions will be licensed
39 | under the LICENSE file in the root directory of this source tree.
40 | 


--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
 1 | MIT License
 2 | 
 3 | Copyright (c) Meta Platforms, Inc. and affiliates.
 4 | 
 5 | Permission is hereby granted, free of charge, to any person obtaining a copy
 6 | of this software and associated documentation files (the "Software"), to deal
 7 | in the Software without restriction, including without limitation the rights
 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 | 
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 | 
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.


--------------------------------------------------------------------------------
/app/main.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
 2 | #
 3 | # This source code is licensed under the MIT license found in the
 4 | # LICENSE file in the root directory of this source tree.
 5 | 
 6 | import argparse
 7 | import multiprocessing as mp
 8 | import pprint
 9 | from pathlib import Path
10 | 
11 | import yaml
12 | 
13 | from app.scaffold import main as app_main
14 | from src.utils.distributed import init_distributed
15 | 
16 | parser = argparse.ArgumentParser()
17 | parser.add_argument("--fname", type=str, help="name of config file to load", default="configs.yaml")
18 | parser.add_argument(
19 |     "--devices",
20 |     type=str,
21 |     nargs="+",
22 |     default=["cuda:0", "cuda:1", "cuda:2", "cuda:3", "cuda:4", "cuda:5", "cuda:6", "cuda:7"],
23 |     help="which devices to use on local machine",
24 | )
25 | parser.add_argument(
26 |     "--debugmode",
27 |     type=bool,
28 |     default=False,
29 |     help="Setting this to true will not spin up new processes. "
30 |     "The main code runs the main process, which makes it easier to \
31 |     debug with checkpointing.",
32 | )
33 | 
34 | 
35 | def process_main(rank, fname, world_size, devices):
36 |     import os
37 | 
38 |     os.environ["CUDA_VISIBLE_DEVICES"] = str(devices[rank].split(":")[-1])
39 | 
40 |     import logging
41 | 
42 |     from src.utils.logging import get_logger
43 | 
44 |     logger = get_logger(force=True)
45 |     if rank == 0:
46 |         logger.setLevel(logging.INFO)
47 |     else:
48 |         logger.setLevel(logging.ERROR)
49 | 
50 |     logger.info(f"called-params {fname}")
51 | 
52 |     # Load config
53 |     params = None
54 |     with open(fname, "r") as y_file:
55 |         params = yaml.load(y_file, Loader=yaml.FullLoader)
56 |         logger.info("loaded params...")
57 | 
58 |     # Log config
59 |     if rank == 0:
60 |         pprint.PrettyPrinter(indent=4).pprint(params)
61 |         folder = params["folder"]
62 |         params_path = os.path.join(folder, "params-pretrain.yaml")
63 |         folder = Path(folder)
64 |         folder.mkdir(parents=True, exist_ok=True)
65 |         with open(params_path, "w") as f:
66 |             yaml.dump(params, f)
67 | 
68 |     # Init distributed (access to comm between GPUS on same machine)
69 |     world_size, rank = init_distributed(rank_and_world_size=(rank, world_size))
70 |     logger.info(f"Running... (rank: {rank}/{world_size})")
71 | 
72 |     # Launch the app with loaded config
73 |     app_main(params["app"], args=params)
74 | 
75 | 
76 | if __name__ == "__main__":
77 |     args = parser.parse_args()
78 |     if args.debugmode:
79 |         process_main(rank=0, fname=args.fname, world_size=1, devices=["cuda:0"])
80 |     else:
81 |         num_gpus = len(args.devices)
82 |         mp.set_start_method("spawn")
83 |         for rank in range(num_gpus):
84 |             mp.Process(target=process_main, args=(rank, args.fname, num_gpus, args.devices)).start()
85 | 


--------------------------------------------------------------------------------
/app/scaffold.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
 2 | #
 3 | # This source code is licensed under the MIT license found in the
 4 | # LICENSE file in the root directory of this source tree.
 5 | 
 6 | import importlib
 7 | import logging
 8 | import sys
 9 | 
10 | logging.basicConfig(stream=sys.stdout, level=logging.INFO)
11 | logger = logging.getLogger()
12 | 
13 | 
14 | def main(app, args, resume_preempt=False):
15 | 
16 |     logger.info(f"Running pre-training of app: {app}")
17 |     return importlib.import_module(f"app.{app}.train").main(args=args, resume_preempt=resume_preempt)
18 | 


--------------------------------------------------------------------------------
/app/vjepa/transforms.py:
--------------------------------------------------------------------------------
  1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
  2 | #
  3 | # This source code is licensed under the MIT license found in the
  4 | # LICENSE file in the root directory of this source tree.
  5 | 
  6 | import torch
  7 | import torchvision.transforms as transforms
  8 | 
  9 | import src.datasets.utils.video.transforms as video_transforms
 10 | from src.datasets.utils.video.randerase import RandomErasing
 11 | 
 12 | 
 13 | def make_transforms(
 14 |     random_horizontal_flip=True,
 15 |     random_resize_aspect_ratio=(3 / 4, 4 / 3),
 16 |     random_resize_scale=(0.3, 1.0),
 17 |     reprob=0.0,
 18 |     auto_augment=False,
 19 |     motion_shift=False,
 20 |     crop_size=224,
 21 |     normalize=((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
 22 | ):
 23 | 
 24 |     _frames_augmentation = VideoTransform(
 25 |         random_horizontal_flip=random_horizontal_flip,
 26 |         random_resize_aspect_ratio=random_resize_aspect_ratio,
 27 |         random_resize_scale=random_resize_scale,
 28 |         reprob=reprob,
 29 |         auto_augment=auto_augment,
 30 |         motion_shift=motion_shift,
 31 |         crop_size=crop_size,
 32 |         normalize=normalize,
 33 |     )
 34 |     return _frames_augmentation
 35 | 
 36 | 
 37 | class VideoTransform(object):
 38 | 
 39 |     def __init__(
 40 |         self,
 41 |         random_horizontal_flip=True,
 42 |         random_resize_aspect_ratio=(3 / 4, 4 / 3),
 43 |         random_resize_scale=(0.3, 1.0),
 44 |         reprob=0.0,
 45 |         auto_augment=False,
 46 |         motion_shift=False,
 47 |         crop_size=224,
 48 |         normalize=((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
 49 |     ):
 50 | 
 51 |         self.random_horizontal_flip = random_horizontal_flip
 52 |         self.random_resize_aspect_ratio = random_resize_aspect_ratio
 53 |         self.random_resize_scale = random_resize_scale
 54 |         self.auto_augment = auto_augment
 55 |         self.motion_shift = motion_shift
 56 |         self.crop_size = crop_size
 57 |         self.mean = torch.tensor(normalize[0], dtype=torch.float32)
 58 |         self.std = torch.tensor(normalize[1], dtype=torch.float32)
 59 |         if not self.auto_augment:
 60 |             # Without auto-augment, PIL and tensor conversions simply scale uint8 space by 255.
 61 |             self.mean *= 255.0
 62 |             self.std *= 255.0
 63 | 
 64 |         self.autoaug_transform = video_transforms.create_random_augment(
 65 |             input_size=(crop_size, crop_size),
 66 |             # auto_augment="rand-m4-n4-w1-mstd0.5-inc1",
 67 |             auto_augment="rand-m7-n4-mstd0.5-inc1",
 68 |             interpolation="bicubic",
 69 |         )
 70 | 
 71 |         self.spatial_transform = (
 72 |             video_transforms.random_resized_crop_with_shift if motion_shift else video_transforms.random_resized_crop
 73 |         )
 74 | 
 75 |         self.reprob = reprob
 76 |         self.erase_transform = RandomErasing(
 77 |             reprob,
 78 |             mode="pixel",
 79 |             max_count=1,
 80 |             num_splits=1,
 81 |             device="cpu",
 82 |         )
 83 | 
 84 |     def __call__(self, buffer):
 85 | 
 86 |         if self.auto_augment:
 87 |             buffer = [transforms.ToPILImage()(frame) for frame in buffer]
 88 |             buffer = self.autoaug_transform(buffer)
 89 |             buffer = [transforms.ToTensor()(img) for img in buffer]
 90 |             buffer = torch.stack(buffer)  # T C H W
 91 |             buffer = buffer.permute(0, 2, 3, 1)  # T H W C
 92 |         elif torch.is_tensor(buffer):
 93 |             # TODO: ensure input is always a tensor?
 94 |             buffer = buffer.to(torch.float32)
 95 |         else:
 96 |             buffer = torch.tensor(buffer, dtype=torch.float32)
 97 | 
 98 |         buffer = buffer.permute(3, 0, 1, 2)  # T H W C -> C T H W
 99 | 
100 |         buffer = self.spatial_transform(
101 |             images=buffer,
102 |             target_height=self.crop_size,
103 |             target_width=self.crop_size,
104 |             scale=self.random_resize_scale,
105 |             ratio=self.random_resize_aspect_ratio,
106 |         )
107 |         if self.random_horizontal_flip:
108 |             buffer, _ = video_transforms.horizontal_flip(0.5, buffer)
109 | 
110 |         buffer = _tensor_normalize_inplace(buffer, self.mean, self.std)
111 |         if self.reprob > 0:
112 |             buffer = buffer.permute(1, 0, 2, 3)
113 |             buffer = self.erase_transform(buffer)
114 |             buffer = buffer.permute(1, 0, 2, 3)
115 | 
116 |         return buffer
117 | 
118 | 
119 | def tensor_normalize(tensor, mean, std):
120 |     """
121 |     Normalize a given tensor by subtracting the mean and dividing the std.
122 |     Args:
123 |         tensor (tensor): tensor to normalize.
124 |         mean (tensor or list): mean value to subtract.
125 |         std (tensor or list): std to divide.
126 |     """
127 |     if tensor.dtype == torch.uint8:
128 |         tensor = tensor.float()
129 |         tensor = tensor / 255.0
130 |     if isinstance(mean, list):
131 |         mean = torch.tensor(mean)
132 |     if isinstance(std, list):
133 |         std = torch.tensor(std)
134 |     tensor = tensor - mean
135 |     tensor = tensor / std
136 |     return tensor
137 | 
138 | 
139 | def _tensor_normalize_inplace(tensor, mean, std):
140 |     """
141 |     Normalize a given tensor by subtracting the mean and dividing the std.
142 |     Args:
143 |         tensor (tensor): tensor to normalize (with dimensions C, T, H, W).
144 |         mean (tensor): mean value to subtract (in 0 to 255 floats).
145 |         std (tensor): std to divide (in 0 to 255 floats).
146 |     """
147 |     if tensor.dtype == torch.uint8:
148 |         tensor = tensor.float()
149 | 
150 |     C, T, H, W = tensor.shape
151 |     tensor = tensor.view(C, -1).permute(1, 0)  # Make C the last dimension
152 |     tensor.sub_(mean).div_(std)
153 |     tensor = tensor.permute(1, 0).view(C, T, H, W)  # Put C back in front
154 |     return tensor
155 | 


--------------------------------------------------------------------------------
/app/vjepa_droid/transforms.py:
--------------------------------------------------------------------------------
  1 | # Copyright (c) Facebook, Inc. and its affiliates.
  2 | # All rights reserved.
  3 | #
  4 | # This source code is licensed under the license found in the
  5 | # LICENSE file in the root directory of this source tree.
  6 | #
  7 | 
  8 | import torch
  9 | import torchvision.transforms as transforms
 10 | 
 11 | import src.datasets.utils.video.transforms as video_transforms
 12 | from src.datasets.utils.video.randerase import RandomErasing
 13 | 
 14 | 
 15 | def make_transforms(
 16 |     random_horizontal_flip=True,
 17 |     random_resize_aspect_ratio=(3 / 4, 4 / 3),
 18 |     random_resize_scale=(0.3, 1.0),
 19 |     reprob=0.0,
 20 |     auto_augment=False,
 21 |     motion_shift=False,
 22 |     crop_size=224,
 23 |     normalize=((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
 24 | ):
 25 | 
 26 |     _frames_augmentation = VideoTransform(
 27 |         random_horizontal_flip=random_horizontal_flip,
 28 |         random_resize_aspect_ratio=random_resize_aspect_ratio,
 29 |         random_resize_scale=random_resize_scale,
 30 |         reprob=reprob,
 31 |         auto_augment=auto_augment,
 32 |         motion_shift=motion_shift,
 33 |         crop_size=crop_size,
 34 |         normalize=normalize,
 35 |     )
 36 |     return _frames_augmentation
 37 | 
 38 | 
 39 | class VideoTransform(object):
 40 | 
 41 |     def __init__(
 42 |         self,
 43 |         random_horizontal_flip=True,
 44 |         random_resize_aspect_ratio=(3 / 4, 4 / 3),
 45 |         random_resize_scale=(0.3, 1.0),
 46 |         reprob=0.0,
 47 |         auto_augment=False,
 48 |         motion_shift=False,
 49 |         crop_size=224,
 50 |         normalize=((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
 51 |     ):
 52 | 
 53 |         self.random_horizontal_flip = random_horizontal_flip
 54 |         self.random_resize_aspect_ratio = random_resize_aspect_ratio
 55 |         self.random_resize_scale = random_resize_scale
 56 |         self.auto_augment = auto_augment
 57 |         self.motion_shift = motion_shift
 58 |         self.crop_size = crop_size
 59 |         self.mean = torch.tensor(normalize[0], dtype=torch.float32)
 60 |         self.std = torch.tensor(normalize[1], dtype=torch.float32)
 61 |         if not self.auto_augment:
 62 |             # Without auto-augment, PIL and tensor conversions simply scale uint8 space by 255.
 63 |             self.mean *= 255.0
 64 |             self.std *= 255.0
 65 | 
 66 |         self.autoaug_transform = video_transforms.create_random_augment(
 67 |             input_size=(crop_size, crop_size),
 68 |             # auto_augment="rand-m4-n4-w1-mstd0.5-inc1",
 69 |             auto_augment="rand-m7-n4-mstd0.5-inc1",
 70 |             interpolation="bicubic",
 71 |         )
 72 | 
 73 |         self.spatial_transform = (
 74 |             video_transforms.random_resized_crop_with_shift if motion_shift else video_transforms.random_resized_crop
 75 |         )
 76 | 
 77 |         self.reprob = reprob
 78 |         self.erase_transform = RandomErasing(
 79 |             reprob,
 80 |             mode="pixel",
 81 |             max_count=1,
 82 |             num_splits=1,
 83 |             device="cpu",
 84 |         )
 85 | 
 86 |     def __call__(self, buffer):
 87 | 
 88 |         if self.auto_augment:
 89 |             buffer = [transforms.ToPILImage()(frame) for frame in buffer]
 90 |             buffer = self.autoaug_transform(buffer)
 91 |             buffer = [transforms.ToTensor()(img) for img in buffer]
 92 |             buffer = torch.stack(buffer)  # T C H W
 93 |             buffer = buffer.permute(0, 2, 3, 1)  # T H W C
 94 |         elif torch.is_tensor(buffer):
 95 |             # TODO: ensure input is always a tensor?
 96 |             buffer = buffer.to(torch.float32)
 97 |         else:
 98 |             buffer = torch.tensor(buffer, dtype=torch.float32)
 99 | 
100 |         buffer = buffer.permute(3, 0, 1, 2)  # T H W C -> C T H W
101 | 
102 |         buffer = self.spatial_transform(
103 |             images=buffer,
104 |             target_height=self.crop_size,
105 |             target_width=self.crop_size,
106 |             scale=self.random_resize_scale,
107 |             ratio=self.random_resize_aspect_ratio,
108 |         )
109 |         if self.random_horizontal_flip:
110 |             buffer, _ = video_transforms.horizontal_flip(0.5, buffer)
111 | 
112 |         buffer = _tensor_normalize_inplace(buffer, self.mean, self.std)
113 |         if self.reprob > 0:
114 |             buffer = buffer.permute(1, 0, 2, 3)
115 |             buffer = self.erase_transform(buffer)
116 |             buffer = buffer.permute(1, 0, 2, 3)
117 | 
118 |         return buffer
119 | 
120 | 
121 | def tensor_normalize(tensor, mean, std):
122 |     """
123 |     Normalize a given tensor by subtracting the mean and dividing the std.
124 |     Args:
125 |         tensor (tensor): tensor to normalize.
126 |         mean (tensor or list): mean value to subtract.
127 |         std (tensor or list): std to divide.
128 |     """
129 |     if tensor.dtype == torch.uint8:
130 |         tensor = tensor.float()
131 |         tensor = tensor / 255.0
132 |     if type(mean) == list:
133 |         mean = torch.tensor(mean)
134 |     if type(std) == list:
135 |         std = torch.tensor(std)
136 |     tensor = tensor - mean
137 |     tensor = tensor / std
138 |     return tensor
139 | 
140 | 
141 | def _tensor_normalize_inplace(tensor, mean, std):
142 |     """
143 |     Normalize a given tensor by subtracting the mean and dividing the std.
144 |     Args:
145 |         tensor (tensor): tensor to normalize (with dimensions C, T, H, W).
146 |         mean (tensor): mean value to subtract (in 0 to 255 floats).
147 |         std (tensor): std to divide (in 0 to 255 floats).
148 |     """
149 |     if tensor.dtype == torch.uint8:
150 |         tensor = tensor.float()
151 | 
152 |     C, T, H, W = tensor.shape
153 |     tensor = tensor.view(C, -1).permute(1, 0)  # Make C the last dimension
154 |     tensor.sub_(mean).div_(std)
155 |     tensor = tensor.permute(1, 0).view(C, T, H, W)  # Put C back in front
156 |     return tensor
157 | 


--------------------------------------------------------------------------------
/assets/flowchart.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/facebookresearch/vjepa2/2dbda75ac7ab8e4e16f1c987c2b76261b54a6e3d/assets/flowchart.png


--------------------------------------------------------------------------------
/assets/vjepa2-abstract-new.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/facebookresearch/vjepa2/2dbda75ac7ab8e4e16f1c987c2b76261b54a6e3d/assets/vjepa2-abstract-new.png


--------------------------------------------------------------------------------
/assets/vjepa2-ac-abstract-new.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/facebookresearch/vjepa2/2dbda75ac7ab8e4e16f1c987c2b76261b54a6e3d/assets/vjepa2-ac-abstract-new.png


--------------------------------------------------------------------------------
/configs/eval/vitg-384/coin.yaml:
--------------------------------------------------------------------------------
  1 | cpus_per_task: 16
  2 | eval_name: video_classification_frozen
  3 | folder: /your_folder/evals/vitg-384/coin
  4 | mem_per_gpu: 220G
  5 | nodes: 16
  6 | resume_checkpoint: true
  7 | tag: coin-vitg16-384-16x8x3
  8 | tasks_per_node: 8
  9 | experiment:
 10 |   classifier:
 11 |     num_heads: 16
 12 |     num_probe_blocks: 4
 13 |   data:
 14 |     dataset_type: VideoDataset
 15 |     dataset_train: /your_data_folder/COIN/train_paths.csv
 16 |     dataset_val: /your_data_folder/COIN/val_paths.csv
 17 |     frame_step: 4
 18 |     frames_per_clip: 16
 19 |     num_classes: 180
 20 |     num_segments: 8
 21 |     num_views_per_segment: 3
 22 |     resolution: 384
 23 |   optimization:
 24 |     batch_size: 1
 25 |     multihead_kwargs:
 26 |     - final_lr: 0.0
 27 |       final_weight_decay: 0.01
 28 |       lr: 0.005
 29 |       start_lr: 0.005
 30 |       warmup: 0.0
 31 |       weight_decay: 0.01
 32 |     - final_lr: 0.0
 33 |       final_weight_decay: 0.01
 34 |       lr: 0.003
 35 |       start_lr: 0.003
 36 |       warmup: 0.0
 37 |       weight_decay: 0.01
 38 |     - final_lr: 0.0
 39 |       final_weight_decay: 0.01
 40 |       lr: 0.001
 41 |       start_lr: 0.001
 42 |       warmup: 0.0
 43 |       weight_decay: 0.01
 44 |     - final_lr: 0.0
 45 |       final_weight_decay: 0.01
 46 |       lr: 0.0003
 47 |       start_lr: 0.0003
 48 |       warmup: 0.0
 49 |       weight_decay: 0.01
 50 |     - final_lr: 0.0
 51 |       final_weight_decay: 0.01
 52 |       lr: 0.0001
 53 |       start_lr: 0.0001
 54 |       warmup: 0.0
 55 |       weight_decay: 0.01
 56 |     - final_lr: 0.0
 57 |       final_weight_decay: 0.1
 58 |       lr: 0.005
 59 |       start_lr: 0.005
 60 |       warmup: 0.0
 61 |       weight_decay: 0.1
 62 |     - final_lr: 0.0
 63 |       final_weight_decay: 0.1
 64 |       lr: 0.003
 65 |       start_lr: 0.003
 66 |       warmup: 0.0
 67 |       weight_decay: 0.1
 68 |     - final_lr: 0.0
 69 |       final_weight_decay: 0.1
 70 |       lr: 0.001
 71 |       start_lr: 0.001
 72 |       warmup: 0.0
 73 |       weight_decay: 0.1
 74 |     - final_lr: 0.0
 75 |       final_weight_decay: 0.1
 76 |       lr: 0.0003
 77 |       start_lr: 0.0003
 78 |       warmup: 0.0
 79 |       weight_decay: 0.1
 80 |     - final_lr: 0.0
 81 |       final_weight_decay: 0.1
 82 |       lr: 0.0001
 83 |       start_lr: 0.0001
 84 |       warmup: 0.0
 85 |       weight_decay: 0.1
 86 |     - final_lr: 0.0
 87 |       final_weight_decay: 0.4
 88 |       lr: 0.005
 89 |       start_lr: 0.005
 90 |       warmup: 0.0
 91 |       weight_decay: 0.4
 92 |     - final_lr: 0.0
 93 |       final_weight_decay: 0.4
 94 |       lr: 0.003
 95 |       start_lr: 0.003
 96 |       warmup: 0.0
 97 |       weight_decay: 0.4
 98 |     - final_lr: 0.0
 99 |       final_weight_decay: 0.4
100 |       lr: 0.001
101 |       start_lr: 0.001
102 |       warmup: 0.0
103 |       weight_decay: 0.4
104 |     - final_lr: 0.0
105 |       final_weight_decay: 0.4
106 |       lr: 0.0003
107 |       start_lr: 0.0003
108 |       warmup: 0.0
109 |       weight_decay: 0.4
110 |     - final_lr: 0.0
111 |       final_weight_decay: 0.4
112 |       lr: 0.0001
113 |       start_lr: 0.0001
114 |       warmup: 0.0
115 |       weight_decay: 0.4
116 |     - final_lr: 0.0
117 |       final_weight_decay: 0.8
118 |       lr: 0.005
119 |       start_lr: 0.005
120 |       warmup: 0.0
121 |       weight_decay: 0.8
122 |     - final_lr: 0.0
123 |       final_weight_decay: 0.8
124 |       lr: 0.003
125 |       start_lr: 0.003
126 |       warmup: 0.0
127 |       weight_decay: 0.8
128 |     - final_lr: 0.0
129 |       final_weight_decay: 0.8
130 |       lr: 0.001
131 |       start_lr: 0.001
132 |       warmup: 0.0
133 |       weight_decay: 0.8
134 |     - final_lr: 0.0
135 |       final_weight_decay: 0.8
136 |       lr: 0.0003
137 |       start_lr: 0.0003
138 |       warmup: 0.0
139 |       weight_decay: 0.8
140 |     - final_lr: 0.0
141 |       final_weight_decay: 0.8
142 |       lr: 0.0001
143 |       start_lr: 0.0001
144 |       warmup: 0.0
145 |       weight_decay: 0.8
146 |     num_epochs: 20
147 |     use_bfloat16: true
148 |     use_pos_embed: false
149 | model_kwargs:
150 |   checkpoint: /your_vjepa2_checkpoints/vitg-384.pt
151 |   module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip
152 |   pretrain_kwargs:
153 |     encoder:
154 |       checkpoint_key: target_encoder
155 |       img_temporal_dim_size: null
156 |       model_name: vit_giant_xformers
157 |       patch_size: 16
158 |       tubelet_size: 2
159 |       uniform_power: true
160 |       use_rope: true
161 |   wrapper_kwargs:
162 |     max_frames: 128
163 |     use_pos_embed: false
164 | 


--------------------------------------------------------------------------------
/configs/eval/vitg-384/diving48.yaml:
--------------------------------------------------------------------------------
 1 | nodes: 8
 2 | tasks_per_node: 8
 3 | cpus_per_task: 12
 4 | mem_per_gpu: 200G
 5 | tag: diving48-vitg16-384-32x4x3
 6 | eval_name: video_classification_frozen
 7 | folder: /your_folder/evals/vitg-384/diving48
 8 | resume_checkpoint: true
 9 | experiment:
10 |   classifier:
11 |     num_probe_blocks: 4
12 |     num_heads: 16
13 |   data:
14 |     dataset_type: VideoDataset
15 |     dataset_train: /your_data_dir/diving48/annotations/Diving48_train_paths.csv
16 |     dataset_val: /your_data_dir/diving48/annotations/Diving48_test_paths.csv
17 |     num_classes: 48
18 |     resolution: 384
19 |     frames_per_clip: 32
20 |     frame_step: 2
21 |     num_segments: 4
22 |     num_views_per_segment: 3
23 |   optimization:
24 |     use_pos_embed: false
25 |     num_epochs: 100
26 |     batch_size: 2
27 |     use_bfloat16: true
28 |     multihead_kwargs:
29 |       - weight_decay: 0.8
30 |         final_weight_decay: 0.8
31 |         lr: 0.001
32 |         start_lr: 0.001
33 |         final_lr: 0.0
34 |         warmup: 0.
35 |       - weight_decay: 0.8
36 |         final_weight_decay: 0.8
37 |         lr: 0.0003
38 |         start_lr: 0.0003
39 |         final_lr: 0.0
40 |         warmup: 0.
41 |       - weight_decay: 0.8
42 |         final_weight_decay: 0.8
43 |         lr: 0.0001
44 |         start_lr: 0.0001
45 |         final_lr: 0.0
46 |         warmup: 0.
47 | model_kwargs:
48 |   checkpoint: /your_vjepa2_checkpoints/vitg-384.pt
49 |   module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip_multilevel
50 |   wrapper_kwargs:
51 |     max_frames: 128
52 |     use_pos_embed: false
53 |     out_layers: [24, 29, 34, 39]
54 |   pretrain_kwargs:
55 |     encoder:
56 |       model_name: vit_giant_xformers
57 |       checkpoint_key: target_encoder
58 |       tubelet_size: 2
59 |       patch_size: 16
60 |       uniform_power: true
61 |       use_rope: true
62 | 


--------------------------------------------------------------------------------
/configs/eval/vitg-384/ek100.yaml:
--------------------------------------------------------------------------------
  1 | nodes: 8
  2 | tasks_per_node: 8
  3 | cpus_per_task: 12
  4 | tag: ek100-vitg16-384
  5 | eval_name: action_anticipation_frozen
  6 | folder: /your_folder/evals/vitg-384/ek100
  7 | resume_checkpoint: true
  8 | experiment:
  9 |   classifier:
 10 |     num_probe_blocks: 4
 11 |     num_heads: 16
 12 |   data:
 13 |     anticipation_time_sec:
 14 |     - 1.0
 15 |     - 1.0
 16 |     auto_augment: true
 17 |     file_format: 0
 18 |     dataset: EK100
 19 |     # e.g. /home/username/EPIC-KITCHENS
 20 |     base_path: /your_ek100_root_dir/
 21 |     dataset_train: /your_data/EPIC_100_train.csv
 22 |     dataset_val: /your_data/EPIC_100_validation.csv
 23 |     frames_per_clip: 32
 24 |     frames_per_second: 8
 25 |     motion_shift: false
 26 |     num_workers: 2
 27 |     pin_memory: true
 28 |     random_resize_scale:
 29 |     - 0.08
 30 |     - 1.0
 31 |     reprob: 0.25
 32 |     resolution: 384
 33 |     train_anticipation_point:
 34 |     - 0.0
 35 |     - 0.25
 36 |     train_anticipation_time_sec:
 37 |     - 0.25
 38 |     - 1.75
 39 |   optimization:
 40 |     num_epochs: 20
 41 |     batch_size: 2
 42 |     use_bfloat16: true
 43 |     use_focal_loss: true
 44 |     multihead_kwargs:
 45 |       - weight_decay: 0.0001
 46 |         final_weight_decay: 0.0001
 47 |         lr: 0.005
 48 |         start_lr: 0.005
 49 |         final_lr: 0.0
 50 |         warmup: 0.
 51 |       - weight_decay: 0.0001
 52 |         final_weight_decay: 0.0001
 53 |         lr: 0.003
 54 |         start_lr: 0.003
 55 |         final_lr: 0.0
 56 |         warmup: 0.
 57 |       - weight_decay: 0.0001
 58 |         final_weight_decay: 0.0001
 59 |         lr: 0.001
 60 |         start_lr: 0.001
 61 |         final_lr: 0.0
 62 |         warmup: 0.
 63 |       - weight_decay: 0.0001
 64 |         final_weight_decay: 0.0001
 65 |         lr: 0.0003
 66 |         start_lr: 0.0003
 67 |         final_lr: 0.0
 68 |         warmup: 0.
 69 |       - weight_decay: 0.0001
 70 |         final_weight_decay: 0.0001
 71 |         lr: 0.0001
 72 |         start_lr: 0.0001
 73 |         final_lr: 0.0
 74 |         warmup: 0.
 75 | 
 76 |       - weight_decay: 0.001
 77 |         final_weight_decay: 0.001
 78 |         lr: 0.005
 79 |         start_lr: 0.005
 80 |         final_lr: 0.0
 81 |         warmup: 0.
 82 |       - weight_decay: 0.001
 83 |         final_weight_decay: 0.001
 84 |         lr: 0.003
 85 |         start_lr: 0.003
 86 |         final_lr: 0.0
 87 |         warmup: 0.
 88 |       - weight_decay: 0.001
 89 |         final_weight_decay: 0.001
 90 |         lr: 0.001
 91 |         start_lr: 0.001
 92 |         final_lr: 0.0
 93 |         warmup: 0.
 94 |       - weight_decay: 0.001
 95 |         final_weight_decay: 0.001
 96 |         lr: 0.0003
 97 |         start_lr: 0.0003
 98 |         final_lr: 0.0
 99 |         warmup: 0.
100 |       - weight_decay: 0.001
101 |         final_weight_decay: 0.001
102 |         lr: 0.0001
103 |         start_lr: 0.0001
104 |         final_lr: 0.0
105 |         warmup: 0.
106 | 
107 |       - weight_decay: 0.01
108 |         final_weight_decay: 0.01
109 |         lr: 0.005
110 |         start_lr: 0.005
111 |         final_lr: 0.0
112 |         warmup: 0.
113 |       - weight_decay: 0.01
114 |         final_weight_decay: 0.01
115 |         lr: 0.003
116 |         start_lr: 0.003
117 |         final_lr: 0.0
118 |         warmup: 0.
119 |       - weight_decay: 0.01
120 |         final_weight_decay: 0.01
121 |         lr: 0.001
122 |         start_lr: 0.001
123 |         final_lr: 0.0
124 |         warmup: 0.
125 |       - weight_decay: 0.01
126 |         final_weight_decay: 0.01
127 |         lr: 0.0003
128 |         start_lr: 0.0003
129 |         final_lr: 0.0
130 |         warmup: 0.
131 |       - weight_decay: 0.01
132 |         final_weight_decay: 0.01
133 |         lr: 0.0001
134 |         start_lr: 0.0001
135 |         final_lr: 0.0
136 |         warmup: 0.
137 | 
138 |       - weight_decay: 0.1
139 |         final_weight_decay: 0.1
140 |         lr: 0.005
141 |         start_lr: 0.005
142 |         final_lr: 0.0
143 |         warmup: 0.
144 |       - weight_decay: 0.1
145 |         final_weight_decay: 0.1
146 |         lr: 0.003
147 |         start_lr: 0.003
148 |         final_lr: 0.0
149 |         warmup: 0.
150 |       - weight_decay: 0.1
151 |         final_weight_decay: 0.1
152 |         lr: 0.001
153 |         start_lr: 0.001
154 |         final_lr: 0.0
155 |         warmup: 0.
156 |       - weight_decay: 0.1
157 |         final_weight_decay: 0.1
158 |         lr: 0.0003
159 |         start_lr: 0.0003
160 |         final_lr: 0.0
161 |         warmup: 0.
162 |       - weight_decay: 0.1
163 |         final_weight_decay: 0.1
164 |         lr: 0.0001
165 |         start_lr: 0.0001
166 |         final_lr: 0.0
167 |         warmup: 0.
168 | model_kwargs:
169 |   checkpoint: /your_vjepa2_checkpoints/vitg-384.pt
170 |   module_name: evals.action_anticipation_frozen.modelcustom.vit_encoder_predictor_concat_ar
171 |   wrapper_kwargs:
172 |     no_predictor: false
173 |     num_output_frames: 2
174 |     num_steps: 1
175 |   pretrain_kwargs:
176 |     encoder:
177 |       model_name: vit_giant_xformers
178 |       checkpoint_key: target_encoder
179 |       tubelet_size: 2
180 |       patch_size: 16
181 |       uniform_power: true
182 |       use_rope: true
183 |     predictor:
184 |       model_name: vit_predictor
185 |       checkpoint_key: predictor
186 |       num_frames: 64
187 |       depth: 12
188 |       num_heads: 12
189 |       predictor_embed_dim: 384
190 |       num_mask_tokens: 10
191 |       uniform_power: true
192 |       use_mask_tokens: true
193 |       use_sdpa: true
194 |       use_silu: false
195 |       wide_silu: false
196 |       use_rope: true
197 | 


--------------------------------------------------------------------------------
/configs/eval/vitg-384/in1k.yaml:
--------------------------------------------------------------------------------
  1 | cpus_per_task: 16
  2 | eval_name: image_classification_frozen
  3 | folder: /your_folder/evals/vitg-384/in1k
  4 | mem_per_gpu: 220G
  5 | nodes: 16
  6 | resume_checkpoint: true
  7 | tag: in1k-vitg16-384-18f
  8 | tasks_per_node: 8
  9 | experiment:
 10 |   classifier:
 11 |     num_heads: 16
 12 |     num_probe_blocks: 4
 13 |   data:
 14 |     dataset_name: ImageNet
 15 |     num_classes: 1000
 16 |     root_path: /datasets/
 17 |     image_folder: ImageNet_FullSize/240712/061417/
 18 |     resolution: 384
 19 |   optimization:
 20 |     batch_size: 8
 21 |     multihead_kwargs:
 22 |     - final_lr: 0.0
 23 |       final_weight_decay: 0.0
 24 |       lr: 0.001
 25 |       start_lr: 0.001
 26 |       warmup: 0.0
 27 |       weight_decay: 0.001
 28 |     - final_lr: 0.0
 29 |       final_weight_decay: 0.008
 30 |       lr: 0.0005
 31 |       start_lr: 0.0002
 32 |       warmup: 5
 33 |       weight_decay: 0.008
 34 |     - final_lr: 0.0
 35 |       final_weight_decay: 0.004
 36 |       lr: 0.0005
 37 |       start_lr: 0.0002
 38 |       warmup: 5
 39 |       weight_decay: 0.004
 40 |     - final_lr: 0.0
 41 |       final_weight_decay: 0.002
 42 |       lr: 0.0005
 43 |       start_lr: 0.0002
 44 |       warmup: 5
 45 |       weight_decay: 0.002
 46 |     - final_lr: 0.0
 47 |       final_weight_decay: 0.001
 48 |       lr: 0.0005
 49 |       start_lr: 0.0002
 50 |       warmup: 5
 51 |       weight_decay: 0.001
 52 |     - final_lr: 0.0
 53 |       final_weight_decay: 0.0005
 54 |       lr: 0.0005
 55 |       start_lr: 0.0002
 56 |       warmup: 5
 57 |       weight_decay: 0.0005
 58 |     - final_lr: 0.0
 59 |       final_weight_decay: 0.008
 60 |       lr: 0.001
 61 |       start_lr: 0.0002
 62 |       warmup: 5
 63 |       weight_decay: 0.008
 64 |     - final_lr: 0.0
 65 |       final_weight_decay: 0.004
 66 |       lr: 0.001
 67 |       start_lr: 0.0002
 68 |       warmup: 5
 69 |       weight_decay: 0.004
 70 |     - final_lr: 0.0
 71 |       final_weight_decay: 0.002
 72 |       lr: 0.001
 73 |       start_lr: 0.0002
 74 |       warmup: 5
 75 |       weight_decay: 0.002
 76 |     - final_lr: 0.0
 77 |       final_weight_decay: 0.001
 78 |       lr: 0.001
 79 |       start_lr: 0.0002
 80 |       warmup: 5
 81 |       weight_decay: 0.001
 82 |     - final_lr: 0.0
 83 |       final_weight_decay: 0.0005
 84 |       lr: 0.001
 85 |       start_lr: 0.0002
 86 |       warmup: 5
 87 |       weight_decay: 0.0005
 88 |     - final_lr: 0.0
 89 |       final_weight_decay: 0.008
 90 |       lr: 0.0015
 91 |       start_lr: 0.0002
 92 |       warmup: 5
 93 |       weight_decay: 0.008
 94 |     - final_lr: 0.0
 95 |       final_weight_decay: 0.004
 96 |       lr: 0.0015
 97 |       start_lr: 0.0002
 98 |       warmup: 5
 99 |       weight_decay: 0.004
100 |     - final_lr: 0.0
101 |       final_weight_decay: 0.002
102 |       lr: 0.0015
103 |       start_lr: 0.0002
104 |       warmup: 5
105 |       weight_decay: 0.002
106 |     - final_lr: 0.0
107 |       final_weight_decay: 0.001
108 |       lr: 0.0015
109 |       start_lr: 0.0002
110 |       warmup: 5
111 |       weight_decay: 0.001
112 |     - final_lr: 0.0
113 |       final_weight_decay: 0.0005
114 |       lr: 0.0015
115 |       start_lr: 0.0002
116 |       warmup: 5
117 |       weight_decay: 0.0005
118 |     - final_lr: 0.0
119 |       final_weight_decay: 0.008
120 |       lr: 0.002
121 |       start_lr: 0.0002
122 |       warmup: 5
123 |       weight_decay: 0.008
124 |     - final_lr: 0.0
125 |       final_weight_decay: 0.004
126 |       lr: 0.002
127 |       start_lr: 0.0002
128 |       warmup: 5
129 |       weight_decay: 0.004
130 |     - final_lr: 0.0
131 |       final_weight_decay: 0.002
132 |       lr: 0.002
133 |       start_lr: 0.0002
134 |       warmup: 5
135 |       weight_decay: 0.002
136 |     - final_lr: 0.0
137 |       final_weight_decay: 0.001
138 |       lr: 0.002
139 |       start_lr: 0.0002
140 |       warmup: 5
141 |       weight_decay: 0.001
142 |     - final_lr: 0.0
143 |       final_weight_decay: 0.0005
144 |       lr: 0.002
145 |       start_lr: 0.0002
146 |       warmup: 5
147 |       weight_decay: 0.0005
148 |     num_epochs: 20
149 |     use_bfloat16: true
150 | model_kwargs:
151 |   checkpoint: /your_vjepa2_checkpoints/vitg-384.pt
152 |   module_name: evals.image_classification_frozen.modelcustom.vit_encoder
153 |   pretrain_kwargs:
154 |     encoder:
155 |       checkpoint_key: target_encoder
156 |       img_temporal_dim_size: null
157 |       model_name: vit_giant_xformers
158 |       patch_size: 16
159 |       tubelet_size: 2
160 |       uniform_power: true
161 |       use_rope: true
162 |   wrapper_kwargs:
163 |     img_as_video_nframes: 18
164 | 


--------------------------------------------------------------------------------
/configs/eval/vitg-384/jester.yaml:
--------------------------------------------------------------------------------
 1 | nodes: 8
 2 | tasks_per_node: 8
 3 | cpus_per_task: 12
 4 | mem_per_gpu: 200G
 5 | tag: jester-vitg16-384-32x4x3
 6 | eval_name: video_classification_frozen
 7 | folder: /your_folder/evals/vitg-384/jester
 8 | resume_checkpoint: true
 9 | experiment:
10 |   classifier:
11 |     num_probe_blocks: 4
12 |     num_heads: 16
13 |   data:
14 |     dataset_type: VideoDataset
15 |     dataset_train: /your_data_dir/Jester/annotations/jester_train_paths.csv
16 |     dataset_val: /your_data_dir/Jester/annotations/jester_validation_paths.csv
17 |     num_classes: 27
18 |     resolution: 384
19 |     frames_per_clip: 32
20 |     frame_step: 2
21 |     num_segments: 4
22 |     num_views_per_segment: 3
23 |   optimization:
24 |     use_pos_embed: false
25 |     num_epochs: 100
26 |     batch_size: 2
27 |     use_bfloat16: true
28 |     multihead_kwargs:
29 |       - weight_decay: 0.8
30 |         final_weight_decay: 0.8
31 |         lr: 0.001
32 |         start_lr: 0.001
33 |         final_lr: 0.0
34 |         warmup: 0.
35 |       - weight_decay: 0.8
36 |         final_weight_decay: 0.8
37 |         lr: 0.0003
38 |         start_lr: 0.0003
39 |         final_lr: 0.0
40 |         warmup: 0.
41 |       - weight_decay: 0.8
42 |         final_weight_decay: 0.8
43 |         lr: 0.0001
44 |         start_lr: 0.0001
45 |         final_lr: 0.0
46 |         warmup: 0.
47 | model_kwargs:
48 |   checkpoint: /your_vjepa2_checkpoints/vitg-384.pt
49 |   module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip_multilevel
50 |   wrapper_kwargs:
51 |     max_frames: 128
52 |     use_pos_embed: false
53 |     out_layers: [24, 29, 34, 39]
54 |   pretrain_kwargs:
55 |     encoder:
56 |       model_name: vit_giant_xformers
57 |       checkpoint_key: target_encoder
58 |       tubelet_size: 2
59 |       patch_size: 16
60 |       uniform_power: true
61 |       use_rope: true
62 | 


--------------------------------------------------------------------------------
/configs/eval/vitg-384/k400.yaml:
--------------------------------------------------------------------------------
  1 | cpus_per_task: 16
  2 | eval_name: video_classification_frozen
  3 | folder: /your_folder/evals/vitg-384/k400
  4 | mem_per_gpu: 220G
  5 | nodes: 32
  6 | num_workers: 8
  7 | resume_checkpoint: true
  8 | tag: k400-vitg16-384-16x8x3-16f
  9 | tasks_per_node: 8
 10 | experiment:
 11 |   classifier:
 12 |     num_heads: 16
 13 |     num_probe_blocks: 4
 14 |   data:
 15 |     dataset_type: VideoDataset
 16 |     dataset_train: /your_data_path/k400_train_paths.csv
 17 |     dataset_val: /your_data_path/k400_val_paths.csv
 18 |     frame_step: 4
 19 |     frames_per_clip: 16
 20 |     num_classes: 400
 21 |     num_segments: 8
 22 |     num_views_per_segment: 3
 23 |     resolution: 384
 24 |   optimization:
 25 |     batch_size: 1
 26 |     multihead_kwargs:
 27 |     - final_lr: 0.0
 28 |       final_weight_decay: 0.01
 29 |       lr: 0.005
 30 |       start_lr: 0.005
 31 |       warmup: 0.0
 32 |       weight_decay: 0.01
 33 |     - final_lr: 0.0
 34 |       final_weight_decay: 0.01
 35 |       lr: 0.003
 36 |       start_lr: 0.003
 37 |       warmup: 0.0
 38 |       weight_decay: 0.01
 39 |     - final_lr: 0.0
 40 |       final_weight_decay: 0.01
 41 |       lr: 0.001
 42 |       start_lr: 0.001
 43 |       warmup: 0.0
 44 |       weight_decay: 0.01
 45 |     - final_lr: 0.0
 46 |       final_weight_decay: 0.01
 47 |       lr: 0.0003
 48 |       start_lr: 0.0003
 49 |       warmup: 0.0
 50 |       weight_decay: 0.01
 51 |     - final_lr: 0.0
 52 |       final_weight_decay: 0.01
 53 |       lr: 0.0001
 54 |       start_lr: 0.0001
 55 |       warmup: 0.0
 56 |       weight_decay: 0.01
 57 |     - final_lr: 0.0
 58 |       final_weight_decay: 0.1
 59 |       lr: 0.005
 60 |       start_lr: 0.005
 61 |       warmup: 0.0
 62 |       weight_decay: 0.1
 63 |     - final_lr: 0.0
 64 |       final_weight_decay: 0.1
 65 |       lr: 0.003
 66 |       start_lr: 0.003
 67 |       warmup: 0.0
 68 |       weight_decay: 0.1
 69 |     - final_lr: 0.0
 70 |       final_weight_decay: 0.1
 71 |       lr: 0.001
 72 |       start_lr: 0.001
 73 |       warmup: 0.0
 74 |       weight_decay: 0.1
 75 |     - final_lr: 0.0
 76 |       final_weight_decay: 0.1
 77 |       lr: 0.0003
 78 |       start_lr: 0.0003
 79 |       warmup: 0.0
 80 |       weight_decay: 0.1
 81 |     - final_lr: 0.0
 82 |       final_weight_decay: 0.1
 83 |       lr: 0.0001
 84 |       start_lr: 0.0001
 85 |       warmup: 0.0
 86 |       weight_decay: 0.1
 87 |     - final_lr: 0.0
 88 |       final_weight_decay: 0.4
 89 |       lr: 0.005
 90 |       start_lr: 0.005
 91 |       warmup: 0.0
 92 |       weight_decay: 0.4
 93 |     - final_lr: 0.0
 94 |       final_weight_decay: 0.4
 95 |       lr: 0.003
 96 |       start_lr: 0.003
 97 |       warmup: 0.0
 98 |       weight_decay: 0.4
 99 |     - final_lr: 0.0
100 |       final_weight_decay: 0.4
101 |       lr: 0.001
102 |       start_lr: 0.001
103 |       warmup: 0.0
104 |       weight_decay: 0.4
105 |     - final_lr: 0.0
106 |       final_weight_decay: 0.4
107 |       lr: 0.0003
108 |       start_lr: 0.0003
109 |       warmup: 0.0
110 |       weight_decay: 0.4
111 |     - final_lr: 0.0
112 |       final_weight_decay: 0.4
113 |       lr: 0.0001
114 |       start_lr: 0.0001
115 |       warmup: 0.0
116 |       weight_decay: 0.4
117 |     - final_lr: 0.0
118 |       final_weight_decay: 0.8
119 |       lr: 0.005
120 |       start_lr: 0.005
121 |       warmup: 0.0
122 |       weight_decay: 0.8
123 |     - final_lr: 0.0
124 |       final_weight_decay: 0.8
125 |       lr: 0.003
126 |       start_lr: 0.003
127 |       warmup: 0.0
128 |       weight_decay: 0.8
129 |     - final_lr: 0.0
130 |       final_weight_decay: 0.8
131 |       lr: 0.001
132 |       start_lr: 0.001
133 |       warmup: 0.0
134 |       weight_decay: 0.8
135 |     - final_lr: 0.0
136 |       final_weight_decay: 0.8
137 |       lr: 0.0003
138 |       start_lr: 0.0003
139 |       warmup: 0.0
140 |       weight_decay: 0.8
141 |     - final_lr: 0.0
142 |       final_weight_decay: 0.8
143 |       lr: 0.0001
144 |       start_lr: 0.0001
145 |       warmup: 0.0
146 |       weight_decay: 0.8
147 |     num_epochs: 20
148 |     use_bfloat16: true
149 |     use_pos_embed: false
150 | model_kwargs:
151 |   checkpoint: /your_vjepa2_checkpoints/vitg-384.pt
152 |   module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip
153 |   pretrain_kwargs:
154 |     encoder:
155 |       checkpoint_key: target_encoder
156 |       img_temporal_dim_size: null
157 |       model_name: vit_giant_xformers
158 |       patch_size: 16
159 |       tubelet_size: 2
160 |       uniform_power: true
161 |       use_rope: true
162 |   wrapper_kwargs:
163 |     max_frames: 128
164 |     use_pos_embed: false
165 | 


--------------------------------------------------------------------------------
/configs/eval/vitg-384/ssv2.yaml:
--------------------------------------------------------------------------------
  1 | cpus_per_task: 16
  2 | eval_name: video_classification_frozen
  3 | folder: /your_folder/evals/vitg-384/ssv2
  4 | mem_per_gpu: 220G
  5 | nodes: 16
  6 | num_workers: 8
  7 | resume_checkpoint: true
  8 | tag: ssv2-vitg16-384-64x2x3
  9 | tasks_per_node: 8
 10 | experiment:
 11 |   classifier:
 12 |     num_heads: 16
 13 |     num_probe_blocks: 4
 14 |   data:
 15 |     dataset_type: VideoDataset
 16 |     dataset_train: /your_data_path/ssv2_train_paths.csv
 17 |     dataset_val: /your_data_path/ssv2_val_paths.csv
 18 |     frame_step: 2
 19 |     frames_per_clip: 64
 20 |     num_classes: 174
 21 |     num_segments: 2
 22 |     num_views_per_segment: 3
 23 |     resolution: 384
 24 |   optimization:
 25 |     batch_size: 2
 26 |     multihead_kwargs:
 27 |     - final_lr: 0.0
 28 |       final_weight_decay: 0.01
 29 |       lr: 0.005
 30 |       start_lr: 0.005
 31 |       warmup: 0.0
 32 |       weight_decay: 0.01
 33 |     - final_lr: 0.0
 34 |       final_weight_decay: 0.01
 35 |       lr: 0.003
 36 |       start_lr: 0.003
 37 |       warmup: 0.0
 38 |       weight_decay: 0.01
 39 |     - final_lr: 0.0
 40 |       final_weight_decay: 0.01
 41 |       lr: 0.001
 42 |       start_lr: 0.001
 43 |       warmup: 0.0
 44 |       weight_decay: 0.01
 45 |     - final_lr: 0.0
 46 |       final_weight_decay: 0.01
 47 |       lr: 0.0003
 48 |       start_lr: 0.0003
 49 |       warmup: 0.0
 50 |       weight_decay: 0.01
 51 |     - final_lr: 0.0
 52 |       final_weight_decay: 0.01
 53 |       lr: 0.0001
 54 |       start_lr: 0.0001
 55 |       warmup: 0.0
 56 |       weight_decay: 0.01
 57 |     - final_lr: 0.0
 58 |       final_weight_decay: 0.1
 59 |       lr: 0.005
 60 |       start_lr: 0.005
 61 |       warmup: 0.0
 62 |       weight_decay: 0.1
 63 |     - final_lr: 0.0
 64 |       final_weight_decay: 0.1
 65 |       lr: 0.003
 66 |       start_lr: 0.003
 67 |       warmup: 0.0
 68 |       weight_decay: 0.1
 69 |     - final_lr: 0.0
 70 |       final_weight_decay: 0.1
 71 |       lr: 0.001
 72 |       start_lr: 0.001
 73 |       warmup: 0.0
 74 |       weight_decay: 0.1
 75 |     - final_lr: 0.0
 76 |       final_weight_decay: 0.1
 77 |       lr: 0.0003
 78 |       start_lr: 0.0003
 79 |       warmup: 0.0
 80 |       weight_decay: 0.1
 81 |     - final_lr: 0.0
 82 |       final_weight_decay: 0.1
 83 |       lr: 0.0001
 84 |       start_lr: 0.0001
 85 |       warmup: 0.0
 86 |       weight_decay: 0.1
 87 |     - final_lr: 0.0
 88 |       final_weight_decay: 0.4
 89 |       lr: 0.005
 90 |       start_lr: 0.005
 91 |       warmup: 0.0
 92 |       weight_decay: 0.4
 93 |     - final_lr: 0.0
 94 |       final_weight_decay: 0.4
 95 |       lr: 0.003
 96 |       start_lr: 0.003
 97 |       warmup: 0.0
 98 |       weight_decay: 0.4
 99 |     - final_lr: 0.0
100 |       final_weight_decay: 0.4
101 |       lr: 0.001
102 |       start_lr: 0.001
103 |       warmup: 0.0
104 |       weight_decay: 0.4
105 |     - final_lr: 0.0
106 |       final_weight_decay: 0.4
107 |       lr: 0.0003
108 |       start_lr: 0.0003
109 |       warmup: 0.0
110 |       weight_decay: 0.4
111 |     - final_lr: 0.0
112 |       final_weight_decay: 0.4
113 |       lr: 0.0001
114 |       start_lr: 0.0001
115 |       warmup: 0.0
116 |       weight_decay: 0.4
117 |     - final_lr: 0.0
118 |       final_weight_decay: 0.8
119 |       lr: 0.005
120 |       start_lr: 0.005
121 |       warmup: 0.0
122 |       weight_decay: 0.8
123 |     - final_lr: 0.0
124 |       final_weight_decay: 0.8
125 |       lr: 0.003
126 |       start_lr: 0.003
127 |       warmup: 0.0
128 |       weight_decay: 0.8
129 |     - final_lr: 0.0
130 |       final_weight_decay: 0.8
131 |       lr: 0.001
132 |       start_lr: 0.001
133 |       warmup: 0.0
134 |       weight_decay: 0.8
135 |     - final_lr: 0.0
136 |       final_weight_decay: 0.8
137 |       lr: 0.0003
138 |       start_lr: 0.0003
139 |       warmup: 0.0
140 |       weight_decay: 0.8
141 |     - final_lr: 0.0
142 |       final_weight_decay: 0.8
143 |       lr: 0.0001
144 |       start_lr: 0.0001
145 |       warmup: 0.0
146 |       weight_decay: 0.8
147 |     num_epochs: 20
148 |     use_bfloat16: true
149 |     use_pos_embed: false
150 | model_kwargs:
151 |   checkpoint: /your_vjepa2_checkpoints/vitg-384.pt
152 |   module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip
153 |   pretrain_kwargs:
154 |     encoder:
155 |       checkpoint_key: target_encoder
156 |       img_temporal_dim_size: null
157 |       model_name: vit_giant_xformers
158 |       patch_size: 16
159 |       tubelet_size: 2
160 |       uniform_power: true
161 |       use_rope: true
162 |   wrapper_kwargs:
163 |     max_frames: 128
164 |     use_pos_embed: false
165 | 


--------------------------------------------------------------------------------
/configs/eval/vitl/coin.yaml:
--------------------------------------------------------------------------------
  1 | cpus_per_task: 16
  2 | eval_name: video_classification_frozen
  3 | folder: /your_folder/evals/vitl/coin
  4 | mem_per_gpu: 220G
  5 | nodes: 8
  6 | resume_checkpoint: true
  7 | tag: coin-vitl16-384-16x8x3
  8 | tasks_per_node: 8
  9 | experiment:
 10 |   classifier:
 11 |     num_heads: 16
 12 |     num_probe_blocks: 4
 13 |   data:
 14 |     dataset_type: VideoDataset
 15 |     dataset_train: /your_data_folder/COIN/train_paths.csv
 16 |     dataset_val: /your_data_folder/COIN/val_paths.csv
 17 |     frame_step: 4
 18 |     frames_per_clip: 16
 19 |     num_classes: 180
 20 |     num_segments: 8
 21 |     num_views_per_segment: 3
 22 |     resolution: 256
 23 |   optimization:
 24 |     batch_size: 2
 25 |     multihead_kwargs:
 26 |     - final_lr: 0.0
 27 |       final_weight_decay: 0.01
 28 |       lr: 0.005
 29 |       start_lr: 0.005
 30 |       warmup: 0.0
 31 |       weight_decay: 0.01
 32 |     - final_lr: 0.0
 33 |       final_weight_decay: 0.01
 34 |       lr: 0.003
 35 |       start_lr: 0.003
 36 |       warmup: 0.0
 37 |       weight_decay: 0.01
 38 |     - final_lr: 0.0
 39 |       final_weight_decay: 0.01
 40 |       lr: 0.001
 41 |       start_lr: 0.001
 42 |       warmup: 0.0
 43 |       weight_decay: 0.01
 44 |     - final_lr: 0.0
 45 |       final_weight_decay: 0.01
 46 |       lr: 0.0003
 47 |       start_lr: 0.0003
 48 |       warmup: 0.0
 49 |       weight_decay: 0.01
 50 |     - final_lr: 0.0
 51 |       final_weight_decay: 0.01
 52 |       lr: 0.0001
 53 |       start_lr: 0.0001
 54 |       warmup: 0.0
 55 |       weight_decay: 0.01
 56 |     - final_lr: 0.0
 57 |       final_weight_decay: 0.1
 58 |       lr: 0.005
 59 |       start_lr: 0.005
 60 |       warmup: 0.0
 61 |       weight_decay: 0.1
 62 |     - final_lr: 0.0
 63 |       final_weight_decay: 0.1
 64 |       lr: 0.003
 65 |       start_lr: 0.003
 66 |       warmup: 0.0
 67 |       weight_decay: 0.1
 68 |     - final_lr: 0.0
 69 |       final_weight_decay: 0.1
 70 |       lr: 0.001
 71 |       start_lr: 0.001
 72 |       warmup: 0.0
 73 |       weight_decay: 0.1
 74 |     - final_lr: 0.0
 75 |       final_weight_decay: 0.1
 76 |       lr: 0.0003
 77 |       start_lr: 0.0003
 78 |       warmup: 0.0
 79 |       weight_decay: 0.1
 80 |     - final_lr: 0.0
 81 |       final_weight_decay: 0.1
 82 |       lr: 0.0001
 83 |       start_lr: 0.0001
 84 |       warmup: 0.0
 85 |       weight_decay: 0.1
 86 |     - final_lr: 0.0
 87 |       final_weight_decay: 0.4
 88 |       lr: 0.005
 89 |       start_lr: 0.005
 90 |       warmup: 0.0
 91 |       weight_decay: 0.4
 92 |     - final_lr: 0.0
 93 |       final_weight_decay: 0.4
 94 |       lr: 0.003
 95 |       start_lr: 0.003
 96 |       warmup: 0.0
 97 |       weight_decay: 0.4
 98 |     - final_lr: 0.0
 99 |       final_weight_decay: 0.4
100 |       lr: 0.001
101 |       start_lr: 0.001
102 |       warmup: 0.0
103 |       weight_decay: 0.4
104 |     - final_lr: 0.0
105 |       final_weight_decay: 0.4
106 |       lr: 0.0003
107 |       start_lr: 0.0003
108 |       warmup: 0.0
109 |       weight_decay: 0.4
110 |     - final_lr: 0.0
111 |       final_weight_decay: 0.4
112 |       lr: 0.0001
113 |       start_lr: 0.0001
114 |       warmup: 0.0
115 |       weight_decay: 0.4
116 |     - final_lr: 0.0
117 |       final_weight_decay: 0.8
118 |       lr: 0.005
119 |       start_lr: 0.005
120 |       warmup: 0.0
121 |       weight_decay: 0.8
122 |     - final_lr: 0.0
123 |       final_weight_decay: 0.8
124 |       lr: 0.003
125 |       start_lr: 0.003
126 |       warmup: 0.0
127 |       weight_decay: 0.8
128 |     - final_lr: 0.0
129 |       final_weight_decay: 0.8
130 |       lr: 0.001
131 |       start_lr: 0.001
132 |       warmup: 0.0
133 |       weight_decay: 0.8
134 |     - final_lr: 0.0
135 |       final_weight_decay: 0.8
136 |       lr: 0.0003
137 |       start_lr: 0.0003
138 |       warmup: 0.0
139 |       weight_decay: 0.8
140 |     - final_lr: 0.0
141 |       final_weight_decay: 0.8
142 |       lr: 0.0001
143 |       start_lr: 0.0001
144 |       warmup: 0.0
145 |       weight_decay: 0.8
146 |     num_epochs: 20
147 |     use_bfloat16: true
148 |     use_pos_embed: false
149 | model_kwargs:
150 |   checkpoint: /your_vjepa2_checkpoints/vitl.pt
151 |   module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip
152 |   pretrain_kwargs:
153 |     encoder:
154 |       checkpoint_key: target_encoder
155 |       img_temporal_dim_size: null
156 |       model_name: vit_large
157 |       patch_size: 16
158 |       tubelet_size: 2
159 |       uniform_power: true
160 |       use_rope: true
161 |   wrapper_kwargs:
162 |     max_frames: 128
163 |     use_pos_embed: false
164 | 


--------------------------------------------------------------------------------
/configs/eval/vitl/diving48.yaml:
--------------------------------------------------------------------------------
 1 | nodes: 8
 2 | tasks_per_node: 8
 3 | cpus_per_task: 12
 4 | mem_per_gpu: 200G
 5 | tag: diving48-vitl16-32x4x3
 6 | eval_name: video_classification_frozen
 7 | folder: /your_folder/evals/vitl/diving48
 8 | resume_checkpoint: true
 9 | experiment:
10 |   classifier:
11 |     num_probe_blocks: 4
12 |     num_heads: 16
13 |   data:
14 |     dataset_type: VideoDataset
15 |     dataset_train: /your_data_dir/diving48/annotations/Diving48_train_paths.csv
16 |     dataset_val: /your_data_dir/diving48/annotations/Diving48_test_paths.csv
17 |     num_classes: 48
18 |     resolution: 256
19 |     frames_per_clip: 32
20 |     frame_step: 2
21 |     num_segments: 4
22 |     num_views_per_segment: 3
23 |   optimization:
24 |     use_pos_embed: false
25 |     num_epochs: 100
26 |     batch_size: 2
27 |     use_bfloat16: true
28 |     multihead_kwargs:
29 |       - weight_decay: 0.8
30 |         final_weight_decay: 0.8
31 |         lr: 0.001
32 |         start_lr: 0.001
33 |         final_lr: 0.0
34 |         warmup: 0.
35 |       - weight_decay: 0.8
36 |         final_weight_decay: 0.8
37 |         lr: 0.0003
38 |         start_lr: 0.0003
39 |         final_lr: 0.0
40 |         warmup: 0.
41 |       - weight_decay: 0.8
42 |         final_weight_decay: 0.8
43 |         lr: 0.0001
44 |         start_lr: 0.0001
45 |         final_lr: 0.0
46 |         warmup: 0.
47 | model_kwargs:
48 |   checkpoint: /your_vjepa2_checkpoints/vitl.pt
49 |   module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip_multilevel
50 |   wrapper_kwargs:
51 |     max_frames: 128
52 |     use_pos_embed: false
53 |     out_layers: [17, 19, 21, 23]
54 |   pretrain_kwargs:
55 |     encoder:
56 |       model_name: vit_large
57 |       checkpoint_key: target_encoder
58 |       tubelet_size: 2
59 |       patch_size: 16
60 |       uniform_power: true
61 |       use_rope: true
62 | 


--------------------------------------------------------------------------------
/configs/eval/vitl/ek100.yaml:
--------------------------------------------------------------------------------
  1 | nodes: 8
  2 | tasks_per_node: 8
  3 | cpus_per_task: 12
  4 | tag: ek100-vitl16
  5 | eval_name: action_anticipation_frozen
  6 | folder: /your_folder/evals/vitl/ek100
  7 | resume_checkpoint: true
  8 | experiment:
  9 |   classifier:
 10 |     num_probe_blocks: 4
 11 |     num_heads: 16
 12 |   data:
 13 |     anticipation_time_sec:
 14 |     - 1.0
 15 |     - 1.0
 16 |     auto_augment: true
 17 |     file_format: 0
 18 |     dataset: EK100
 19 |     # e.g. /home/username/EPIC-KITCHENS
 20 |     base_path: /your_ek100_root_dir/
 21 |     dataset_train: /your_data/EPIC_100_train.csv
 22 |     dataset_val: /your_data/EPIC_100_validation.csv
 23 |     frames_per_clip: 32
 24 |     frames_per_second: 8
 25 |     motion_shift: false
 26 |     num_workers: 2
 27 |     pin_memory: true
 28 |     random_resize_scale:
 29 |     - 0.08
 30 |     - 1.0
 31 |     reprob: 0.25
 32 |     resolution: 256
 33 |     train_anticipation_point:
 34 |     - 0.0
 35 |     - 0.25
 36 |     train_anticipation_time_sec:
 37 |     - 0.25
 38 |     - 1.75
 39 |   optimization:
 40 |     num_epochs: 20
 41 |     batch_size: 2
 42 |     use_bfloat16: true
 43 |     use_focal_loss: true
 44 |     multihead_kwargs:
 45 |       - weight_decay: 0.0001
 46 |         final_weight_decay: 0.0001
 47 |         lr: 0.005
 48 |         start_lr: 0.005
 49 |         final_lr: 0.0
 50 |         warmup: 0.
 51 |       - weight_decay: 0.0001
 52 |         final_weight_decay: 0.0001
 53 |         lr: 0.003
 54 |         start_lr: 0.003
 55 |         final_lr: 0.0
 56 |         warmup: 0.
 57 |       - weight_decay: 0.0001
 58 |         final_weight_decay: 0.0001
 59 |         lr: 0.001
 60 |         start_lr: 0.001
 61 |         final_lr: 0.0
 62 |         warmup: 0.
 63 |       - weight_decay: 0.0001
 64 |         final_weight_decay: 0.0001
 65 |         lr: 0.0003
 66 |         start_lr: 0.0003
 67 |         final_lr: 0.0
 68 |         warmup: 0.
 69 |       - weight_decay: 0.0001
 70 |         final_weight_decay: 0.0001
 71 |         lr: 0.0001
 72 |         start_lr: 0.0001
 73 |         final_lr: 0.0
 74 |         warmup: 0.
 75 | 
 76 |       - weight_decay: 0.001
 77 |         final_weight_decay: 0.001
 78 |         lr: 0.005
 79 |         start_lr: 0.005
 80 |         final_lr: 0.0
 81 |         warmup: 0.
 82 |       - weight_decay: 0.001
 83 |         final_weight_decay: 0.001
 84 |         lr: 0.003
 85 |         start_lr: 0.003
 86 |         final_lr: 0.0
 87 |         warmup: 0.
 88 |       - weight_decay: 0.001
 89 |         final_weight_decay: 0.001
 90 |         lr: 0.001
 91 |         start_lr: 0.001
 92 |         final_lr: 0.0
 93 |         warmup: 0.
 94 |       - weight_decay: 0.001
 95 |         final_weight_decay: 0.001
 96 |         lr: 0.0003
 97 |         start_lr: 0.0003
 98 |         final_lr: 0.0
 99 |         warmup: 0.
100 |       - weight_decay: 0.001
101 |         final_weight_decay: 0.001
102 |         lr: 0.0001
103 |         start_lr: 0.0001
104 |         final_lr: 0.0
105 |         warmup: 0.
106 | 
107 |       - weight_decay: 0.01
108 |         final_weight_decay: 0.01
109 |         lr: 0.005
110 |         start_lr: 0.005
111 |         final_lr: 0.0
112 |         warmup: 0.
113 |       - weight_decay: 0.01
114 |         final_weight_decay: 0.01
115 |         lr: 0.003
116 |         start_lr: 0.003
117 |         final_lr: 0.0
118 |         warmup: 0.
119 |       - weight_decay: 0.01
120 |         final_weight_decay: 0.01
121 |         lr: 0.001
122 |         start_lr: 0.001
123 |         final_lr: 0.0
124 |         warmup: 0.
125 |       - weight_decay: 0.01
126 |         final_weight_decay: 0.01
127 |         lr: 0.0003
128 |         start_lr: 0.0003
129 |         final_lr: 0.0
130 |         warmup: 0.
131 |       - weight_decay: 0.01
132 |         final_weight_decay: 0.01
133 |         lr: 0.0001
134 |         start_lr: 0.0001
135 |         final_lr: 0.0
136 |         warmup: 0.
137 | 
138 |       - weight_decay: 0.1
139 |         final_weight_decay: 0.1
140 |         lr: 0.005
141 |         start_lr: 0.005
142 |         final_lr: 0.0
143 |         warmup: 0.
144 |       - weight_decay: 0.1
145 |         final_weight_decay: 0.1
146 |         lr: 0.003
147 |         start_lr: 0.003
148 |         final_lr: 0.0
149 |         warmup: 0.
150 |       - weight_decay: 0.1
151 |         final_weight_decay: 0.1
152 |         lr: 0.001
153 |         start_lr: 0.001
154 |         final_lr: 0.0
155 |         warmup: 0.
156 |       - weight_decay: 0.1
157 |         final_weight_decay: 0.1
158 |         lr: 0.0003
159 |         start_lr: 0.0003
160 |         final_lr: 0.0
161 |         warmup: 0.
162 |       - weight_decay: 0.1
163 |         final_weight_decay: 0.1
164 |         lr: 0.0001
165 |         start_lr: 0.0001
166 |         final_lr: 0.0
167 |         warmup: 0.
168 | model_kwargs:
169 |   checkpoint: /your_vjepa2_checkpoints/vitl.pt
170 |   module_name: evals.action_anticipation_frozen.modelcustom.vit_encoder_predictor_concat_ar
171 |   wrapper_kwargs:
172 |     no_predictor: false
173 |     num_output_frames: 2
174 |     num_steps: 1
175 |   pretrain_kwargs:
176 |     encoder:
177 |       model_name: vit_large
178 |       checkpoint_key: target_encoder
179 |       tubelet_size: 2
180 |       patch_size: 16
181 |       uniform_power: true
182 |       use_rope: true
183 |     predictor:
184 |       model_name: vit_predictor
185 |       checkpoint_key: predictor
186 |       num_frames: 64
187 |       depth: 12
188 |       num_heads: 12
189 |       predictor_embed_dim: 384
190 |       num_mask_tokens: 10
191 |       uniform_power: true
192 |       use_mask_tokens: true
193 |       use_sdpa: true
194 |       use_silu: false
195 |       wide_silu: false
196 |       use_rope: true
197 | 


--------------------------------------------------------------------------------
/configs/eval/vitl/in1k.yaml:
--------------------------------------------------------------------------------
  1 | cpus_per_task: 16
  2 | eval_name: image_classification_frozen
  3 | folder: /your_folder/evals/vitl/in1k
  4 | mem_per_gpu: 220G
  5 | nodes: 8
  6 | resume_checkpoint: true
  7 | tag: in1k-vitl16-384-18f
  8 | tasks_per_node: 8
  9 | experiment:
 10 |   classifier:
 11 |     num_heads: 16
 12 |     num_probe_blocks: 4
 13 |   data:
 14 |     dataset_name: ImageNet
 15 |     num_classes: 1000
 16 |     root_path: /datasets/
 17 |     image_folder: ImageNet_FullSize/240712/061417/
 18 |     resolution: 256
 19 |   optimization:
 20 |     batch_size: 16
 21 |     multihead_kwargs:
 22 |     - final_lr: 0.0
 23 |       final_weight_decay: 0.0
 24 |       lr: 0.001
 25 |       start_lr: 0.001
 26 |       warmup: 0.0
 27 |       weight_decay: 0.001
 28 |     - final_lr: 0.0
 29 |       final_weight_decay: 0.008
 30 |       lr: 0.0005
 31 |       start_lr: 0.0002
 32 |       warmup: 5
 33 |       weight_decay: 0.008
 34 |     - final_lr: 0.0
 35 |       final_weight_decay: 0.004
 36 |       lr: 0.0005
 37 |       start_lr: 0.0002
 38 |       warmup: 5
 39 |       weight_decay: 0.004
 40 |     - final_lr: 0.0
 41 |       final_weight_decay: 0.002
 42 |       lr: 0.0005
 43 |       start_lr: 0.0002
 44 |       warmup: 5
 45 |       weight_decay: 0.002
 46 |     - final_lr: 0.0
 47 |       final_weight_decay: 0.001
 48 |       lr: 0.0005
 49 |       start_lr: 0.0002
 50 |       warmup: 5
 51 |       weight_decay: 0.001
 52 |     - final_lr: 0.0
 53 |       final_weight_decay: 0.0005
 54 |       lr: 0.0005
 55 |       start_lr: 0.0002
 56 |       warmup: 5
 57 |       weight_decay: 0.0005
 58 |     - final_lr: 0.0
 59 |       final_weight_decay: 0.008
 60 |       lr: 0.001
 61 |       start_lr: 0.0002
 62 |       warmup: 5
 63 |       weight_decay: 0.008
 64 |     - final_lr: 0.0
 65 |       final_weight_decay: 0.004
 66 |       lr: 0.001
 67 |       start_lr: 0.0002
 68 |       warmup: 5
 69 |       weight_decay: 0.004
 70 |     - final_lr: 0.0
 71 |       final_weight_decay: 0.002
 72 |       lr: 0.001
 73 |       start_lr: 0.0002
 74 |       warmup: 5
 75 |       weight_decay: 0.002
 76 |     - final_lr: 0.0
 77 |       final_weight_decay: 0.001
 78 |       lr: 0.001
 79 |       start_lr: 0.0002
 80 |       warmup: 5
 81 |       weight_decay: 0.001
 82 |     - final_lr: 0.0
 83 |       final_weight_decay: 0.0005
 84 |       lr: 0.001
 85 |       start_lr: 0.0002
 86 |       warmup: 5
 87 |       weight_decay: 0.0005
 88 |     - final_lr: 0.0
 89 |       final_weight_decay: 0.008
 90 |       lr: 0.0015
 91 |       start_lr: 0.0002
 92 |       warmup: 5
 93 |       weight_decay: 0.008
 94 |     - final_lr: 0.0
 95 |       final_weight_decay: 0.004
 96 |       lr: 0.0015
 97 |       start_lr: 0.0002
 98 |       warmup: 5
 99 |       weight_decay: 0.004
100 |     - final_lr: 0.0
101 |       final_weight_decay: 0.002
102 |       lr: 0.0015
103 |       start_lr: 0.0002
104 |       warmup: 5
105 |       weight_decay: 0.002
106 |     - final_lr: 0.0
107 |       final_weight_decay: 0.001
108 |       lr: 0.0015
109 |       start_lr: 0.0002
110 |       warmup: 5
111 |       weight_decay: 0.001
112 |     - final_lr: 0.0
113 |       final_weight_decay: 0.0005
114 |       lr: 0.0015
115 |       start_lr: 0.0002
116 |       warmup: 5
117 |       weight_decay: 0.0005
118 |     - final_lr: 0.0
119 |       final_weight_decay: 0.008
120 |       lr: 0.002
121 |       start_lr: 0.0002
122 |       warmup: 5
123 |       weight_decay: 0.008
124 |     - final_lr: 0.0
125 |       final_weight_decay: 0.004
126 |       lr: 0.002
127 |       start_lr: 0.0002
128 |       warmup: 5
129 |       weight_decay: 0.004
130 |     - final_lr: 0.0
131 |       final_weight_decay: 0.002
132 |       lr: 0.002
133 |       start_lr: 0.0002
134 |       warmup: 5
135 |       weight_decay: 0.002
136 |     - final_lr: 0.0
137 |       final_weight_decay: 0.001
138 |       lr: 0.002
139 |       start_lr: 0.0002
140 |       warmup: 5
141 |       weight_decay: 0.001
142 |     - final_lr: 0.0
143 |       final_weight_decay: 0.0005
144 |       lr: 0.002
145 |       start_lr: 0.0002
146 |       warmup: 5
147 |       weight_decay: 0.0005
148 |     num_epochs: 20
149 |     use_bfloat16: true
150 | model_kwargs:
151 |   checkpoint: /your_vjepa2_checkpoints/vitl.pt
152 |   module_name: evals.image_classification_frozen.modelcustom.vit_encoder
153 |   pretrain_kwargs:
154 |     encoder:
155 |       checkpoint_key: target_encoder
156 |       img_temporal_dim_size: null
157 |       model_name: vit_large
158 |       patch_size: 16
159 |       tubelet_size: 2
160 |       uniform_power: true
161 |       use_rope: true
162 |   wrapper_kwargs:
163 |     img_as_video_nframes: 16
164 | 


--------------------------------------------------------------------------------
/configs/eval/vitl/jester.yaml:
--------------------------------------------------------------------------------
 1 | nodes: 8
 2 | tasks_per_node: 8
 3 | cpus_per_task: 12
 4 | mem_per_gpu: 200G
 5 | tag: jester-vitl16-32x4x3
 6 | eval_name: video_classification_frozen
 7 | folder: /your_folder/evals/vitl/jester
 8 | resume_checkpoint: true
 9 | experiment:
10 |   classifier:
11 |     num_probe_blocks: 4
12 |     num_heads: 16
13 |   data:
14 |     dataset_type: VideoDataset
15 |     dataset_train: /your_data_dir/Jester/annotations/jester_train_paths.csv
16 |     dataset_val: /your_data_dir/Jester/annotations/jester_validation_paths.csv
17 |     num_classes: 27
18 |     resolution: 256
19 |     frames_per_clip: 32
20 |     frame_step: 2
21 |     num_segments: 4
22 |     num_views_per_segment: 3
23 |   optimization:
24 |     use_pos_embed: false
25 |     num_epochs: 100
26 |     batch_size: 2
27 |     use_bfloat16: true
28 |     multihead_kwargs:
29 |       - weight_decay: 0.8
30 |         final_weight_decay: 0.8
31 |         lr: 0.001
32 |         start_lr: 0.001
33 |         final_lr: 0.0
34 |         warmup: 0.
35 |       - weight_decay: 0.8
36 |         final_weight_decay: 0.8
37 |         lr: 0.0003
38 |         start_lr: 0.0003
39 |         final_lr: 0.0
40 |         warmup: 0.
41 |       - weight_decay: 0.8
42 |         final_weight_decay: 0.8
43 |         lr: 0.0001
44 |         start_lr: 0.0001
45 |         final_lr: 0.0
46 |         warmup: 0.
47 | model_kwargs:
48 |   checkpoint: /your_vjepa2_checkpoints/vitl.pt
49 |   module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip_multilevel
50 |   wrapper_kwargs:
51 |     max_frames: 128
52 |     use_pos_embed: false
53 |     out_layers: [17, 19, 21, 23]
54 |   pretrain_kwargs:
55 |     encoder:
56 |       model_name: vit_large
57 |       checkpoint_key: target_encoder
58 |       tubelet_size: 2
59 |       patch_size: 16
60 |       uniform_power: true
61 |       use_rope: true
62 | 


--------------------------------------------------------------------------------
/configs/eval/vitl/k400.yaml:
--------------------------------------------------------------------------------
  1 | cpus_per_task: 16
  2 | eval_name: video_classification_frozen
  3 | folder: /your_folder/evals/vitl/k400
  4 | mem_per_gpu: 220G
  5 | nodes: 8
  6 | num_workers: 8
  7 | resume_checkpoint: true
  8 | tag: k400-vitl16-16x8x3-16f
  9 | tasks_per_node: 8
 10 | experiment:
 11 |   classifier:
 12 |     num_heads: 16
 13 |     num_probe_blocks: 4
 14 |   data:
 15 |     dataset_type: VideoDataset
 16 |     dataset_train: /your_data_path/k400_train_paths.csv
 17 |     dataset_val: /your_data_path/k400_val_paths.csv
 18 |     frame_step: 4
 19 |     frames_per_clip: 16
 20 |     num_classes: 400
 21 |     num_segments: 8
 22 |     num_views_per_segment: 3
 23 |     resolution: 256
 24 |   optimization:
 25 |     batch_size: 4
 26 |     multihead_kwargs:
 27 |     - final_lr: 0.0
 28 |       final_weight_decay: 0.01
 29 |       lr: 0.005
 30 |       start_lr: 0.005
 31 |       warmup: 0.0
 32 |       weight_decay: 0.01
 33 |     - final_lr: 0.0
 34 |       final_weight_decay: 0.01
 35 |       lr: 0.003
 36 |       start_lr: 0.003
 37 |       warmup: 0.0
 38 |       weight_decay: 0.01
 39 |     - final_lr: 0.0
 40 |       final_weight_decay: 0.01
 41 |       lr: 0.001
 42 |       start_lr: 0.001
 43 |       warmup: 0.0
 44 |       weight_decay: 0.01
 45 |     - final_lr: 0.0
 46 |       final_weight_decay: 0.01
 47 |       lr: 0.0003
 48 |       start_lr: 0.0003
 49 |       warmup: 0.0
 50 |       weight_decay: 0.01
 51 |     - final_lr: 0.0
 52 |       final_weight_decay: 0.01
 53 |       lr: 0.0001
 54 |       start_lr: 0.0001
 55 |       warmup: 0.0
 56 |       weight_decay: 0.01
 57 |     - final_lr: 0.0
 58 |       final_weight_decay: 0.1
 59 |       lr: 0.005
 60 |       start_lr: 0.005
 61 |       warmup: 0.0
 62 |       weight_decay: 0.1
 63 |     - final_lr: 0.0
 64 |       final_weight_decay: 0.1
 65 |       lr: 0.003
 66 |       start_lr: 0.003
 67 |       warmup: 0.0
 68 |       weight_decay: 0.1
 69 |     - final_lr: 0.0
 70 |       final_weight_decay: 0.1
 71 |       lr: 0.001
 72 |       start_lr: 0.001
 73 |       warmup: 0.0
 74 |       weight_decay: 0.1
 75 |     - final_lr: 0.0
 76 |       final_weight_decay: 0.1
 77 |       lr: 0.0003
 78 |       start_lr: 0.0003
 79 |       warmup: 0.0
 80 |       weight_decay: 0.1
 81 |     - final_lr: 0.0
 82 |       final_weight_decay: 0.1
 83 |       lr: 0.0001
 84 |       start_lr: 0.0001
 85 |       warmup: 0.0
 86 |       weight_decay: 0.1
 87 |     - final_lr: 0.0
 88 |       final_weight_decay: 0.4
 89 |       lr: 0.005
 90 |       start_lr: 0.005
 91 |       warmup: 0.0
 92 |       weight_decay: 0.4
 93 |     - final_lr: 0.0
 94 |       final_weight_decay: 0.4
 95 |       lr: 0.003
 96 |       start_lr: 0.003
 97 |       warmup: 0.0
 98 |       weight_decay: 0.4
 99 |     - final_lr: 0.0
100 |       final_weight_decay: 0.4
101 |       lr: 0.001
102 |       start_lr: 0.001
103 |       warmup: 0.0
104 |       weight_decay: 0.4
105 |     - final_lr: 0.0
106 |       final_weight_decay: 0.4
107 |       lr: 0.0003
108 |       start_lr: 0.0003
109 |       warmup: 0.0
110 |       weight_decay: 0.4
111 |     - final_lr: 0.0
112 |       final_weight_decay: 0.4
113 |       lr: 0.0001
114 |       start_lr: 0.0001
115 |       warmup: 0.0
116 |       weight_decay: 0.4
117 |     - final_lr: 0.0
118 |       final_weight_decay: 0.8
119 |       lr: 0.005
120 |       start_lr: 0.005
121 |       warmup: 0.0
122 |       weight_decay: 0.8
123 |     - final_lr: 0.0
124 |       final_weight_decay: 0.8
125 |       lr: 0.003
126 |       start_lr: 0.003
127 |       warmup: 0.0
128 |       weight_decay: 0.8
129 |     - final_lr: 0.0
130 |       final_weight_decay: 0.8
131 |       lr: 0.001
132 |       start_lr: 0.001
133 |       warmup: 0.0
134 |       weight_decay: 0.8
135 |     - final_lr: 0.0
136 |       final_weight_decay: 0.8
137 |       lr: 0.0003
138 |       start_lr: 0.0003
139 |       warmup: 0.0
140 |       weight_decay: 0.8
141 |     - final_lr: 0.0
142 |       final_weight_decay: 0.8
143 |       lr: 0.0001
144 |       start_lr: 0.0001
145 |       warmup: 0.0
146 |       weight_decay: 0.8
147 |     num_epochs: 20
148 |     use_bfloat16: true
149 |     use_pos_embed: false
150 | model_kwargs:
151 |   checkpoint: /your_vjepa2_checkpoints/vitl.pt
152 |   module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip
153 |   pretrain_kwargs:
154 |     encoder:
155 |       checkpoint_key: target_encoder
156 |       img_temporal_dim_size: null
157 |       model_name: vit_large
158 |       patch_size: 16
159 |       tubelet_size: 2
160 |       uniform_power: true
161 |       use_rope: true
162 |   wrapper_kwargs:
163 |     max_frames: 128
164 |     use_pos_embed: false
165 | 


--------------------------------------------------------------------------------
/configs/eval/vitl/ssv2.yaml:
--------------------------------------------------------------------------------
  1 | cpus_per_task: 16
  2 | eval_name: video_classification_frozen
  3 | folder: /your_folder/evals/vitl/ssv2
  4 | mem_per_gpu: 220G
  5 | nodes: 8
  6 | max_workers: 8
  7 | resume_checkpoint: true
  8 | tag: ssv2-vitl16-16x2x3-16f
  9 | tasks_per_node: 8
 10 | experiment:
 11 |   classifier:
 12 |     num_heads: 16
 13 |     num_probe_blocks: 4
 14 |   data:
 15 |     dataset_type: VideoDataset
 16 |     dataset_train: /your_data_path/ssv2_train_paths.csv
 17 |     dataset_val: /your_data_path/ssv2_val_paths.csv
 18 |     frame_step: 4
 19 |     frames_per_clip: 16
 20 |     num_classes: 174
 21 |     num_segments: 2
 22 |     num_views_per_segment: 3
 23 |     resolution: 256
 24 |   optimization:
 25 |     batch_size: 4
 26 |     multihead_kwargs:
 27 |     - final_lr: 0.0
 28 |       final_weight_decay: 0.01
 29 |       lr: 0.005
 30 |       start_lr: 0.005
 31 |       warmup: 0.0
 32 |       weight_decay: 0.01
 33 |     - final_lr: 0.0
 34 |       final_weight_decay: 0.01
 35 |       lr: 0.003
 36 |       start_lr: 0.003
 37 |       warmup: 0.0
 38 |       weight_decay: 0.01
 39 |     - final_lr: 0.0
 40 |       final_weight_decay: 0.01
 41 |       lr: 0.001
 42 |       start_lr: 0.001
 43 |       warmup: 0.0
 44 |       weight_decay: 0.01
 45 |     - final_lr: 0.0
 46 |       final_weight_decay: 0.01
 47 |       lr: 0.0003
 48 |       start_lr: 0.0003
 49 |       warmup: 0.0
 50 |       weight_decay: 0.01
 51 |     - final_lr: 0.0
 52 |       final_weight_decay: 0.01
 53 |       lr: 0.0001
 54 |       start_lr: 0.0001
 55 |       warmup: 0.0
 56 |       weight_decay: 0.01
 57 |     - final_lr: 0.0
 58 |       final_weight_decay: 0.1
 59 |       lr: 0.005
 60 |       start_lr: 0.005
 61 |       warmup: 0.0
 62 |       weight_decay: 0.1
 63 |     - final_lr: 0.0
 64 |       final_weight_decay: 0.1
 65 |       lr: 0.003
 66 |       start_lr: 0.003
 67 |       warmup: 0.0
 68 |       weight_decay: 0.1
 69 |     - final_lr: 0.0
 70 |       final_weight_decay: 0.1
 71 |       lr: 0.001
 72 |       start_lr: 0.001
 73 |       warmup: 0.0
 74 |       weight_decay: 0.1
 75 |     - final_lr: 0.0
 76 |       final_weight_decay: 0.1
 77 |       lr: 0.0003
 78 |       start_lr: 0.0003
 79 |       warmup: 0.0
 80 |       weight_decay: 0.1
 81 |     - final_lr: 0.0
 82 |       final_weight_decay: 0.1
 83 |       lr: 0.0001
 84 |       start_lr: 0.0001
 85 |       warmup: 0.0
 86 |       weight_decay: 0.1
 87 |     - final_lr: 0.0
 88 |       final_weight_decay: 0.4
 89 |       lr: 0.005
 90 |       start_lr: 0.005
 91 |       warmup: 0.0
 92 |       weight_decay: 0.4
 93 |     - final_lr: 0.0
 94 |       final_weight_decay: 0.4
 95 |       lr: 0.003
 96 |       start_lr: 0.003
 97 |       warmup: 0.0
 98 |       weight_decay: 0.4
 99 |     - final_lr: 0.0
100 |       final_weight_decay: 0.4
101 |       lr: 0.001
102 |       start_lr: 0.001
103 |       warmup: 0.0
104 |       weight_decay: 0.4
105 |     - final_lr: 0.0
106 |       final_weight_decay: 0.4
107 |       lr: 0.0003
108 |       start_lr: 0.0003
109 |       warmup: 0.0
110 |       weight_decay: 0.4
111 |     - final_lr: 0.0
112 |       final_weight_decay: 0.4
113 |       lr: 0.0001
114 |       start_lr: 0.0001
115 |       warmup: 0.0
116 |       weight_decay: 0.4
117 |     - final_lr: 0.0
118 |       final_weight_decay: 0.8
119 |       lr: 0.005
120 |       start_lr: 0.005
121 |       warmup: 0.0
122 |       weight_decay: 0.8
123 |     - final_lr: 0.0
124 |       final_weight_decay: 0.8
125 |       lr: 0.003
126 |       start_lr: 0.003
127 |       warmup: 0.0
128 |       weight_decay: 0.8
129 |     - final_lr: 0.0
130 |       final_weight_decay: 0.8
131 |       lr: 0.001
132 |       start_lr: 0.001
133 |       warmup: 0.0
134 |       weight_decay: 0.8
135 |     - final_lr: 0.0
136 |       final_weight_decay: 0.8
137 |       lr: 0.0003
138 |       start_lr: 0.0003
139 |       warmup: 0.0
140 |       weight_decay: 0.8
141 |     - final_lr: 0.0
142 |       final_weight_decay: 0.8
143 |       lr: 0.0001
144 |       start_lr: 0.0001
145 |       warmup: 0.0
146 |       weight_decay: 0.8
147 |     num_epochs: 20
148 |     use_bfloat16: true
149 |     use_pos_embed: false
150 | model_kwargs:
151 |   checkpoint: /your_vjepa2_checkpoints/vitl.pt
152 |   module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip
153 |   pretrain_kwargs:
154 |     encoder:
155 |       checkpoint_key: target_encoder
156 |       img_temporal_dim_size: null
157 |       model_name: vit_large
158 |       patch_size: 16
159 |       tubelet_size: 2
160 |       uniform_power: true
161 |       use_rope: true
162 |   wrapper_kwargs:
163 |     max_frames: 128
164 |     use_pos_embed: false
165 | 


--------------------------------------------------------------------------------
/configs/inference/vitg-384/diving48.yaml:
--------------------------------------------------------------------------------
 1 | nodes: 8
 2 | tasks_per_node: 8
 3 | cpus_per_task: 12
 4 | mem_per_gpu: 200G
 5 | tag: diving48-vitg16-384-32x4x3
 6 | eval_name: video_classification_frozen
 7 | folder: /your_folder/evals/vitg-384/diving48
 8 | resume_checkpoint: true
 9 | val_only: true
10 | experiment:
11 |   classifier:
12 |     num_probe_blocks: 4
13 |     num_heads: 16
14 |   data:
15 |     dataset_type: VideoDataset
16 |     dataset_train: /your_data_dir/diving48/annotations/Diving48_train_paths.csv
17 |     dataset_val: /your_data_dir/diving48/annotations/Diving48_test_paths.csv
18 |     num_classes: 48
19 |     resolution: 384
20 |     frames_per_clip: 32
21 |     frame_step: 2
22 |     num_segments: 4
23 |     num_views_per_segment: 3
24 |   optimization:
25 |     use_pos_embed: false
26 |     num_epochs: 100
27 |     batch_size: 2
28 |     use_bfloat16: true
29 |     multihead_kwargs:
30 |       - weight_decay: 0.0
31 |         final_weight_decay: 0.0
32 |         lr: 0.0
33 |         start_lr: 0.0
34 |         final_lr: 0.0
35 |         warmup: 0.0
36 | model_kwargs:
37 |   checkpoint: /your_vjepa2_checkpoints/vitg-384.pt
38 |   module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip_multilevel
39 |   wrapper_kwargs:
40 |     max_frames: 128
41 |     use_pos_embed: false
42 |     out_layers: [24, 29, 34, 39]
43 |   pretrain_kwargs:
44 |     encoder:
45 |       model_name: vit_giant_xformers
46 |       checkpoint_key: target_encoder
47 |       tubelet_size: 2
48 |       patch_size: 16
49 |       uniform_power: true
50 |       use_rope: true
51 | 


--------------------------------------------------------------------------------
/configs/inference/vitg-384/ek100.yaml:
--------------------------------------------------------------------------------
 1 | nodes: 8
 2 | tasks_per_node: 8
 3 | cpus_per_task: 12
 4 | tag: ek100-vitg16-384
 5 | eval_name: action_anticipation_frozen
 6 | folder: /your_folder/evals/vitg-384/ek100
 7 | resume_checkpoint: true
 8 | val_only: true
 9 | experiment:
10 |   classifier:
11 |     num_probe_blocks: 4
12 |     num_heads: 16
13 |   data:
14 |     anticipation_time_sec:
15 |     - 1.0
16 |     - 1.0
17 |     auto_augment: true
18 |     file_format: 0
19 |     dataset: EK100
20 |     base_path: /your_ek100_root_dir/
21 |     dataset_train: /your_data/EPIC_100_train.csv
22 |     dataset_val: /your_data/EPIC_100_validation.csv
23 |     frames_per_clip: 32
24 |     frames_per_second: 8
25 |     motion_shift: false
26 |     num_workers: 2
27 |     pin_memory: true
28 |     random_resize_scale:
29 |     - 0.08
30 |     - 1.0
31 |     reprob: 0.25
32 |     resolution: 384
33 |     train_anticipation_point:
34 |     - 0.0
35 |     - 0.25
36 |     train_anticipation_time_sec:
37 |     - 0.25
38 |     - 1.75
39 |   optimization:
40 |     num_epochs: 20
41 |     batch_size: 2
42 |     use_bfloat16: true
43 |     use_focal_loss: true
44 |     multihead_kwargs:
45 |       - weight_decay: 0.0
46 |         final_weight_decay: 0.0
47 |         lr: 0.0
48 |         start_lr: 0.0
49 |         final_lr: 0.0
50 |         warmup: 0.0
51 | model_kwargs:
52 |   checkpoint: /your_vjepa2_checkpoints/vitg-384.pt
53 |   module_name: evals.action_anticipation_frozen.modelcustom.vit_encoder_predictor_concat_ar
54 |   wrapper_kwargs:
55 |     no_predictor: false
56 |     num_output_frames: 2
57 |     num_steps: 1
58 |   pretrain_kwargs:
59 |     encoder:
60 |       model_name: vit_giant_xformers
61 |       checkpoint_key: target_encoder
62 |       tubelet_size: 2
63 |       patch_size: 16
64 |       uniform_power: true
65 |       use_rope: true
66 |     predictor:
67 |       model_name: vit_predictor
68 |       checkpoint_key: predictor
69 |       num_frames: 64
70 |       depth: 12
71 |       num_heads: 12
72 |       predictor_embed_dim: 384
73 |       num_mask_tokens: 10
74 |       uniform_power: true
75 |       use_mask_tokens: true
76 |       use_sdpa: true
77 |       use_silu: false
78 |       wide_silu: false
79 |       use_rope: true
80 | 


--------------------------------------------------------------------------------
/configs/inference/vitg-384/ssv2.yaml:
--------------------------------------------------------------------------------
 1 | cpus_per_task: 16
 2 | eval_name: video_classification_frozen
 3 | folder: /your_folder/evals/vitg-384/ssv2
 4 | mem_per_gpu: 220G
 5 | nodes: 16
 6 | num_workers: 8
 7 | resume_checkpoint: true
 8 | val_only: true
 9 | tag: ssv2-vitg16-384-64x2x3
10 | tasks_per_node: 8
11 | experiment:
12 |   classifier:
13 |     num_heads: 16
14 |     num_probe_blocks: 4
15 |   data:
16 |     dataset_type: VideoDataset
17 |     dataset_train: /your_data_path/ssv2_train_paths.csv
18 |     dataset_val: /your_data_path/ssv2_val_paths.csv
19 |     frame_step: 2
20 |     frames_per_clip: 64
21 |     num_classes: 174
22 |     num_segments: 2
23 |     num_views_per_segment: 3
24 |     resolution: 384
25 |   optimization:
26 |     batch_size: 2
27 |     multihead_kwargs:
28 |     - final_lr: 0.0
29 |       final_weight_decay: 0.0
30 |       lr: 0.0
31 |       start_lr: 0.0
32 |       warmup: 0.0
33 |       weight_decay: 0.0
34 |     num_epochs: 20
35 |     use_bfloat16: true
36 |     use_pos_embed: false
37 | model_kwargs:
38 |   checkpoint: /your_vjepa2_checkpoints/vitg-384.pt
39 |   module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip
40 |   pretrain_kwargs:
41 |     encoder:
42 |       checkpoint_key: target_encoder
43 |       img_temporal_dim_size: null
44 |       model_name: vit_giant_xformers
45 |       patch_size: 16
46 |       tubelet_size: 2
47 |       uniform_power: true
48 |       use_rope: true
49 |   wrapper_kwargs:
50 |     max_frames: 128
51 |     use_pos_embed: false
52 | 


--------------------------------------------------------------------------------
/configs/inference/vitl/diving48.yaml:
--------------------------------------------------------------------------------
 1 | nodes: 8
 2 | tasks_per_node: 8
 3 | cpus_per_task: 12
 4 | mem_per_gpu: 200G
 5 | tag: diving48-vitl16-32x4x3
 6 | eval_name: video_classification_frozen
 7 | folder: /your_folder/evals/vitl/diving48
 8 | resume_checkpoint: true
 9 | val_only: true
10 | experiment:
11 |   classifier:
12 |     num_probe_blocks: 4
13 |     num_heads: 16
14 |   data:
15 |     dataset_type: VideoDataset
16 |     dataset_train: /your_data_dir/diving48/annotations/Diving48_train_paths.csv
17 |     dataset_val: /your_data_dir/diving48/annotations/Diving48_test_paths.csv
18 |     num_classes: 48
19 |     resolution: 256
20 |     frames_per_clip: 32
21 |     frame_step: 2
22 |     num_segments: 4
23 |     num_views_per_segment: 3
24 |   optimization:
25 |     use_pos_embed: false
26 |     num_epochs: 100
27 |     batch_size: 2
28 |     use_bfloat16: true
29 |     multihead_kwargs:
30 |       - weight_decay: 0.0
31 |         final_weight_decay: 0.0
32 |         lr: 0.0
33 |         start_lr: 0.0
34 |         final_lr: 0.0
35 |         warmup: 0.0
36 | model_kwargs:
37 |   checkpoint: /your_vjepa2_checkpoints/vitl.pt
38 |   module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip_multilevel
39 |   wrapper_kwargs:
40 |     max_frames: 128
41 |     use_pos_embed: false
42 |     out_layers: [17, 19, 21, 23]
43 |   pretrain_kwargs:
44 |     encoder:
45 |       model_name: vit_large
46 |       checkpoint_key: target_encoder
47 |       tubelet_size: 2
48 |       patch_size: 16
49 |       uniform_power: true
50 |       use_rope: true
51 | 


--------------------------------------------------------------------------------
/configs/inference/vitl/ek100.yaml:
--------------------------------------------------------------------------------
 1 | nodes: 8
 2 | tasks_per_node: 8
 3 | cpus_per_task: 12
 4 | tag: ek100-vitl16
 5 | eval_name: action_anticipation_frozen
 6 | folder: /your_folder/evals/vitl/ek100
 7 | resume_checkpoint: true
 8 | val_only: true
 9 | experiment:
10 |   classifier:
11 |     num_probe_blocks: 4
12 |     num_heads: 16
13 |   data:
14 |     anticipation_time_sec:
15 |     - 1.0
16 |     - 1.0
17 |     auto_augment: true
18 |     file_format: 0
19 |     dataset: EK100
20 |     base_path: /your_ek100_root_dir/
21 |     dataset_train: /your_data/EPIC_100_train.csv
22 |     dataset_val: /your_data/EPIC_100_validation.csv
23 |     frames_per_clip: 32
24 |     frames_per_second: 8
25 |     motion_shift: false
26 |     num_workers: 2
27 |     pin_memory: true
28 |     random_resize_scale:
29 |     - 0.08
30 |     - 1.0
31 |     reprob: 0.25
32 |     resolution: 256
33 |     train_anticipation_point:
34 |     - 0.0
35 |     - 0.25
36 |     train_anticipation_time_sec:
37 |     - 0.25
38 |     - 1.75
39 |   optimization:
40 |     num_epochs: 20
41 |     batch_size: 2
42 |     use_bfloat16: true
43 |     use_focal_loss: true
44 |     multihead_kwargs:
45 |       - weight_decay: 0.0
46 |         final_weight_decay: 0.0
47 |         lr: 0.0
48 |         start_lr: 0.0
49 |         final_lr: 0.0
50 |         warmup: 0.0
51 | model_kwargs:
52 |   checkpoint: /your_vjepa2_checkpoints/vitl.pt
53 |   module_name: evals.action_anticipation_frozen.modelcustom.vit_encoder_predictor_concat_ar
54 |   wrapper_kwargs:
55 |     no_predictor: false
56 |     num_output_frames: 2
57 |     num_steps: 1
58 |   pretrain_kwargs:
59 |     encoder:
60 |       model_name: vit_large
61 |       checkpoint_key: target_encoder
62 |       tubelet_size: 2
63 |       patch_size: 16
64 |       uniform_power: true
65 |       use_rope: true
66 |     predictor:
67 |       model_name: vit_predictor
68 |       checkpoint_key: predictor
69 |       num_frames: 64
70 |       depth: 12
71 |       num_heads: 12
72 |       predictor_embed_dim: 384
73 |       num_mask_tokens: 10
74 |       uniform_power: true
75 |       use_mask_tokens: true
76 |       use_sdpa: true
77 |       use_silu: false
78 |       wide_silu: false
79 |       use_rope: true
80 | 


--------------------------------------------------------------------------------
/configs/inference/vitl/ssv2.yaml:
--------------------------------------------------------------------------------
 1 | cpus_per_task: 16
 2 | eval_name: video_classification_frozen
 3 | folder: /your_folder/evals/vitl/ssv2
 4 | mem_per_gpu: 220G
 5 | nodes: 8
 6 | max_workers: 8
 7 | resume_checkpoint: true
 8 | val_only: true
 9 | tag: ssv2-vitl16-16x2x3-16f
10 | tasks_per_node: 8
11 | experiment:
12 |   classifier:
13 |     num_heads: 16
14 |     num_probe_blocks: 4
15 |   data:
16 |     dataset_type: VideoDataset
17 |     dataset_train: /your_data_path/ssv2_train_paths.csv
18 |     dataset_val: /your_data_path/ssv2_val_paths.csv
19 |     frame_step: 4
20 |     frames_per_clip: 16
21 |     num_classes: 174
22 |     num_segments: 2
23 |     num_views_per_segment: 3
24 |     resolution: 256
25 |   optimization:
26 |     batch_size: 4
27 |     multihead_kwargs:
28 |     - final_lr: 0.0
29 |       final_weight_decay: 0.0
30 |       lr: 0.0
31 |       start_lr: 0.0
32 |       warmup: 0.0
33 |       weight_decay: 0.0
34 |     num_epochs: 20
35 |     use_bfloat16: true
36 |     use_pos_embed: false
37 | model_kwargs:
38 |   checkpoint: /your_vjepa2_checkpoints/vitl.pt
39 |   module_name: evals.video_classification_frozen.modelcustom.vit_encoder_multiclip
40 |   pretrain_kwargs:
41 |     encoder:
42 |       checkpoint_key: target_encoder
43 |       img_temporal_dim_size: null
44 |       model_name: vit_large
45 |       patch_size: 16
46 |       tubelet_size: 2
47 |       uniform_power: true
48 |       use_rope: true
49 |   wrapper_kwargs:
50 |     max_frames: 128
51 |     use_pos_embed: false
52 | 


--------------------------------------------------------------------------------
/configs/train/vitg16/cooldown-256px-64f.yaml:
--------------------------------------------------------------------------------
  1 | app: vjepa
  2 | cpus_per_task: 32
  3 | folder: /your_folder/anneal/64.8.vitg16-256px-64f
  4 | mem_per_gpu: 220G
  5 | nodes: 64
  6 | tasks_per_node: 8
  7 | data:
  8 |   dataset_type: VideoDataset
  9 |   datasets:
 10 |   - /your_k710_root_dir/k710_train_paths.csv
 11 |   - /your_data_path/ssv2_train_paths.csv
 12 |   - /your_data/howto_320p.csv
 13 |   datasets_weights:
 14 |   - 0.335
 15 |   - 0.100
 16 |   - 0.565
 17 |   batch_size: 6
 18 |   crop_size: 256
 19 |   dataset_fpcs:
 20 |   - 64
 21 |   - 64
 22 |   - 64
 23 |   fps: 4
 24 |   num_workers: 12
 25 |   patch_size: 16
 26 |   persistent_workers: true
 27 |   pin_mem: false
 28 |   tubelet_size: 2
 29 | data_aug:
 30 |   auto_augment: false
 31 |   motion_shift: false
 32 |   random_resize_aspect_ratio:
 33 |   - 0.75
 34 |   - 1.35
 35 |   random_resize_scale:
 36 |   - 0.3
 37 |   - 1.0
 38 |   reprob: 0.0
 39 | loss:
 40 |   loss_exp: 1.0
 41 | mask:
 42 | - aspect_ratio:
 43 |   - 0.75
 44 |   - 1.5
 45 |   full_complement: false
 46 |   max_keep: null
 47 |   max_temporal_keep: 1.0
 48 |   num_blocks: 8
 49 |   spatial_scale:
 50 |   - 0.15
 51 |   - 0.15
 52 |   temporal_scale:
 53 |   - 1.0
 54 |   - 1.0
 55 | - aspect_ratio:
 56 |   - 0.75
 57 |   - 1.5
 58 |   full_complement: false
 59 |   max_keep: null
 60 |   max_temporal_keep: 1.0
 61 |   num_blocks: 2
 62 |   spatial_scale:
 63 |   - 0.7
 64 |   - 0.7
 65 |   temporal_scale:
 66 |   - 1.0
 67 |   - 1.0
 68 | meta:
 69 |   dtype: bfloat16
 70 |   eval_freq: 100
 71 |   load_checkpoint: true
 72 |   read_checkpoint: null
 73 |   save_every_freq: 50
 74 |   seed: 239
 75 |   use_sdpa: true
 76 | model:
 77 |   model_name: vit_giant_xformers
 78 |   pred_depth: 12
 79 |   pred_embed_dim: 384
 80 |   pred_num_heads: 12
 81 |   uniform_power: true
 82 |   use_activation_checkpointing: true
 83 |   use_mask_tokens: true
 84 |   use_rope: true
 85 |   zero_init_mask_tokens: true
 86 | optimization:
 87 |   anneal_ckpt: /your_folder/pretrain/16.8.vitg.256px.16f/e0.pt
 88 |   ema:
 89 |   - 0.99925
 90 |   - 0.99925
 91 |   epochs: 4
 92 |   final_lr: 1.0e-06
 93 |   final_weight_decay: 0.04
 94 |   ipe: 30
 95 |   ipe_scale: 1.25
 96 |   is_anneal: true
 97 |   lr: 0.000525
 98 |   resume_anneal: true
 99 |   start_lr: 0.0001
100 |   warmup: 0
101 |   weight_decay: 0.04
102 | 


--------------------------------------------------------------------------------
/configs/train/vitg16/cooldown-384px-64f.yaml:
--------------------------------------------------------------------------------
  1 | app: vjepa
  2 | cpus_per_task: 32
  3 | folder: /your_folder/anneal/64.8.vitg16-384px-64f
  4 | mem_per_gpu: 220G
  5 | nodes: 64
  6 | tasks_per_node: 8
  7 | data:
  8 |   dataset_type: VideoDataset
  9 |   datasets:
 10 |   - /your_k710_root_dir/k710_train_paths.csv
 11 |   - /your_data_path/ssv2_train_paths.csv
 12 |   - /your_data/howto_320p.csv
 13 |   datasets_weights:
 14 |   - 0.335
 15 |   - 0.100
 16 |   - 0.565
 17 |   batch_size: 6
 18 |   crop_size: 384
 19 |   dataset_fpcs:
 20 |   - 64
 21 |   - 64
 22 |   - 64
 23 |   fps: 4
 24 |   num_workers: 12
 25 |   patch_size: 16
 26 |   persistent_workers: true
 27 |   pin_mem: false
 28 |   tubelet_size: 2
 29 | data_aug:
 30 |   auto_augment: false
 31 |   motion_shift: false
 32 |   random_resize_aspect_ratio:
 33 |   - 0.75
 34 |   - 1.35
 35 |   random_resize_scale:
 36 |   - 0.3
 37 |   - 1.0
 38 |   reprob: 0.0
 39 | loss:
 40 |   loss_exp: 1.0
 41 | mask:
 42 | - aspect_ratio:
 43 |   - 0.75
 44 |   - 1.5
 45 |   full_complement: false
 46 |   max_keep: null
 47 |   max_temporal_keep: 1.0
 48 |   num_blocks: 8
 49 |   spatial_scale:
 50 |   - 0.15
 51 |   - 0.15
 52 |   temporal_scale:
 53 |   - 1.0
 54 |   - 1.0
 55 | - aspect_ratio:
 56 |   - 0.75
 57 |   - 1.5
 58 |   full_complement: false
 59 |   max_keep: null
 60 |   max_temporal_keep: 1.0
 61 |   num_blocks: 2
 62 |   spatial_scale:
 63 |   - 0.7
 64 |   - 0.7
 65 |   temporal_scale:
 66 |   - 1.0
 67 |   - 1.0
 68 | meta:
 69 |   dtype: bfloat16
 70 |   eval_freq: 100
 71 |   load_checkpoint: true
 72 |   read_checkpoint: null
 73 |   save_every_freq: 50
 74 |   seed: 239
 75 |   use_sdpa: true
 76 | model:
 77 |   model_name: vit_giant_xformers
 78 |   pred_depth: 12
 79 |   pred_embed_dim: 384
 80 |   pred_num_heads: 12
 81 |   uniform_power: true
 82 |   use_activation_checkpointing: true
 83 |   use_mask_tokens: true
 84 |   use_rope: true
 85 |   zero_init_mask_tokens: true
 86 | optimization:
 87 |   anneal_ckpt: /your_folder/pretrain/16.8.vitg.256px.16f/e0.pt
 88 |   ema:
 89 |   - 0.99925
 90 |   - 0.99925
 91 |   epochs: 40
 92 |   final_lr: 1.0e-06
 93 |   final_weight_decay: 0.04
 94 |   ipe: 300
 95 |   ipe_scale: 1.25
 96 |   is_anneal: true
 97 |   lr: 0.000525
 98 |   resume_anneal: true
 99 |   start_lr: 0.0001
100 |   warmup: 0
101 |   weight_decay: 0.04
102 | 


--------------------------------------------------------------------------------
/configs/train/vitg16/droid-256px-8f.yaml:
--------------------------------------------------------------------------------
 1 | app: vjepa_droid
 2 | cpus_per_task: 12
 3 | folder: /your_folder/droid/4.8.vitg16-256px-8f
 4 | mem_per_gpu: 220G
 5 | nodes: 4
 6 | tasks_per_node: 8
 7 | data:
 8 |   batch_size: 8
 9 |   camera_views:
10 |   - left_mp4_path
11 |   crop_size: 256
12 |   datasets:
13 |     - /your_file_path/droid_train_paths_cw.csv
14 |   dataset_fpcs:
15 |   - 8
16 |   fps: 4
17 |   num_workers: 12
18 |   patch_size: 16
19 |   pin_mem: true
20 |   stereo_view: false
21 |   tubelet_size: 2
22 | data_aug:
23 |   auto_augment: false
24 |   horizontal_flip: false
25 |   motion_shift: false
26 |   random_resize_aspect_ratio:
27 |   - 0.75
28 |   - 1.35
29 |   random_resize_scale:
30 |   - 1.777
31 |   - 1.777
32 |   reprob: 0.0
33 | loss:
34 |   auto_steps: 2
35 |   loss_exp: 1.0
36 |   normalize_reps: true
37 |   reg_coeff: 0.0
38 | meta:
39 |   dtype: bfloat16
40 |   eval_freq: 100
41 |   resume_checkpoint: null
42 |   load_predictor: false
43 |   pretrain_checkpoint: /your_vjepa2_checkpoints/vitg.pt
44 |   context_encoder_key: target_encoder
45 |   target_encoder_key: target_encoder
46 |   save_every_freq: 25
47 |   seed: 239
48 |   use_sdpa: true
49 | model:
50 |   model_name: vit_giant_xformers
51 |   pred_depth: 24
52 |   pred_embed_dim: 1024
53 |   pred_is_frame_causal: true
54 |   pred_num_heads: 16
55 |   uniform_power: true
56 |   use_activation_checkpointing: true
57 |   use_extrinsics: false
58 |   use_rope: true
59 | optimization:
60 |   anneal: 15
61 |   epochs: 315
62 |   final_lr: 0.0
63 |   final_weight_decay: 0.04
64 |   ipe: 300
65 |   lr: 0.000425
66 |   start_lr: 0.000075
67 |   warmup: 15
68 |   weight_decay: 0.04
69 | 


--------------------------------------------------------------------------------
/configs/train/vitg16/pretrain-256px-16f.yaml:
--------------------------------------------------------------------------------
 1 | app: vjepa
 2 | nodes: 16
 3 | tasks_per_node: 8
 4 | cpus_per_task: 16
 5 | mem_per_gpu: 220G
 6 | folder: /your_folder/pretrain/16.8.vitg.256px.16f
 7 | data:
 8 |   dataset_type: VideoDataset
 9 |   datasets:
10 |   - /your_k710_root_dir/k710_train_paths.csv
11 |   - /your_data_path/ssv2_train_paths.csv
12 |   - /your_data/howto_320p.csv
13 |   datasets_weights:
14 |   - 0.335
15 |   - 0.100
16 |   - 0.565
17 |   batch_size: 24
18 |   crop_size: 256
19 |   patch_size: 16
20 |   dataset_fpcs:
21 |   - 16
22 |   - 16
23 |   - 16
24 |   tubelet_size: 2
25 |   fps: 4
26 |   num_workers: 8
27 |   persistent_workers: true
28 |   pin_mem: true
29 | data_aug:
30 |   auto_augment: false
31 |   motion_shift: false
32 |   random_resize_aspect_ratio:
33 |   - 0.75
34 |   - 1.35
35 |   random_resize_scale:
36 |   - 0.3
37 |   - 1.0
38 |   reprob: 0.0
39 | loss:
40 |   loss_exp: 1.0
41 | mask:
42 | - aspect_ratio:
43 |   - 0.75
44 |   - 1.5
45 |   full_complement: false
46 |   max_keep: null
47 |   max_temporal_keep: 1.0
48 |   num_blocks: 8
49 |   spatial_scale:
50 |   - 0.15
51 |   - 0.15
52 |   temporal_scale:
53 |   - 1.0
54 |   - 1.0
55 | - aspect_ratio:
56 |   - 0.75
57 |   - 1.5
58 |   full_complement: false
59 |   max_keep: null
60 |   max_temporal_keep: 1.0
61 |   num_blocks: 2
62 |   spatial_scale:
63 |   - 0.7
64 |   - 0.7
65 |   temporal_scale:
66 |   - 1.0
67 |   - 1.0
68 | meta:
69 |   dtype: bfloat16
70 |   eval_freq: 100
71 |   load_checkpoint: true
72 |   read_checkpoint: null
73 |   save_every_freq: 50
74 |   seed: 239
75 |   use_sdpa: true
76 | model:
77 |   model_name: vit_giant_xformers
78 |   pred_depth: 12
79 |   pred_embed_dim: 384
80 |   pred_num_heads: 12
81 |   uniform_power: true
82 |   use_activation_checkpointing: true
83 |   use_mask_tokens: true
84 |   use_rope: true
85 |   zero_init_mask_tokens: true
86 | optimization:
87 |   ema:
88 |   - 0.99925
89 |   - 0.99925
90 |   epochs: 800
91 |   final_lr: 0.000525
92 |   final_weight_decay: 0.04
93 |   ipe: 300
94 |   ipe_scale: 1.25
95 |   lr: 0.000525
96 |   start_lr: 0.0001
97 |   warmup: 40
98 |   weight_decay: 0.04
99 | 


--------------------------------------------------------------------------------
/configs/train/vith16/cooldown-256px-64f.yaml:
--------------------------------------------------------------------------------
  1 | app: vjepa
  2 | cpus_per_task: 32
  3 | folder: /your_folder/anneal/32.8.vith16-256px-64f
  4 | mem_per_gpu: 220G
  5 | nodes: 32
  6 | tasks_per_node: 8
  7 | data:
  8 |   dataset_type: VideoDataset
  9 |   datasets:
 10 |   - /your_k710_root_dir/k710_train_paths.csv
 11 |   - /your_data_path/ssv2_train_paths.csv
 12 |   - /your_data/howto_320p.csv
 13 |   datasets_weights:
 14 |   - 0.335
 15 |   - 0.100
 16 |   - 0.565
 17 |   batch_size: 12
 18 |   crop_size: 256
 19 |   dataset_fpcs:
 20 |   - 64
 21 |   - 64
 22 |   - 64
 23 |   fps: 4
 24 |   num_workers: 12
 25 |   patch_size: 16
 26 |   persistent_workers: true
 27 |   pin_mem: false
 28 |   tubelet_size: 2
 29 | data_aug:
 30 |   auto_augment: false
 31 |   motion_shift: false
 32 |   random_resize_aspect_ratio:
 33 |   - 0.75
 34 |   - 1.35
 35 |   random_resize_scale:
 36 |   - 0.3
 37 |   - 1.0
 38 |   reprob: 0.0
 39 | loss:
 40 |   loss_exp: 1.0
 41 | mask:
 42 | - aspect_ratio:
 43 |   - 0.75
 44 |   - 1.5
 45 |   full_complement: false
 46 |   max_keep: null
 47 |   max_temporal_keep: 1.0
 48 |   num_blocks: 8
 49 |   spatial_scale:
 50 |   - 0.15
 51 |   - 0.15
 52 |   temporal_scale:
 53 |   - 1.0
 54 |   - 1.0
 55 | - aspect_ratio:
 56 |   - 0.75
 57 |   - 1.5
 58 |   full_complement: false
 59 |   max_keep: null
 60 |   max_temporal_keep: 1.0
 61 |   num_blocks: 2
 62 |   spatial_scale:
 63 |   - 0.7
 64 |   - 0.7
 65 |   temporal_scale:
 66 |   - 1.0
 67 |   - 1.0
 68 | meta:
 69 |   dtype: bfloat16
 70 |   eval_freq: 100
 71 |   load_checkpoint: true
 72 |   read_checkpoint: null
 73 |   save_every_freq: 50
 74 |   seed: 239
 75 |   use_sdpa: true
 76 | model:
 77 |   model_name: vit_huge
 78 |   pred_depth: 12
 79 |   pred_embed_dim: 384
 80 |   pred_num_heads: 12
 81 |   uniform_power: true
 82 |   use_activation_checkpointing: true
 83 |   use_mask_tokens: true
 84 |   use_rope: true
 85 |   zero_init_mask_tokens: true
 86 | optimization:
 87 |   anneal_ckpt: /your_folder/pretrain/16.8.vith.256px.16f/e0.pt
 88 |   ema:
 89 |   - 0.99925
 90 |   - 0.99925
 91 |   epochs: 40
 92 |   final_lr: 1.0e-06
 93 |   final_weight_decay: 0.04
 94 |   ipe: 300
 95 |   ipe_scale: 1.25
 96 |   is_anneal: true
 97 |   lr: 0.000525
 98 |   resume_anneal: true
 99 |   start_lr: 0.0001
100 |   warmup: 0
101 |   weight_decay: 0.04
102 | 


--------------------------------------------------------------------------------
/configs/train/vith16/pretrain-256px-16f.yaml:
--------------------------------------------------------------------------------
 1 | app: vjepa
 2 | nodes: 16
 3 | tasks_per_node: 8
 4 | cpus_per_task: 16
 5 | mem_per_gpu: 220G
 6 | folder: /your_folder/pretrain/16.8.vith.256px.16f
 7 | data:
 8 |   dataset_type: VideoDataset
 9 |   datasets:
10 |   - /your_k710_root_dir/k710_train_paths.csv
11 |   - /your_data_path/ssv2_train_paths.csv
12 |   - /your_data/howto_320p.csv
13 |   datasets_weights:
14 |   - 0.335
15 |   - 0.100
16 |   - 0.565
17 |   batch_size: 24
18 |   crop_size: 256
19 |   patch_size: 16
20 |   dataset_fpcs:
21 |   - 16
22 |   - 16
23 |   - 16
24 |   tubelet_size: 2
25 |   fps: 4
26 |   num_workers: 8
27 |   persistent_workers: true
28 |   pin_mem: true
29 | data_aug:
30 |   auto_augment: false
31 |   motion_shift: false
32 |   random_resize_aspect_ratio:
33 |   - 0.75
34 |   - 1.35
35 |   random_resize_scale:
36 |   - 0.3
37 |   - 1.0
38 |   reprob: 0.0
39 | loss:
40 |   loss_exp: 1.0
41 | mask:
42 | - aspect_ratio:
43 |   - 0.75
44 |   - 1.5
45 |   full_complement: false
46 |   max_keep: null
47 |   max_temporal_keep: 1.0
48 |   num_blocks: 8
49 |   spatial_scale:
50 |   - 0.15
51 |   - 0.15
52 |   temporal_scale:
53 |   - 1.0
54 |   - 1.0
55 | - aspect_ratio:
56 |   - 0.75
57 |   - 1.5
58 |   full_complement: false
59 |   max_keep: null
60 |   max_temporal_keep: 1.0
61 |   num_blocks: 2
62 |   spatial_scale:
63 |   - 0.7
64 |   - 0.7
65 |   temporal_scale:
66 |   - 1.0
67 |   - 1.0
68 | meta:
69 |   dtype: bfloat16
70 |   eval_freq: 100
71 |   load_checkpoint: true
72 |   read_checkpoint: null
73 |   save_every_freq: 50
74 |   seed: 239
75 |   use_sdpa: true
76 | model:
77 |   model_name: vit_huge
78 |   pred_depth: 12
79 |   pred_embed_dim: 384
80 |   pred_num_heads: 12
81 |   uniform_power: true
82 |   use_activation_checkpointing: true
83 |   use_mask_tokens: true
84 |   use_rope: true
85 |   zero_init_mask_tokens: true
86 | optimization:
87 |   ema:
88 |   - 0.99925
89 |   - 0.99925
90 |   epochs: 10
91 |   final_lr: 0.000525
92 |   final_weight_decay: 0.04
93 |   ipe: 300
94 |   ipe_scale: 1.25
95 |   lr: 0.000525
96 |   start_lr: 0.0001
97 |   warmup: 40
98 |   weight_decay: 0.04
99 | 


--------------------------------------------------------------------------------
/configs/train/vitl16/cooldown-256px-64f.yaml:
--------------------------------------------------------------------------------
  1 | app: vjepa
  2 | cpus_per_task: 32
  3 | folder: /your_folder/anneal/32.8.vitl16-256px-64f
  4 | mem_per_gpu: 220G
  5 | nodes: 32
  6 | tasks_per_node: 8
  7 | data:
  8 |   dataset_type: VideoDataset
  9 |   datasets:
 10 |   - /your_k710_root_dir/k710_train_paths.csv
 11 |   - /your_data_path/ssv2_train_paths.csv
 12 |   - /your_data/howto_320p.csv
 13 |   datasets_weights:
 14 |   - 0.335
 15 |   - 0.100
 16 |   - 0.565
 17 |   batch_size: 12
 18 |   crop_size: 256
 19 |   dataset_fpcs:
 20 |   - 64
 21 |   - 64
 22 |   - 64
 23 |   fps: 4
 24 |   num_workers: 12
 25 |   patch_size: 16
 26 |   persistent_workers: true
 27 |   pin_mem: false
 28 |   tubelet_size: 2
 29 | data_aug:
 30 |   auto_augment: false
 31 |   motion_shift: false
 32 |   random_resize_aspect_ratio:
 33 |   - 0.75
 34 |   - 1.35
 35 |   random_resize_scale:
 36 |   - 0.3
 37 |   - 1.0
 38 |   reprob: 0.0
 39 | loss:
 40 |   loss_exp: 1.0
 41 | mask:
 42 | - aspect_ratio:
 43 |   - 0.75
 44 |   - 1.5
 45 |   full_complement: false
 46 |   max_keep: null
 47 |   max_temporal_keep: 1.0
 48 |   num_blocks: 8
 49 |   spatial_scale:
 50 |   - 0.15
 51 |   - 0.15
 52 |   temporal_scale:
 53 |   - 1.0
 54 |   - 1.0
 55 | - aspect_ratio:
 56 |   - 0.75
 57 |   - 1.5
 58 |   full_complement: false
 59 |   max_keep: null
 60 |   max_temporal_keep: 1.0
 61 |   num_blocks: 2
 62 |   spatial_scale:
 63 |   - 0.7
 64 |   - 0.7
 65 |   temporal_scale:
 66 |   - 1.0
 67 |   - 1.0
 68 | meta:
 69 |   dtype: bfloat16
 70 |   eval_freq: 100
 71 |   load_checkpoint: true
 72 |   read_checkpoint: null
 73 |   save_every_freq: 50
 74 |   seed: 239
 75 |   use_sdpa: true
 76 | model:
 77 |   model_name: vit_large
 78 |   pred_depth: 12
 79 |   pred_embed_dim: 384
 80 |   pred_num_heads: 12
 81 |   uniform_power: true
 82 |   use_activation_checkpointing: true
 83 |   use_mask_tokens: true
 84 |   use_rope: true
 85 |   zero_init_mask_tokens: true
 86 | optimization:
 87 |   anneal_ckpt: /your_folder/pretrain/16.8.vitl.256px.16f/e0.pt
 88 |   ema:
 89 |   - 0.99925
 90 |   - 0.99925
 91 |   epochs: 40
 92 |   final_lr: 1.0e-06
 93 |   final_weight_decay: 0.04
 94 |   ipe: 300
 95 |   ipe_scale: 1.25
 96 |   is_anneal: true
 97 |   lr: 0.000525
 98 |   resume_anneal: true
 99 |   start_lr: 0.0001
100 |   warmup: 0
101 |   weight_decay: 0.04
102 | 


--------------------------------------------------------------------------------
/configs/train/vitl16/pretrain-256px-16f.yaml:
--------------------------------------------------------------------------------
 1 | app: vjepa
 2 | nodes: 16
 3 | tasks_per_node: 8
 4 | cpus_per_task: 16
 5 | mem_per_gpu: 220G
 6 | folder: /your_folder/pretrain/16.8.vitl.256px.16f
 7 | data:
 8 |   dataset_type: VideoDataset
 9 |   datasets:
10 |   - /your_k710_root_dir/k710_train_paths.csv
11 |   - /your_data_path/ssv2_train_paths.csv
12 |   - /your_data/howto_320p.csv
13 |   datasets_weights:
14 |   - 0.335
15 |   - 0.100
16 |   - 0.565
17 |   batch_size: 24
18 |   crop_size: 256
19 |   patch_size: 16
20 |   dataset_fpcs:
21 |   - 16
22 |   - 16
23 |   - 16
24 |   tubelet_size: 2
25 |   fps: 4
26 |   num_workers: 8
27 |   persistent_workers: true
28 |   pin_mem: true
29 | data_aug:
30 |   auto_augment: false
31 |   motion_shift: false
32 |   random_resize_aspect_ratio:
33 |   - 0.75
34 |   - 1.35
35 |   random_resize_scale:
36 |   - 0.3
37 |   - 1.0
38 |   reprob: 0.0
39 | loss:
40 |   loss_exp: 1.0
41 | mask:
42 | - aspect_ratio:
43 |   - 0.75
44 |   - 1.5
45 |   full_complement: false
46 |   max_keep: null
47 |   max_temporal_keep: 1.0
48 |   num_blocks: 8
49 |   spatial_scale:
50 |   - 0.15
51 |   - 0.15
52 |   temporal_scale:
53 |   - 1.0
54 |   - 1.0
55 | - aspect_ratio:
56 |   - 0.75
57 |   - 1.5
58 |   full_complement: false
59 |   max_keep: null
60 |   max_temporal_keep: 1.0
61 |   num_blocks: 2
62 |   spatial_scale:
63 |   - 0.7
64 |   - 0.7
65 |   temporal_scale:
66 |   - 1.0
67 |   - 1.0
68 | meta:
69 |   dtype: bfloat16
70 |   eval_freq: 100
71 |   load_checkpoint: true
72 |   read_checkpoint: null
73 |   save_every_freq: 50
74 |   seed: 239
75 |   use_sdpa: true
76 | model:
77 |   model_name: vit_large
78 |   pred_depth: 12
79 |   pred_embed_dim: 384
80 |   pred_num_heads: 12
81 |   uniform_power: true
82 |   use_activation_checkpointing: true
83 |   use_mask_tokens: true
84 |   use_rope: true
85 |   zero_init_mask_tokens: true
86 | optimization:
87 |   ema:
88 |   - 0.99925
89 |   - 0.99925
90 |   epochs: 10
91 |   final_lr: 0.000525
92 |   final_weight_decay: 0.04
93 |   ipe: 300
94 |   ipe_scale: 1.25
95 |   lr: 0.000525
96 |   start_lr: 0.0001
97 |   warmup: 40
98 |   weight_decay: 0.04
99 | 


--------------------------------------------------------------------------------
/evals/action_anticipation_frozen/losses.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
 2 | #
 3 | # This source code is licensed under the MIT license found in the
 4 | # LICENSE file in the root directory of this source tree.
 5 | 
 6 | import torch.nn.functional as F
 7 | 
 8 | 
 9 | def sigmoid_focal_loss(
10 |     inputs,
11 |     targets,
12 |     alpha=0.25,
13 |     gamma=2.0,
14 |     reduction="sum",
15 |     detach=False,
16 | ):
17 |     """
18 |     Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
19 | 
20 |     :param Tensor inputs: Prediction logits for each sample [B x K]
21 |     :param Tensor targets: Class label for each sample [B] (long tensor)
22 |     :param float alpha: Weight in range (0,1) to balance pos vs neg samples.
23 |     :param float gamma: Exponent of modulating factor (1-p_t) to balance easy vs hard samples.
24 |     :param str reduction: 'mean' | 'sum'
25 |     """
26 |     B, K = inputs.size()  # [batch_size, class logits]
27 | 
28 |     # convert to one-hot targets
29 |     targets = F.one_hot(targets, K).float()  # [B, K]
30 | 
31 |     p = F.sigmoid(inputs)
32 | 
33 |     ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
34 |     p_t = p * targets + (1 - p) * (1 - targets)
35 |     loss = ce_loss * ((1 - p_t) ** gamma)
36 | 
37 |     if alpha >= 0:
38 |         alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
39 |         loss = alpha_t * loss
40 | 
41 |     if reduction == "mean":
42 |         loss = loss.mean()
43 |     elif reduction == "sum":
44 |         loss = loss.sum()
45 | 
46 |     if detach:
47 |         loss = loss.detach()
48 | 
49 |     return loss
50 | 


--------------------------------------------------------------------------------
/evals/action_anticipation_frozen/metrics.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
 2 | #
 3 | # This source code is licensed under the MIT license found in the
 4 | # LICENSE file in the root directory of this source tree.
 5 | 
 6 | import torch
 7 | import torch.distributed as dist
 8 | import torch.nn.functional as F
 9 | 
10 | 
11 | class ClassMeanRecall:
12 | 
13 |     def __init__(self, num_classes: int, device: torch.device, k=5):
14 |         self.num_classes = num_classes
15 |         self.TP = torch.zeros(num_classes).to(device)
16 |         self.FN = torch.zeros(num_classes).to(device)
17 |         self.k = k
18 | 
19 |     def __call__(self, logits, labels, valid_classes=None, eps=1e-8):
20 |         """
21 |         :param logits: Tensors of shape [B, num_classes]
22 |         :param labels: Tensors of shape [B]
23 |         :param valid_classes: set
24 |         """
25 |         k, tp_tensor, fn_tensor = self.k, self.TP, self.FN
26 |         logits = F.sigmoid(logits)
27 | 
28 |         if valid_classes is not None:
29 |             _logits = torch.zeros(logits.shape).to(logits.device)
30 |             for c in valid_classes:
31 |                 _logits[:, c] = logits[:, c]
32 |             logits = _logits
33 | 
34 |         preds = logits.topk(k, dim=1).indices
35 | 
36 |         # Loop over batch and check whether all targets are within top-k logit
37 |         # predictions for their respective class, if so TP else FN
38 |         for p, gt in zip(preds, labels):
39 |             if gt in p:
40 |                 tp_tensor[gt] += 1
41 |             else:
42 |                 fn_tensor[gt] += 1
43 | 
44 |         # Aggregate TP/FN across all workers, but need to detach so that we
45 |         # don't accidentally update tp_tensor and fn_tensor, which
46 |         # only track local quantities.
47 |         TP, FN = tp_tensor.clone(), fn_tensor.clone()
48 |         dist.all_reduce(TP)
49 |         dist.all_reduce(FN)
50 | 
51 |         nch = torch.sum((TP + FN) > 0)  # num classes hit; may not have TP/FP data for all classes yet
52 |         recall = 100.0 * torch.sum(TP / (TP + FN + eps)) / nch  # mean class recall
53 |         topk = 100.0 * sum(TP) / int(sum(TP + FN))  # accuracy
54 | 
55 |         return dict(
56 |             recall=recall,
57 |             accuracy=topk,
58 |         )
59 | 


--------------------------------------------------------------------------------
/evals/action_anticipation_frozen/models.py:
--------------------------------------------------------------------------------
  1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
  2 | #
  3 | # This source code is licensed under the MIT license found in the
  4 | # LICENSE file in the root directory of this source tree.
  5 | 
  6 | import importlib
  7 | import logging
  8 | 
  9 | import torch
 10 | import torch.nn as nn
 11 | 
 12 | from src.models.attentive_pooler import AttentivePooler
 13 | 
 14 | logging.basicConfig()
 15 | logger = logging.getLogger()
 16 | logger.setLevel(logging.INFO)
 17 | 
 18 | 
 19 | class AttentiveClassifier(nn.Module):
 20 | 
 21 |     def __init__(
 22 |         self,
 23 |         verb_classes: dict,
 24 |         noun_classes: dict,
 25 |         action_classes: dict,
 26 |         embed_dim: int,
 27 |         num_heads: int,
 28 |         depth: int,
 29 |         use_activation_checkpointing: bool,
 30 |     ):
 31 |         super().__init__()
 32 |         self.num_verb_classes = len(verb_classes)
 33 |         num_noun_classes = len(noun_classes)
 34 |         num_action_classes = len(action_classes)
 35 |         self.action_only = self.num_verb_classes == 0
 36 | 
 37 |         self.pooler = AttentivePooler(
 38 |             num_queries=1 if self.action_only else 3,
 39 |             embed_dim=embed_dim,
 40 |             num_heads=num_heads,
 41 |             depth=depth,
 42 |             use_activation_checkpointing=use_activation_checkpointing,
 43 |         )
 44 |         if not self.action_only:
 45 |             self.verb_classifier = nn.Linear(embed_dim, self.num_verb_classes, bias=True)
 46 |             self.noun_classifier = nn.Linear(embed_dim, num_noun_classes, bias=True)
 47 |         self.action_classifier = nn.Linear(embed_dim, num_action_classes, bias=True)
 48 | 
 49 |     def forward(self, x):
 50 |         if torch.isnan(x).any():
 51 |             print("Nan detected at output of encoder")
 52 |             exit(1)
 53 | 
 54 |         x = self.pooler(x)  # [B, 2, D]
 55 |         if not self.action_only:
 56 |             x_verb, x_noun, x_action = x[:, 0, :], x[:, 1, :], x[:, 2, :]
 57 |             x_verb = self.verb_classifier(x_verb)
 58 |             x_noun = self.noun_classifier(x_noun)
 59 |             x_action = self.action_classifier(x_action)
 60 |             return dict(
 61 |                 verb=x_verb,
 62 |                 noun=x_noun,
 63 |                 action=x_action,
 64 |             )
 65 |         else:
 66 |             x_action = x[:, 0, :]
 67 |             x_action = self.action_classifier(x_action)
 68 |             return dict(action=x_action)
 69 | 
 70 | 
 71 | def init_module(
 72 |     module_name,
 73 |     device,
 74 |     frames_per_clip,
 75 |     frames_per_second,
 76 |     resolution,
 77 |     checkpoint,
 78 |     model_kwargs,
 79 |     wrapper_kwargs,
 80 | ):
 81 |     """
 82 |     Build (frozen) model and initialize from pretrained checkpoint
 83 | 
 84 |     API requirements for "model" module:
 85 |       1) Needs to be a pytorch module with 'forward()' function protocol:
 86 |         :param x: (Tensor) Video clip (shape=[batch_size x num_channels x num_frames x height x width])
 87 |         :param anticipation_time: (Tensor) Seconds into the future to predict for each sample in batch
 88 |             (shape=[batch_size])
 89 |         :returns: (Tensor) Representations of future frames (shape=[batch_size x num_output_tokens x feature_dim])
 90 | 
 91 |       2) Needs to have a public attribute called 'embed_dim' (int) describing its
 92 |          output feature dimension.
 93 |     """
 94 |     model = (
 95 |         importlib.import_module(f"{module_name}")
 96 |         .init_module(
 97 |             frames_per_clip=frames_per_clip,
 98 |             frames_per_second=frames_per_second,
 99 |             resolution=resolution,
100 |             checkpoint=checkpoint,
101 |             model_kwargs=model_kwargs,
102 |             wrapper_kwargs=wrapper_kwargs,
103 |         )
104 |         .to(device)
105 |     )
106 |     model.eval()
107 |     for p in model.parameters():
108 |         p.requires_grad = False
109 |     print(model)
110 |     return model
111 | 
112 | 
113 | def init_classifier(
114 |     embed_dim: int,
115 |     num_heads: int,
116 |     num_blocks: int,
117 |     device: torch.device,
118 |     num_classifiers: int,
119 |     action_classes: dict,
120 |     verb_classes: dict,
121 |     noun_classes: dict,
122 | ):
123 |     classifiers = [
124 |         AttentiveClassifier(
125 |             verb_classes=verb_classes,
126 |             noun_classes=noun_classes,
127 |             action_classes=action_classes,
128 |             embed_dim=embed_dim,
129 |             num_heads=num_heads,
130 |             depth=num_blocks,
131 |             use_activation_checkpointing=True,
132 |         ).to(device)
133 |         for _ in range(num_classifiers)
134 |     ]
135 |     print(classifiers[0])
136 |     return classifiers
137 | 


--------------------------------------------------------------------------------
/evals/action_anticipation_frozen/utils.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
 2 | #
 3 | # This source code is licensed under the MIT license found in the
 4 | # LICENSE file in the root directory of this source tree.
 5 | 
 6 | import logging
 7 | import math
 8 | 
 9 | import torch
10 | 
11 | logging.basicConfig()
12 | logger = logging.getLogger()
13 | logger.setLevel(logging.INFO)
14 | 
15 | 
16 | def init_opt(classifiers, iterations_per_epoch, opt_kwargs, num_epochs, use_bfloat16=False):
17 |     optimizers, schedulers, wd_schedulers, scalers = [], [], [], []
18 |     for c, kwargs in zip(classifiers, opt_kwargs):
19 |         param_groups = [
20 |             {
21 |                 "params": (p for n, p in c.named_parameters()),
22 |                 "mc_warmup_steps": int(kwargs.get("warmup") * iterations_per_epoch),
23 |                 "mc_start_lr": kwargs.get("start_lr"),
24 |                 "mc_ref_lr": kwargs.get("ref_lr"),
25 |                 "mc_final_lr": kwargs.get("final_lr"),
26 |                 "mc_ref_wd": kwargs.get("ref_wd"),
27 |                 "mc_final_wd": kwargs.get("final_wd"),
28 |             }
29 |         ]
30 |         logger.info("Using AdamW")
31 |         optimizers += [torch.optim.AdamW(param_groups)]
32 |         schedulers += [WarmupCosineLRSchedule(optimizers[-1], T_max=int(num_epochs * iterations_per_epoch))]
33 |         wd_schedulers += [CosineWDSchedule(optimizers[-1], T_max=int(num_epochs * iterations_per_epoch))]
34 |         scalers += [torch.cuda.amp.GradScaler() if use_bfloat16 else None]
35 |     return optimizers, scalers, schedulers, wd_schedulers
36 | 
37 | 
38 | class WarmupCosineLRSchedule(object):
39 | 
40 |     def __init__(self, optimizer, T_max, last_epoch=-1):
41 |         self.optimizer = optimizer
42 |         self.T_max = T_max
43 |         self._step = 0.0
44 | 
45 |     def step(self):
46 |         self._step += 1
47 |         for group in self.optimizer.param_groups:
48 |             ref_lr = group.get("mc_ref_lr")
49 |             final_lr = group.get("mc_final_lr")
50 |             start_lr = group.get("mc_start_lr")
51 |             warmup_steps = group.get("mc_warmup_steps")
52 |             T_max = self.T_max - warmup_steps
53 |             if self._step < warmup_steps:
54 |                 progress = float(self._step) / float(max(1, warmup_steps))
55 |                 new_lr = start_lr + progress * (ref_lr - start_lr)
56 |             else:
57 |                 # -- progress after warmup
58 |                 progress = float(self._step - warmup_steps) / float(max(1, T_max))
59 |                 new_lr = max(
60 |                     final_lr,
61 |                     final_lr + (ref_lr - final_lr) * 0.5 * (1.0 + math.cos(math.pi * progress)),
62 |                 )
63 |             group["lr"] = new_lr
64 | 
65 | 
66 | class CosineWDSchedule(object):
67 | 
68 |     def __init__(self, optimizer, T_max):
69 |         self.optimizer = optimizer
70 |         self.T_max = T_max
71 |         self._step = 0.0
72 | 
73 |     def step(self):
74 |         self._step += 1
75 |         progress = self._step / self.T_max
76 | 
77 |         for group in self.optimizer.param_groups:
78 |             ref_wd = group.get("mc_ref_wd")
79 |             final_wd = group.get("mc_final_wd")
80 |             new_wd = final_wd + (ref_wd - final_wd) * 0.5 * (1.0 + math.cos(math.pi * progress))
81 |             if final_wd <= ref_wd:
82 |                 new_wd = max(final_wd, new_wd)
83 |             else:
84 |                 new_wd = min(final_wd, new_wd)
85 |             group["weight_decay"] = new_wd
86 | 


--------------------------------------------------------------------------------
/evals/hub/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/facebookresearch/vjepa2/2dbda75ac7ab8e4e16f1c987c2b76261b54a6e3d/evals/hub/__init__.py


--------------------------------------------------------------------------------
/evals/hub/preprocessor.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
 2 | #
 3 | # This source code is licensed under the MIT license found in the
 4 | # LICENSE file in the root directory of this source tree.
 5 | 
 6 | 
 7 | def _make_transforms(crop_size=256):
 8 |     from ..video_classification_frozen.utils import make_transforms
 9 | 
10 |     return make_transforms(crop_size=crop_size, training=False)
11 | 
12 | 
13 | def vjepa2_preprocessor(*, pretrained: bool = True, **kwargs):
14 |     crop_size = kwargs.get("crop_size", 256)
15 |     return _make_transforms(crop_size=crop_size)
16 | 


--------------------------------------------------------------------------------
/evals/image_classification_frozen/modelcustom/vit_encoder.py:
--------------------------------------------------------------------------------
 1 | """
 2 | Copyright (c) Meta Platforms, Inc. and affiliates.
 3 | 
 4 | This source code is licensed under the MIT license found in the
 5 | LICENSE file in the root directory of this source tree.
 6 | ------------------------------------------------------------------------------
 7 | 
 8 | modelcustom API requirements:
 9 | 
10 | API requirements for Encoder module:
11 |     1) Needs to be a pytorch module with 'forward()' function protocol:
12 |         :param x: (Tensor) Video clip (shape=[batch_size x num_channels x num_frames x height x width])
13 |         :returns: (Tensor) Representations of video clip (shape=[batch_size x num_encoder_tokens x feature_dim])
14 |     2) Needs to have a public attribute called 'embed_dim' (int) describing its
15 |         output feature dimension.
16 | 
17 | API requirements for Predictor module:
18 |     1) Needs to be a pytorch module with 'forward()' function protocol:
19 |         :param x: (Tensor) Video clip tokens (shape=[batch_size x num_encoder_tokens x feature_dim])
20 |         :param anticipation_time: (Tensor) Seconds into the future to predict for each sample in batch
21 |             (shape=[batch_size])
22 |         :returns: (Tensor) Representations of future frames (shape=[batch_size x num_output_tokens x feature_dim])
23 |     2) Needs to have a public attribute called 'embed_dim' (int) describing its
24 |         output feature dimension.
25 | """
26 | 
27 | import logging
28 | 
29 | import torch
30 | 
31 | import src.models.vision_transformer as vit
32 | 
33 | logging.basicConfig()
34 | logger = logging.getLogger()
35 | logger.setLevel(logging.INFO)
36 | 
37 | 
38 | def init_module(
39 |     resolution: int,
40 |     checkpoint: str,
41 |     # --
42 |     model_kwargs: dict,
43 |     wrapper_kwargs: dict,
44 |     **kwargs,
45 | ):
46 |     logger.info(f"Loading pretrained model from {checkpoint=}")
47 |     checkpoint = torch.load(checkpoint, map_location="cpu")
48 | 
49 |     img_as_video_nframes = wrapper_kwargs.get("img_as_video_nframes")
50 |     # --
51 |     enc_kwargs = model_kwargs["encoder"]
52 |     enc_ckp_key = enc_kwargs.get("checkpoint_key")
53 |     enc_model_name = enc_kwargs.get("model_name")
54 | 
55 |     model = vit.__dict__[enc_model_name](
56 |         input_size=resolution,
57 |         num_frames=img_as_video_nframes,
58 |         **enc_kwargs,
59 |     )
60 | 
61 |     def forward_prehook(module, input):
62 |         input = input[0]  # [B, C, H, W]
63 |         input = input.unsqueeze(2).repeat(1, 1, img_as_video_nframes, 1, 1)
64 |         return input
65 | 
66 |     model.register_forward_pre_hook(forward_prehook)
67 | 
68 |     pretrained_dict = checkpoint[enc_ckp_key]
69 |     # --
70 |     pretrained_dict = {k.replace("module.", ""): v for k, v in pretrained_dict.items()}
71 |     pretrained_dict = {k.replace("backbone.", ""): v for k, v in pretrained_dict.items()}
72 |     for k, v in model.state_dict().items():
73 |         if k not in pretrained_dict:
74 |             logger.info(f'key "{k}" could not be found in loaded state dict')
75 |         elif pretrained_dict[k].shape != v.shape:
76 |             logger.info(f'key "{k}" is of different shape in model and loaded state dict')
77 |             pretrained_dict[k] = v
78 |     msg = model.load_state_dict(pretrained_dict, strict=False)
79 |     logger.info(f"loaded pretrained model with msg: {msg}")
80 |     print(model)
81 | 
82 |     del checkpoint
83 |     return model
84 | 


--------------------------------------------------------------------------------
/evals/image_classification_frozen/models.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
 2 | #
 3 | # This source code is licensed under the MIT license found in the
 4 | # LICENSE file in the root directory of this source tree.
 5 | 
 6 | import importlib
 7 | import logging
 8 | 
 9 | logging.basicConfig()
10 | logger = logging.getLogger()
11 | logger.setLevel(logging.INFO)
12 | 
13 | 
14 | def init_module(
15 |     module_name,
16 |     device,
17 |     resolution,
18 |     checkpoint,
19 |     model_kwargs,
20 |     wrapper_kwargs,
21 | ):
22 |     """
23 |     Build (frozen) model and initialize from pretrained checkpoint
24 | 
25 |     API requirements for Encoder module:
26 |       1) Needs to be a pytorch module with 'forward()' function protocol:
27 |         :param x: (Tensor) Video clip (shape=[batch_size x num_channels x num_frames x height x width])
28 |         :returns: (Tensor) Representations of video clip (shape=[batch_size x num_encoder_tokens x feature_dim])
29 |     """
30 |     model = (
31 |         importlib.import_module(f"{module_name}")
32 |         .init_module(
33 |             resolution=resolution,
34 |             checkpoint=checkpoint,
35 |             model_kwargs=model_kwargs,
36 |             wrapper_kwargs=wrapper_kwargs,
37 |         )
38 |         .to(device)
39 |     )
40 |     model.eval()
41 |     for p in model.parameters():
42 |         p.requires_grad = False
43 |     print(model)
44 |     return model
45 | 


--------------------------------------------------------------------------------
/evals/main.py:
--------------------------------------------------------------------------------
  1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
  2 | #
  3 | # This source code is licensed under the MIT license found in the
  4 | # LICENSE file in the root directory of this source tree.
  5 | 
  6 | import argparse
  7 | import multiprocessing as mp
  8 | import os
  9 | import pprint
 10 | 
 11 | import yaml
 12 | 
 13 | from evals.scaffold import main as eval_main
 14 | from src.utils.distributed import init_distributed
 15 | 
 16 | parser = argparse.ArgumentParser()
 17 | parser.add_argument("--val_only", action="store_true", help="only run eval", default=False)
 18 | parser.add_argument("--fname", type=str, help="name of config file to load", default="configs.yaml")
 19 | parser.add_argument(
 20 |     "--devices",
 21 |     type=str,
 22 |     nargs="+",
 23 |     default=["cuda:0", "cuda:1", "cuda:2", "cuda:3", "cuda:4", "cuda:5", "cuda:6", "cuda:7"],
 24 |     help="which devices to use on local machine",
 25 | )
 26 | parser.add_argument(
 27 |     "--debugmode",
 28 |     type=bool,
 29 |     default=False,
 30 |     help="Setting this to true will not spin up new processes. "
 31 |     "The main code runs the main process, which makes it easier to debug with checkpointing.",
 32 | )
 33 | parser.add_argument(
 34 |     "--folder",
 35 |     type=str,
 36 |     help="location to save logs",
 37 |     default="",
 38 | )
 39 | parser.add_argument("--override_config_folder", action="store_true")
 40 | parser.add_argument("--checkpoint", type=str, help="location of pretrained ckpt")
 41 | parser.add_argument("--model_name", type=str, help="Model name")
 42 | parser.add_argument("--batch_size", type=int)
 43 | parser.add_argument("--use_fsdp", action="store_true")
 44 | 
 45 | 
 46 | def process_main(args, rank, fname, world_size, devices):
 47 |     import logging
 48 |     import os
 49 | 
 50 |     os.environ["CUDA_VISIBLE_DEVICES"] = str(devices[rank].split(":")[-1])
 51 | 
 52 |     logging.basicConfig()
 53 |     logger = logging.getLogger()
 54 |     if rank == 0:
 55 |         logger.setLevel(logging.INFO)
 56 |     else:
 57 |         logger.setLevel(logging.ERROR)
 58 | 
 59 |     logger.info(f"called-params {fname}")
 60 | 
 61 |     # Load config
 62 |     params = None
 63 |     with open(fname, "r") as y_file:
 64 |         params = yaml.load(y_file, Loader=yaml.FullLoader)
 65 |         if args.val_only:
 66 |             params["val_only"] = True
 67 | 
 68 |         if args.checkpoint:
 69 |             params["model_kwargs"]["checkpoint"] = args.checkpoint
 70 | 
 71 |         if args.model_name:
 72 |             params["model_kwargs"]["pretrain_kwargs"]["encoder"]["model_name"] = args.model_name
 73 | 
 74 |         if args.batch_size:
 75 |             params["experiment"]["optimization"]["batch_size"] = args.batch_size
 76 | 
 77 |         if args.override_config_folder:
 78 |             params["folder"] = args.folder
 79 |         params["use_fsdp"] = args.use_fsdp
 80 |         logger.info("loaded params...")
 81 | 
 82 |     if rank == 0:
 83 |         pprint.PrettyPrinter(indent=4).pprint(params)
 84 | 
 85 |     # Init distributed (access to comm between GPUS on same machine)
 86 |     world_size, rank = init_distributed(rank_and_world_size=(rank, world_size))
 87 |     logger.info(f"Running... (rank: {rank}/{world_size})")
 88 | 
 89 |     # Launch the eval with loaded config
 90 |     eval_main(params["eval_name"], args_eval=params)
 91 | 
 92 | 
 93 | if __name__ == "__main__":
 94 |     args = parser.parse_args()
 95 |     if args.debugmode:
 96 |         # FSDP debugging (use torchrun)
 97 |         if args.use_fsdp:
 98 |             process_main(
 99 |                 args=args,
100 |                 rank=int(os.environ["RANK"]),
101 |                 fname=args.fname,
102 |                 world_size=int(os.environ["WORLD_SIZE"]),
103 |                 devices=args.devices,
104 |             )
105 |         # Single-GPU debugging
106 |         else:
107 |             process_main(args=args, rank=0, fname=args.fname, world_size=1, devices=["cuda:0"])
108 |     else:
109 |         num_gpus = len(args.devices)
110 |         mp.set_start_method("spawn")
111 |         for rank in range(num_gpus):
112 |             mp.Process(target=process_main, args=(args, rank, args.fname, num_gpus, args.devices)).start()
113 | 


--------------------------------------------------------------------------------
/evals/scaffold.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
 2 | #
 3 | # This source code is licensed under the MIT license found in the
 4 | # LICENSE file in the root directory of this source tree.
 5 | 
 6 | import importlib
 7 | 
 8 | from src.utils.logging import get_logger
 9 | 
10 | logger = get_logger("Eval runner scaffold")
11 | 
12 | 
13 | def main(eval_name, args_eval, resume_preempt=False):
14 |     logger.info(f"Running evaluation: {eval_name}")
15 |     if eval_name.startswith("app."):
16 |         import_path = f"{eval_name}.eval"
17 |     else:
18 |         import_path = f"evals.{eval_name}.eval"
19 |     return importlib.import_module(import_path).main(args_eval=args_eval, resume_preempt=resume_preempt)
20 | 


--------------------------------------------------------------------------------
/evals/video_classification_frozen/models.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
 2 | #
 3 | # This source code is licensed under the MIT license found in the
 4 | # LICENSE file in the root directory of this source tree.
 5 | 
 6 | import importlib
 7 | import logging
 8 | 
 9 | logging.basicConfig()
10 | logger = logging.getLogger()
11 | logger.setLevel(logging.INFO)
12 | 
13 | 
14 | def init_module(
15 |     module_name,
16 |     device,
17 |     frames_per_clip,
18 |     resolution,
19 |     checkpoint,
20 |     model_kwargs,
21 |     wrapper_kwargs,
22 | ):
23 |     """
24 |     Build (frozen) model and initialize from pretrained checkpoint
25 | 
26 |     API requirements for Encoder module:
27 |       1) Needs to be a pytorch module with 'forward()' function protocol:
28 |         :param x: (Tensor) Video clip (shape=[batch_size x num_channels x num_frames x height x width])
29 |         :returns: (Tensor) Representations of video clip (shape=[batch_size x num_encoder_tokens x feature_dim])
30 |     """
31 |     model = (
32 |         importlib.import_module(f"{module_name}")
33 |         .init_module(
34 |             frames_per_clip=frames_per_clip,
35 |             resolution=resolution,
36 |             checkpoint=checkpoint,
37 |             model_kwargs=model_kwargs,
38 |             wrapper_kwargs=wrapper_kwargs,
39 |         )
40 |         .to(device)
41 |     )
42 |     model.eval()
43 |     for p in model.parameters():
44 |         p.requires_grad = False
45 |     print(model)
46 |     return model
47 | 


--------------------------------------------------------------------------------
/hubconf.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
 2 | #
 3 | # This source code is licensed under the MIT license found in the
 4 | # LICENSE file in the root directory of this source tree.
 5 | 
 6 | from evals.hub.preprocessor import vjepa2_preprocessor
 7 | from src.hub.backbones import (
 8 |     vjepa2_ac_vit_giant,
 9 |     vjepa2_vit_giant,
10 |     vjepa2_vit_giant_384,
11 |     vjepa2_vit_huge,
12 |     vjepa2_vit_large,
13 | )
14 | 
15 | dependencies = ["torch", "timm", "einops"]
16 | 


--------------------------------------------------------------------------------
/notebooks/franka_example_traj.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/facebookresearch/vjepa2/2dbda75ac7ab8e4e16f1c987c2b76261b54a6e3d/notebooks/franka_example_traj.npz


--------------------------------------------------------------------------------
/notebooks/utils/world_model_wrapper.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
 2 | #
 3 | # This source code is licensed under the license found in the
 4 | # LICENSE file in the root directory of this source tree.
 5 | 
 6 | import numpy as np
 7 | import torch.nn.functional as F
 8 | 
 9 | from .mpc_utils import cem, compute_new_pose
10 | 
11 | 
12 | class WorldModel(object):
13 | 
14 |     def __init__(
15 |         self,
16 |         encoder,
17 |         predictor,
18 |         tokens_per_frame,
19 |         transform,
20 |         mpc_args={
21 |             "rollout": 2,
22 |             "samples": 400,
23 |             "topk": 10,
24 |             "cem_steps": 10,
25 |             "momentum_mean": 0.15,
26 |             "momentum_std": 0.15,
27 |             "maxnorm": 0.05,
28 |             "verbose": True,
29 |         },
30 |         normalize_reps=True,
31 |         device="cuda:0",
32 |     ):
33 |         super().__init__()
34 |         self.encoder = encoder
35 |         self.predictor = predictor
36 |         self.normalize_reps = normalize_reps
37 |         self.transform = transform
38 |         self.tokens_per_frame = tokens_per_frame
39 |         self.device = device
40 |         self.mpc_args = mpc_args
41 | 
42 |     def encode(self, image):
43 |         clip = np.expand_dims(image, axis=0)
44 |         clip = self.transform(clip)[None, :]
45 |         B, C, T, H, W = clip.size()
46 |         clip = clip.permute(0, 2, 1, 3, 4).flatten(0, 1).unsqueeze(2).repeat(1, 1, 2, 1, 1)
47 |         clip = clip.to(self.device, non_blocking=True)
48 |         h = self.encoder(clip)
49 |         h = h.view(B, T, -1, h.size(-1)).flatten(1, 2)
50 |         if self.normalize_reps:
51 |             h = F.layer_norm(h, (h.size(-1),))
52 |         return h
53 | 
54 |     def infer_next_action(self, rep, pose, goal_rep, close_gripper=None):
55 | 
56 |         def step_predictor(reps, actions, poses):
57 |             B, T, N_T, D = reps.size()
58 |             reps = reps.flatten(1, 2)
59 |             next_rep = self.predictor(reps, actions, poses)[:, -self.tokens_per_frame :]
60 |             if self.normalize_reps:
61 |                 next_rep = F.layer_norm(next_rep, (next_rep.size(-1),))
62 |             next_rep = next_rep.view(B, 1, N_T, D)
63 |             next_pose = compute_new_pose(poses[:, -1:], actions[:, -1:])
64 |             return next_rep, next_pose
65 | 
66 |         mpc_action = cem(
67 |             context_frame=rep,
68 |             context_pose=pose,
69 |             goal_frame=goal_rep,
70 |             world_model=step_predictor,
71 |             close_gripper=close_gripper,
72 |             **self.mpc_args,
73 |         )[0]
74 | 
75 |         return mpc_action
76 | 


--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.isort]
2 | profile="black"
3 | line_length=119
4 | 
5 | [tool.black]
6 | line-length = 119
7 | 


--------------------------------------------------------------------------------
/requirements-test.txt:
--------------------------------------------------------------------------------
1 | # Tools for static checking.
2 | black == 24.4.2
3 | flake8 == 7.0.0
4 | isort == 5.13.2


--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
 1 | torch>=2
 2 | torchvision
 3 | tensorboard
 4 | wandb
 5 | iopath
 6 | pyyaml
 7 | numpy
 8 | opencv-python
 9 | submitit
10 | braceexpand
11 | webdataset
12 | timm
13 | transformers
14 | peft
15 | decord
16 | pandas
17 | einops
18 | beartype
19 | psutil
20 | h5py
21 | fire
22 | python-box
23 | scikit-image
24 | ftfy
25 | jupyter
26 | 


--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
 2 | #
 3 | # This source code is licensed under the MIT license found in the
 4 | # LICENSE file in the root directory of this source tree.
 5 | 
 6 | from setuptools import setup
 7 | 
 8 | NAME = "vjepa2"
 9 | VERSION = "0.0.1"
10 | DESCRIPTION = "PyTorch code and models for V-JEPA 2."
11 | URL = "https://github.com/facebookresearch/vjepa2"
12 | 
13 | 
14 | def get_requirements():
15 |     with open("./requirements.txt") as reqsf:
16 |         reqs = reqsf.readlines()
17 |     return reqs
18 | 
19 | 
20 | if __name__ == "__main__":
21 |     setup(
22 |         name=NAME,
23 |         version=VERSION,
24 |         description=DESCRIPTION,
25 |         url=URL,
26 |         python_requires=">=3.11",
27 |         install_requires=get_requirements(),
28 |     )
29 | 


--------------------------------------------------------------------------------
/src/datasets/data_manager.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
 2 | #
 3 | # This source code is licensed under the MIT license found in the
 4 | # LICENSE file in the root directory of this source tree.
 5 | 
 6 | from logging import getLogger
 7 | 
 8 | _GLOBAL_SEED = 0
 9 | logger = getLogger()
10 | 
11 | 
12 | def init_data(
13 |     batch_size,
14 |     transform=None,
15 |     shared_transform=None,
16 |     data="ImageNet",
17 |     collator=None,
18 |     pin_mem=True,
19 |     num_workers=8,
20 |     world_size=1,
21 |     rank=0,
22 |     root_path=None,
23 |     image_folder=None,
24 |     training=True,
25 |     drop_last=True,
26 |     subset_file=None,
27 |     clip_len=None,
28 |     dataset_fpcs=None,
29 |     frame_sample_rate=None,
30 |     duration=None,
31 |     fps=None,
32 |     num_clips=1,
33 |     random_clip_sampling=True,
34 |     allow_clip_overlap=False,
35 |     filter_short_videos=False,
36 |     filter_long_videos=int(1e9),
37 |     datasets_weights=None,
38 |     persistent_workers=False,
39 |     deterministic=True,
40 |     log_dir=None,
41 | ):
42 |     if data.lower() == "imagenet":
43 |         from src.datasets.imagenet1k import make_imagenet1k
44 | 
45 |         dataset, data_loader, dist_sampler = make_imagenet1k(
46 |             transform=transform,
47 |             batch_size=batch_size,
48 |             collator=collator,
49 |             pin_mem=pin_mem,
50 |             training=training,
51 |             num_workers=num_workers,
52 |             world_size=world_size,
53 |             rank=rank,
54 |             root_path=root_path,
55 |             image_folder=image_folder,
56 |             persistent_workers=persistent_workers,
57 |             drop_last=drop_last,
58 |             subset_file=subset_file,
59 |         )
60 | 
61 |     elif data.lower() == "videodataset":
62 |         from src.datasets.video_dataset import make_videodataset
63 | 
64 |         dataset, data_loader, dist_sampler = make_videodataset(
65 |             data_paths=root_path,
66 |             batch_size=batch_size,
67 |             frames_per_clip=clip_len,
68 |             dataset_fpcs=dataset_fpcs,
69 |             frame_step=frame_sample_rate,
70 |             duration=duration,
71 |             fps=fps,
72 |             num_clips=num_clips,
73 |             random_clip_sampling=random_clip_sampling,
74 |             allow_clip_overlap=allow_clip_overlap,
75 |             filter_short_videos=filter_short_videos,
76 |             filter_long_videos=filter_long_videos,
77 |             shared_transform=shared_transform,
78 |             transform=transform,
79 |             datasets_weights=datasets_weights,
80 |             collator=collator,
81 |             num_workers=num_workers,
82 |             pin_mem=pin_mem,
83 |             persistent_workers=persistent_workers,
84 |             world_size=world_size,
85 |             rank=rank,
86 |             deterministic=deterministic,
87 |             log_dir=log_dir,
88 |         )
89 | 
90 |     return (data_loader, dist_sampler)
91 | 


--------------------------------------------------------------------------------
/src/datasets/imagenet1k.py:
--------------------------------------------------------------------------------
  1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
  2 | #
  3 | # This source code is licensed under the MIT license found in the
  4 | # LICENSE file in the root directory of this source tree.
  5 | 
  6 | import os
  7 | import subprocess
  8 | import time
  9 | from logging import getLogger
 10 | 
 11 | import numpy as np
 12 | import torch
 13 | import torchvision
 14 | 
 15 | _GLOBAL_SEED = 0
 16 | logger = getLogger()
 17 | 
 18 | 
 19 | class ImageNet(torchvision.datasets.ImageFolder):
 20 | 
 21 |     def __init__(
 22 |         self,
 23 |         root,
 24 |         image_folder="imagenet_full_size/061417/",
 25 |         tar_file="imagenet_full_size-061417.tar.gz",
 26 |         transform=None,
 27 |         train=True,
 28 |         job_id=None,
 29 |         local_rank=None,
 30 |         index_targets=False,
 31 |     ):
 32 |         """
 33 |         ImageNet
 34 | 
 35 |         Dataset wrapper
 36 | 
 37 |         :param root: root network directory for ImageNet data
 38 |         :param image_folder: path to images inside root network directory
 39 |         :param tar_file: zipped image_folder inside root network directory
 40 |         :param train: whether to load train data (or validation)
 41 |         :param job_id: scheduler job-id used to create dir on local machine
 42 |         :param index_targets: whether to index the id of each labeled image
 43 |         """
 44 | 
 45 |         suffix = "train/" if train else "val/"
 46 |         data_path = os.path.join(root, image_folder, suffix)
 47 |         logger.info(f"data-path {data_path}")
 48 | 
 49 |         super(ImageNet, self).__init__(root=data_path, transform=transform)
 50 |         logger.info("Initialized ImageNet")
 51 | 
 52 |         if index_targets:
 53 |             self.targets = []
 54 |             for sample in self.samples:
 55 |                 self.targets.append(sample[1])
 56 |             self.targets = np.array(self.targets)
 57 |             self.samples = np.array(self.samples)
 58 | 
 59 |             mint = None
 60 |             self.target_indices = []
 61 |             for t in range(len(self.classes)):
 62 |                 indices = np.squeeze(np.argwhere(self.targets == t)).tolist()
 63 |                 self.target_indices.append(indices)
 64 |                 mint = len(indices) if mint is None else min(mint, len(indices))
 65 |                 logger.debug(f"num-labeled target {t} {len(indices)}")
 66 |             logger.info(f"min. labeled indices {mint}")
 67 | 
 68 | 
 69 | class ImageNetSubset(object):
 70 | 
 71 |     def __init__(self, dataset, subset_file):
 72 |         """
 73 |         ImageNetSubset
 74 | 
 75 |         :param dataset: ImageNet dataset object
 76 |         :param subset_file: '.txt' file containing IDs of IN1K images to keep
 77 |         """
 78 |         self.dataset = dataset
 79 |         self.subset_file = subset_file
 80 |         self.filter_dataset_(subset_file)
 81 | 
 82 |     def filter_dataset_(self, subset_file):
 83 |         """Filter self.dataset to a subset"""
 84 |         root = self.dataset.root
 85 |         class_to_idx = self.dataset.class_to_idx
 86 |         # -- update samples to subset of IN1k targets/samples
 87 |         new_samples = []
 88 |         logger.info(f"Using {subset_file}")
 89 |         with open(subset_file, "r") as rfile:
 90 |             for line in rfile:
 91 |                 class_name = line.split("_")[0]
 92 |                 target = class_to_idx[class_name]
 93 |                 img = line.split("\n")[0]
 94 |                 new_samples.append((os.path.join(root, class_name, img), target))
 95 |         self.samples = new_samples
 96 | 
 97 |     @property
 98 |     def classes(self):
 99 |         return self.dataset.classes
100 | 
101 |     def __len__(self):
102 |         return len(self.samples)
103 | 
104 |     def __getitem__(self, index):
105 |         path, target = self.samples[index]
106 |         img = self.dataset.loader(path)
107 |         if self.dataset.transform is not None:
108 |             img = self.dataset.transform(img)
109 |         if self.dataset.target_transform is not None:
110 |             target = self.dataset.target_transform(target)
111 |         return img, target
112 | 
113 | 
114 | def make_imagenet1k(
115 |     transform,
116 |     batch_size,
117 |     collator=None,
118 |     pin_mem=True,
119 |     num_workers=8,
120 |     world_size=1,
121 |     rank=0,
122 |     root_path=None,
123 |     image_folder=None,
124 |     training=True,
125 |     drop_last=True,
126 |     persistent_workers=False,
127 |     subset_file=None,
128 | ):
129 |     dataset = ImageNet(
130 |         root=root_path,
131 |         image_folder=image_folder,
132 |         transform=transform,
133 |         train=training,
134 |         index_targets=False,
135 |     )
136 |     if subset_file is not None:
137 |         dataset = ImageNetSubset(dataset, subset_file)
138 |     logger.info("ImageNet dataset created")
139 |     dist_sampler = torch.utils.data.distributed.DistributedSampler(dataset=dataset, num_replicas=world_size, rank=rank)
140 |     data_loader = torch.utils.data.DataLoader(
141 |         dataset,
142 |         collate_fn=collator,
143 |         sampler=dist_sampler,
144 |         batch_size=batch_size,
145 |         drop_last=drop_last,
146 |         pin_memory=pin_mem,
147 |         num_workers=num_workers,
148 |         persistent_workers=persistent_workers,
149 |     )
150 |     logger.info("ImageNet unsupervised data loader created")
151 | 
152 |     return dataset, data_loader, dist_sampler
153 | 


--------------------------------------------------------------------------------
/src/datasets/utils/utils.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
 2 | #
 3 | # This source code is licensed under the MIT license found in the
 4 | # LICENSE file in the root directory of this source tree.
 5 | 
 6 | from src.utils.cluster import dataset_paths
 7 | from src.utils.logging import get_logger
 8 | 
 9 | logger = get_logger("Datasets utils")
10 | 
11 | 
12 | def get_dataset_paths(datasets: list[str]):
13 |     paths = []
14 |     for d in datasets:
15 |         try:
16 |             path = dataset_paths().get(d)
17 |         except Exception:
18 |             raise Exception(f"Unknown dataset: {d}")
19 |         paths.append(path)
20 |     logger.info(f"Datapaths {paths}")
21 |     return paths
22 | 


--------------------------------------------------------------------------------
/src/datasets/utils/video/functional.py:
--------------------------------------------------------------------------------
  1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
  2 | #
  3 | # This source code is licensed under the MIT license found in the
  4 | # LICENSE file in the root directory of this source tree.
  5 | 
  6 | import numbers
  7 | 
  8 | import cv2
  9 | import numpy as np
 10 | import PIL
 11 | import torch
 12 | from torchvision.transforms import functional as tvf
 13 | 
 14 | 
 15 | def _is_tensor_clip(clip):
 16 |     return torch.is_tensor(clip) and clip.ndimension() == 4
 17 | 
 18 | 
 19 | def crop_clip(clip, min_h, min_w, h, w):
 20 |     if isinstance(clip[0], np.ndarray) or isinstance(clip[0], torch.Tensor):
 21 |         if clip[0].shape[-1] == 3:
 22 |             cropped = [img[min_h : min_h + h, min_w : min_w + w, :] for img in clip]
 23 |         else:
 24 |             assert clip[0].shape[0] == 3
 25 |             cropped = [img[:, min_h : min_h + h, min_w : min_w + w] for img in clip]
 26 | 
 27 |     elif isinstance(clip[0], PIL.Image.Image):
 28 |         cropped = [img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip]
 29 | 
 30 |     else:
 31 |         raise TypeError(
 32 |             "Expected numpy.ndarray or PIL.Image or torch.Tensor):" + "but got list of {0}".format(type(clip[0]))
 33 |         )
 34 |     return cropped
 35 | 
 36 | 
 37 | def resize_clip(clip, size, interpolation="bilinear"):
 38 |     if isinstance(clip[0], np.ndarray) or isinstance(clip[0], torch.Tensor):
 39 |         if isinstance(size, numbers.Number):
 40 |             if clip[0].shape[-1] == 3:
 41 |                 im_h, im_w, im_c = clip[0].shape
 42 |             else:
 43 |                 assert clip[0].shape[0] == 3
 44 |                 im_c, im_h, im_w = clip[0].shape
 45 |             # Min spatial dim already matches minimal size
 46 |             if (im_w <= im_h and im_w == size) or (im_h <= im_w and im_h == size):
 47 |                 return clip
 48 |             new_h, new_w = get_resize_sizes(im_h, im_w, size)
 49 |             size = (new_w, new_h)
 50 |         else:
 51 |             size = size[0], size[1]
 52 | 
 53 |         if isinstance(clip[0], np.ndarray):
 54 |             if interpolation == "bilinear":
 55 |                 np_inter = cv2.INTER_LINEAR
 56 |             else:
 57 |                 np_inter = cv2.INTER_NEAREST
 58 |             scaled = [cv2.resize(img, size, interpolation=np_inter) for img in clip]
 59 |         else:  # isinstance(clip[0], torch.Tensor)
 60 |             if interpolation == "bilinear":
 61 |                 np_inter = tvf.InterpolationMode.BILINEAR
 62 |             else:
 63 |                 np_inter = tvf.InterpolationMode.NEAREST
 64 |             size = (size[1], size[0])  # torchvision transformers expect the size in (h, w) order.
 65 |             scaled = [tvf.resize(img, size, interpolation=np_inter) for img in clip]
 66 |     elif isinstance(clip[0], PIL.Image.Image):
 67 |         if isinstance(size, numbers.Number):
 68 |             im_w, im_h = clip[0].size
 69 |             # Min spatial dim already matches minimal size
 70 |             if (im_w <= im_h and im_w == size) or (im_h <= im_w and im_h == size):
 71 |                 return clip
 72 |             new_h, new_w = get_resize_sizes(im_h, im_w, size)
 73 |             size = (new_w, new_h)
 74 |         else:
 75 |             size = size[1], size[0]
 76 |         if interpolation == "bilinear":
 77 |             pil_inter = PIL.Image.BILINEAR
 78 |         else:
 79 |             pil_inter = PIL.Image.NEAREST
 80 |         scaled = [img.resize(size, pil_inter) for img in clip]
 81 |     else:
 82 |         raise TypeError(
 83 |             "Expected numpy.ndarray or PIL.Image or torch.Tensor" + "but got list of {0}".format(type(clip[0]))
 84 |         )
 85 |     return scaled
 86 | 
 87 | 
 88 | def get_resize_sizes(im_h, im_w, size):
 89 |     if im_w < im_h:
 90 |         ow = size
 91 |         oh = int(size * im_h / im_w)
 92 |     else:
 93 |         oh = size
 94 |         ow = int(size * im_w / im_h)
 95 |     return oh, ow
 96 | 
 97 | 
 98 | def normalize(clip, mean, std, inplace=False):
 99 |     if not _is_tensor_clip(clip):
100 |         raise TypeError("tensor is not a torch clip.")
101 | 
102 |     if not inplace:
103 |         clip = clip.clone()
104 | 
105 |     dtype = clip.dtype
106 |     mean = torch.as_tensor(mean, dtype=dtype, device=clip.device)
107 |     std = torch.as_tensor(std, dtype=dtype, device=clip.device)
108 |     clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
109 | 
110 |     return clip
111 | 


--------------------------------------------------------------------------------
/src/datasets/utils/video/transforms_builder.py:
--------------------------------------------------------------------------------
  1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
  2 | #
  3 | # This source code is licensed under the MIT license found in the
  4 | # LICENSE file in the root directory of this source tree.
  5 | 
  6 | from typing import Optional
  7 | 
  8 | import torch
  9 | import torchvision.transforms as transforms
 10 | 
 11 | import src.datasets.utils.video.transforms as video_transforms
 12 | from src.datasets.utils.video.randerase import RandomErasing
 13 | 
 14 | 
 15 | def make_transforms(
 16 |     random_horizontal_flip=True,
 17 |     random_resize_aspect_ratio=(3 / 4, 4 / 3),
 18 |     random_resize_scale=(0.3, 1.0),
 19 |     reprob=0.0,
 20 |     auto_augment=False,
 21 |     motion_shift=False,
 22 |     crop_size=224,
 23 |     normalize=((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
 24 |     pad_frame_count: Optional[int] = None,
 25 |     pad_frame_method: str = "circulant",
 26 | ):
 27 |     _frames_augmentation = VideoTransform(
 28 |         random_horizontal_flip=random_horizontal_flip,
 29 |         random_resize_aspect_ratio=random_resize_aspect_ratio,
 30 |         random_resize_scale=random_resize_scale,
 31 |         reprob=reprob,
 32 |         auto_augment=auto_augment,
 33 |         motion_shift=motion_shift,
 34 |         crop_size=crop_size,
 35 |         normalize=normalize,
 36 |         pad_frame_count=pad_frame_count,
 37 |         pad_frame_method=pad_frame_method,
 38 |     )
 39 |     return _frames_augmentation
 40 | 
 41 | 
 42 | class VideoTransform(object):
 43 | 
 44 |     def __init__(
 45 |         self,
 46 |         random_horizontal_flip=True,
 47 |         random_resize_aspect_ratio=(3 / 4, 4 / 3),
 48 |         random_resize_scale=(0.3, 1.0),
 49 |         reprob=0.0,
 50 |         auto_augment=False,
 51 |         motion_shift=False,
 52 |         crop_size=224,
 53 |         normalize=((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
 54 |         pad_frame_count: Optional[int] = None,
 55 |         pad_frame_method: str = "circulant",
 56 |     ):
 57 |         self.random_horizontal_flip = random_horizontal_flip
 58 |         self.random_resize_aspect_ratio = random_resize_aspect_ratio
 59 |         self.random_resize_scale = random_resize_scale
 60 |         self.auto_augment = auto_augment
 61 |         self.motion_shift = motion_shift
 62 |         self.crop_size = crop_size
 63 |         self.mean = torch.tensor(normalize[0], dtype=torch.float32)
 64 |         self.std = torch.tensor(normalize[1], dtype=torch.float32)
 65 |         self.pad_frame_count = pad_frame_count
 66 |         self.pad_frame_method = pad_frame_method
 67 | 
 68 |         if not self.auto_augment:
 69 |             # Without auto-augment, PIL and tensor conversions simply scale uint8 space by 255.
 70 |             self.mean *= 255.0
 71 |             self.std *= 255.0
 72 | 
 73 |         self.autoaug_transform = video_transforms.create_random_augment(
 74 |             input_size=(crop_size, crop_size),
 75 |             auto_augment="rand-m7-n4-mstd0.5-inc1",
 76 |             interpolation="bicubic",
 77 |         )
 78 | 
 79 |         self.spatial_transform = (
 80 |             video_transforms.random_resized_crop_with_shift if motion_shift else video_transforms.random_resized_crop
 81 |         )
 82 | 
 83 |         self.reprob = reprob
 84 |         self.erase_transform = RandomErasing(
 85 |             reprob,
 86 |             mode="pixel",
 87 |             max_count=1,
 88 |             num_splits=1,
 89 |             device="cpu",
 90 |         )
 91 | 
 92 |     def __call__(self, buffer):
 93 | 
 94 |         if self.auto_augment:
 95 |             buffer = [transforms.ToPILImage()(frame) for frame in buffer]
 96 |             buffer = self.autoaug_transform(buffer)
 97 |             buffer = [transforms.ToTensor()(img) for img in buffer]
 98 |             buffer = torch.stack(buffer)  # T C H W
 99 |             buffer = buffer.permute(0, 2, 3, 1)  # T H W C
100 |         elif torch.is_tensor(buffer):
101 |             # TODO: ensure input is always a tensor?
102 |             buffer = buffer.to(torch.float32)
103 |         else:
104 |             buffer = torch.tensor(buffer, dtype=torch.float32)
105 | 
106 |         buffer = buffer.permute(3, 0, 1, 2)  # T H W C -> C T H W
107 | 
108 |         buffer = self.spatial_transform(
109 |             images=buffer,
110 |             target_height=self.crop_size,
111 |             target_width=self.crop_size,
112 |             scale=self.random_resize_scale,
113 |             ratio=self.random_resize_aspect_ratio,
114 |         )
115 |         if self.random_horizontal_flip:
116 |             buffer, _ = video_transforms.horizontal_flip(0.5, buffer)
117 | 
118 |         buffer = _tensor_normalize_inplace(buffer, self.mean, self.std)
119 |         if self.reprob > 0:
120 |             buffer = buffer.permute(1, 0, 2, 3)
121 |             buffer = self.erase_transform(buffer)
122 |             buffer = buffer.permute(1, 0, 2, 3)
123 | 
124 |         if self.pad_frame_count is not None:
125 |             buffer = video_transforms.frame_pad(buffer, self.pad_frame_count, self.pad_frame_method)
126 | 
127 |         return buffer
128 | 
129 | 
130 | def tensor_normalize(tensor, mean, std):
131 |     """
132 |     Normalize a given tensor by subtracting the mean and dividing the std.
133 |     Args:
134 |         tensor (tensor): tensor to normalize.
135 |         mean (tensor or list): mean value to subtract.
136 |         std (tensor or list): std to divide.
137 |     """
138 |     if tensor.dtype == torch.uint8:
139 |         tensor = tensor.float()
140 |         tensor = tensor / 255.0
141 |     if isinstance(mean, list):
142 |         mean = torch.tensor(mean)
143 |     if isinstance(std, list):
144 |         std = torch.tensor(std)
145 |     tensor = tensor - mean
146 |     tensor = tensor / std
147 |     return tensor
148 | 
149 | 
150 | def _tensor_normalize_inplace(tensor, mean, std):
151 |     """
152 |     Normalize a given tensor by subtracting the mean and dividing the std.
153 |     Args:
154 |         tensor (tensor): tensor to normalize (with dimensions C, T, H, W).
155 |         mean (tensor): mean value to subtract (in 0 to 255 floats).
156 |         std (tensor): std to divide (in 0 to 255 floats).
157 |     """
158 |     if tensor.dtype == torch.uint8:
159 |         tensor = tensor.float()
160 | 
161 |     C, T, H, W = tensor.shape
162 |     tensor = tensor.view(C, -1).permute(1, 0)  # Make C the last dimension
163 |     tensor.sub_(mean).div_(std)
164 |     tensor = tensor.permute(1, 0).view(C, T, H, W)  # Put C back in front
165 |     return tensor
166 | 


--------------------------------------------------------------------------------
/src/datasets/utils/video/volume_transforms.py:
--------------------------------------------------------------------------------
  1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
  2 | #
  3 | # This source code is licensed under the MIT license found in the
  4 | # LICENSE file in the root directory of this source tree.
  5 | 
  6 | import numpy as np
  7 | import torch
  8 | from PIL import Image
  9 | 
 10 | 
 11 | def convert_img(img):
 12 |     """Converts (H, W, C) numpy.ndarray to (C, W, H) format"""
 13 |     if len(img.shape) == 3:
 14 |         img = img.transpose(2, 0, 1)
 15 |     if len(img.shape) == 2:
 16 |         img = np.expand_dims(img, 0)
 17 |     return img
 18 | 
 19 | 
 20 | class ClipToTensor(object):
 21 |     """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255]
 22 |     to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0]
 23 |     """
 24 | 
 25 |     def __init__(self, channel_nb=3, div_255=True, numpy=False):
 26 |         self.channel_nb = channel_nb
 27 |         self.div_255 = div_255
 28 |         self.numpy = numpy
 29 | 
 30 |     def __call__(self, clip):
 31 |         """
 32 |         Args: clip (list of numpy.ndarray): clip (list of images)
 33 |         to be converted to tensor.
 34 |         """
 35 |         # Retrieve shape
 36 |         if isinstance(clip[0], np.ndarray):
 37 |             h, w, ch = clip[0].shape
 38 |             assert ch == self.channel_nb, "Got {0} instead of 3 channels".format(ch)
 39 |         elif isinstance(clip[0], Image.Image):
 40 |             w, h = clip[0].size
 41 |         elif isinstance(clip[0], torch.Tensor):
 42 |             tensor_clip = torch.stack(clip)
 43 |             # Converting (T, C, H, W) -> (C, T, H, W) to match what `convert_img` followed by
 44 |             # `np_clip[:, img_idx, :, :] = img` does for other data types.
 45 |             tensor_clip = tensor_clip.permute(1, 0, 2, 3)
 46 |             if not isinstance(tensor_clip, torch.FloatTensor):
 47 |                 tensor_clip = tensor_clip.float()
 48 |             if self.div_255:
 49 |                 tensor_clip = torch.div(tensor_clip, 255)
 50 |             return tensor_clip
 51 |         else:
 52 |             raise TypeError(
 53 |                 "Expected numpy.ndarray or PIL.Image or torch.Tensor\
 54 |             but got list of {0}".format(
 55 |                     type(clip[0])
 56 |                 )
 57 |             )
 58 | 
 59 |         np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)])
 60 | 
 61 |         # Convert
 62 |         for img_idx, img in enumerate(clip):
 63 |             if isinstance(img, np.ndarray):
 64 |                 pass
 65 |             elif isinstance(img, Image.Image):
 66 |                 img = np.array(img, copy=False)
 67 |             else:
 68 |                 raise TypeError(
 69 |                     "Expected numpy.ndarray or PIL.Image\
 70 |                 but got list of {0}".format(
 71 |                         type(clip[0])
 72 |                     )
 73 |                 )
 74 |             img = convert_img(img)
 75 |             np_clip[:, img_idx, :, :] = img
 76 | 
 77 |         if self.numpy:
 78 |             if self.div_255:
 79 |                 np_clip = np_clip / 255.0
 80 |             return np_clip
 81 | 
 82 |         else:
 83 |             tensor_clip = torch.from_numpy(np_clip)
 84 | 
 85 |             if not isinstance(tensor_clip, torch.FloatTensor):
 86 |                 tensor_clip = tensor_clip.float()
 87 |             if self.div_255:
 88 |                 tensor_clip = torch.div(tensor_clip, 255)
 89 |             return tensor_clip
 90 | 
 91 | 
 92 | # Note this norms data to -1/1
 93 | class ClipToTensor_K(object):
 94 |     """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255]
 95 |     to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0]
 96 |     """
 97 | 
 98 |     def __init__(self, channel_nb=3, div_255=True, numpy=False):
 99 |         self.channel_nb = channel_nb
100 |         self.div_255 = div_255
101 |         self.numpy = numpy
102 | 
103 |     def __call__(self, clip):
104 |         """
105 |         Args: clip (list of numpy.ndarray): clip (list of images)
106 |         to be converted to tensor.
107 |         """
108 |         # Retrieve shape
109 |         if isinstance(clip[0], np.ndarray):
110 |             h, w, ch = clip[0].shape
111 |             assert ch == self.channel_nb, "Got {0} instead of 3 channels".format(ch)
112 |         elif isinstance(clip[0], Image.Image):
113 |             w, h = clip[0].size
114 |         else:
115 |             raise TypeError(
116 |                 "Expected numpy.ndarray or PIL.Image\
117 |             but got list of {0}".format(
118 |                     type(clip[0])
119 |                 )
120 |             )
121 | 
122 |         np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)])
123 | 
124 |         # Convert
125 |         for img_idx, img in enumerate(clip):
126 |             if isinstance(img, np.ndarray):
127 |                 pass
128 |             elif isinstance(img, Image.Image):
129 |                 img = np.array(img, copy=False)
130 |             else:
131 |                 raise TypeError(
132 |                     "Expected numpy.ndarray or PIL.Image\
133 |                 but got list of {0}".format(
134 |                         type(clip[0])
135 |                     )
136 |                 )
137 |             img = convert_img(img)
138 |             np_clip[:, img_idx, :, :] = img
139 |         if self.numpy:
140 |             if self.div_255:
141 |                 np_clip = (np_clip - 127.5) / 127.5
142 |             return np_clip
143 | 
144 |         else:
145 |             tensor_clip = torch.from_numpy(np_clip)
146 | 
147 |             if not isinstance(tensor_clip, torch.FloatTensor):
148 |                 tensor_clip = tensor_clip.float()
149 |             if self.div_255:
150 |                 tensor_clip = torch.div(torch.sub(tensor_clip, 127.5), 127.5)
151 |             return tensor_clip
152 | 
153 | 
154 | class ToTensor(object):
155 |     """Converts numpy array to tensor"""
156 | 
157 |     def __call__(self, array):
158 |         tensor = torch.from_numpy(array)
159 |         return tensor
160 | 


--------------------------------------------------------------------------------
/src/datasets/utils/worker_init_fn.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
 2 | 
 3 | # Copyright The Lightning AI team.
 4 | 
 5 | # Licensed under the Apache License, Version 2.0 (the "License");
 6 | # you may not use this file except in compliance with the License.
 7 | # You may obtain a copy of the License at
 8 | 
 9 | #     http://www.apache.org/licenses/LICENSE-2.0
10 | 
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | 
17 | # This code originally comes from PyTorch Lighting with some light modificaitons:
18 | # https://github.com/Lightning-AI/pytorch-lightning/blob/a944e7744e57a5a2c13f3c73b9735edf2f71e329/src/lightning/fabric/utilities/seed.py
19 | 
20 | 
21 | import os
22 | import random
23 | from typing import Optional
24 | 
25 | import numpy as np
26 | import torch
27 | 
28 | from src.utils.logging import get_logger
29 | 
30 | logger = get_logger("worker_init_fn")
31 | 
32 | 
33 | def _generate_seed_sequence(base_seed: int, worker_id: int, global_rank: int, count: int) -> list[int]:
34 |     """Generates a sequence of seeds from a base seed, worker id and rank using the linear congruential generator (LCG)
35 |     algorithm."""
36 |     # Combine base seed, worker id and rank into a unique 64-bit number
37 |     combined_seed = (base_seed << 32) | (worker_id << 16) | global_rank
38 |     seeds = []
39 |     for _ in range(count):
40 |         # x_(n+1) = (a * x_n + c) mod m. With c=1, m=2^64 and a is D. Knuth's constant
41 |         combined_seed = (combined_seed * 6364136223846793005 + 1) & ((1 << 64) - 1)
42 |         seeds.append(combined_seed)
43 |     return seeds
44 | 
45 | 
46 | def pl_worker_init_function(worker_id: int, rank: Optional[int] = None) -> None:  # pragma: no cover
47 |     r"""The worker_init_fn that Lightning automatically adds to your dataloader if you previously set the seed with
48 |     ``seed_everything(seed, workers=True)``.
49 | 
50 |     See also the PyTorch documentation on
51 |     `randomness in DataLoaders <https://pytorch.org/docs/stable/notes/randomness.html#dataloader>`_.
52 | 
53 |     """
54 |     # implementation notes: https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562
55 |     if rank is None:
56 |         procid = os.environ.get("SLURM_PROCID")
57 |         if procid is None:
58 |             logger.warning("SLURM_PROCID is not set, setting rank to 0")
59 |             rank = 0
60 |         else:
61 |             rank = int(procid)
62 | 
63 |     process_seed = torch.initial_seed()
64 |     # back out the base seed so we can use all the bits
65 |     base_seed = process_seed - worker_id
66 |     logger.debug(
67 |         f"Initializing random number generators of process {rank} worker {worker_id} with base seed {base_seed}"
68 |     )
69 |     seed_sequence = _generate_seed_sequence(base_seed, worker_id, rank, count=4)
70 |     torch.manual_seed(seed_sequence[0])  # torch takes a 64-bit seed
71 |     random.seed((seed_sequence[1] << 32) | seed_sequence[2])  # combine two 64-bit seeds
72 | 
73 |     ss = np.random.SeedSequence([base_seed, worker_id, rank])
74 |     np_rng_seed = ss.generate_state(4)
75 | 
76 |     np.random.seed(np_rng_seed)
77 | 


--------------------------------------------------------------------------------
/src/hub/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/facebookresearch/vjepa2/2dbda75ac7ab8e4e16f1c987c2b76261b54a6e3d/src/hub/__init__.py


--------------------------------------------------------------------------------
/src/hub/backbones.py:
--------------------------------------------------------------------------------
  1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
  2 | #
  3 | # This source code is licensed under the MIT license found in the
  4 | # LICENSE file in the root directory of this source tree.
  5 | 
  6 | import torch
  7 | 
  8 | VJEPA_BASE_URL = "https://dl.fbaipublicfiles.com/vjepa2"
  9 | 
 10 | # for testing
 11 | # VJEPA_BASE_URL = "http://localhost:8300"
 12 | 
 13 | ARCH_NAME_MAP = {
 14 |     "vit_large": ("vit_large", "vitl"),
 15 |     "vit_huge": ("vit_huge", "vith"),
 16 |     "vit_giant": ("vit_giant_xformers", "vitg"),
 17 |     "vit_ac_giant": ("vit_giant_xformers", "vjepa2-ac-vitg"),
 18 |     "vit_giant_384": ("vit_giant_xformers", "vitg-384"),
 19 | }
 20 | 
 21 | 
 22 | def _clean_backbone_key(state_dict):
 23 |     for key, val in state_dict.copy().items():
 24 |         _ = state_dict.pop(key)
 25 |         key = key.replace("module.", "")
 26 |         key = key.replace("backbone.", "")
 27 |         state_dict[key] = val
 28 |     return state_dict
 29 | 
 30 | 
 31 | def _make_vjepa2_ac_model(
 32 |     *,
 33 |     model_name: str = "vit_ac_giant",
 34 |     img_size=256,
 35 |     patch_size=16,
 36 |     tubelet_size=2,
 37 |     num_frames=64,
 38 |     pretrained: bool = True,
 39 |     **kwargs,
 40 | ):
 41 |     from ..models import ac_predictor as vit_ac_predictor
 42 |     from ..models import vision_transformer as vit_encoder
 43 | 
 44 |     vit_encoder_kwargs = dict(
 45 |         patch_size=patch_size,
 46 |         img_size=(img_size, img_size),
 47 |         num_frames=num_frames,
 48 |         tubelet_size=tubelet_size,
 49 |         use_sdpa=True,
 50 |         use_SiLU=False,
 51 |         wide_SiLU=True,
 52 |         uniform_power=False,
 53 |         use_rope=True,
 54 |     )
 55 |     vit_encoder_kwargs.update(**kwargs)
 56 | 
 57 |     arch_name = ARCH_NAME_MAP[model_name][0]
 58 |     encoder = vit_encoder.__dict__[arch_name](**vit_encoder_kwargs)
 59 | 
 60 |     vit_predictor_kwargs = dict(
 61 |         img_size=(img_size, img_size),
 62 |         patch_size=patch_size,
 63 |         num_frames=num_frames,
 64 |         tubelet_size=tubelet_size,
 65 |         embed_dim=encoder.embed_dim,
 66 |     )
 67 |     vit_predictor_kwargs.update(**kwargs)
 68 | 
 69 |     predictor = vit_ac_predictor.__dict__["vit_ac_predictor"](**vit_predictor_kwargs)
 70 | 
 71 |     if pretrained:
 72 |         model_file = ARCH_NAME_MAP[model_name][-1]
 73 |         url = VJEPA_BASE_URL + f"/{model_file}.pt"
 74 |         state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
 75 |         encoder_state_dict = _clean_backbone_key(state_dict["encoder"])
 76 |         encoder.load_state_dict(encoder_state_dict, strict=False)
 77 |         predictor_state_dict = _clean_backbone_key(state_dict["predictor"])
 78 |         predictor.load_state_dict(predictor_state_dict, strict=True)
 79 | 
 80 |     return encoder, predictor
 81 | 
 82 | 
 83 | def _make_vjepa2_model(
 84 |     *,
 85 |     model_name: str = "vit_large",
 86 |     img_size=256,
 87 |     patch_size=16,
 88 |     tubelet_size=2,
 89 |     num_frames=64,
 90 |     pretrained: bool = True,
 91 |     **kwargs,
 92 | ):
 93 |     from ..models import predictor as vit_predictor
 94 |     from ..models import vision_transformer as vit_encoder
 95 | 
 96 |     vit_encoder_kwargs = dict(
 97 |         patch_size=patch_size,
 98 |         img_size=(img_size, img_size),
 99 |         num_frames=num_frames,
100 |         tubelet_size=tubelet_size,
101 |         use_sdpa=True,
102 |         use_SiLU=False,
103 |         wide_SiLU=True,
104 |         uniform_power=False,
105 |         use_rope=True,
106 |     )
107 |     vit_encoder_kwargs.update(**kwargs)
108 | 
109 |     arch_name = ARCH_NAME_MAP[model_name][0]
110 |     encoder = vit_encoder.__dict__[arch_name](**vit_encoder_kwargs)
111 | 
112 |     vit_predictor_kwargs = dict(
113 |         img_size=(img_size, img_size),
114 |         patch_size=patch_size,
115 |         use_mask_tokens=True,
116 |         embed_dim=encoder.embed_dim,
117 |         predictor_embed_dim=384,
118 |         num_frames=num_frames,
119 |         tubelet_size=tubelet_size,
120 |         depth=12,
121 |         num_heads=12,
122 |         num_mask_tokens=10,
123 |         use_rope=True,
124 |         uniform_power=False,
125 |         use_sdpa=True,
126 |         use_silu=False,
127 |         wide_silu=True,
128 |     )
129 |     vit_predictor_kwargs.update(**kwargs)
130 | 
131 |     predictor = vit_predictor.__dict__["vit_predictor"](**vit_predictor_kwargs)
132 | 
133 |     if pretrained:
134 |         model_file = ARCH_NAME_MAP[model_name][-1]
135 |         url = VJEPA_BASE_URL + f"/{model_file}.pt"
136 |         state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
137 |         encoder_state_dict = _clean_backbone_key(state_dict["encoder"])
138 |         encoder.load_state_dict(encoder_state_dict, strict=False)  # state_dict has pos_embed but we use RoPE
139 |         predictor_state_dict = _clean_backbone_key(state_dict["predictor"])
140 |         predictor.load_state_dict(predictor_state_dict, strict=False)  # state_dict has pos_embed but we use RoPE
141 | 
142 |     return encoder, predictor
143 | 
144 | 
145 | def vjepa2_vit_large(*, pretrained: bool = True, **kwargs):
146 |     """
147 |     VJEPA 2 ViT-Large model
148 |     """
149 |     return _make_vjepa2_model(model_name="vit_large", img_size=256, pretrained=pretrained, **kwargs)
150 | 
151 | 
152 | def vjepa2_vit_huge(*, pretrained: bool = True, **kwargs):
153 |     """
154 |     VJEPA 2 ViT-Huge model
155 |     """
156 |     return _make_vjepa2_model(model_name="vit_huge", img_size=256, pretrained=pretrained, **kwargs)
157 | 
158 | 
159 | def vjepa2_vit_giant(*, pretrained: bool = True, **kwargs):
160 |     """
161 |     VJEPA 2 ViT-giant model
162 |     """
163 |     return _make_vjepa2_model(model_name="vit_giant", img_size=256, pretrained=pretrained, **kwargs)
164 | 
165 | 
166 | def vjepa2_vit_giant_384(*, pretrained: bool = True, **kwargs):
167 |     """
168 |     VJEPA 2 ViT-giant-384 model
169 |     """
170 |     return _make_vjepa2_model(model_name="vit_giant_384", img_size=384, pretrained=pretrained, **kwargs)
171 | 
172 | 
173 | def vjepa2_ac_vit_giant(*, pretrained: bool = True, **kwargs):
174 |     """
175 |     VJEPA 2-AC ViT-giant model
176 |     """
177 |     return _make_vjepa2_ac_model(model_name="vit_ac_giant", img_size=256, pretrained=pretrained, **kwargs)
178 | 


--------------------------------------------------------------------------------
/src/masks/default.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
 2 | #
 3 | # This source code is licensed under the MIT license found in the
 4 | # LICENSE file in the root directory of this source tree.
 5 | 
 6 | from logging import getLogger
 7 | 
 8 | import torch
 9 | 
10 | _GLOBAL_SEED = 0
11 | logger = getLogger()
12 | 
13 | 
14 | class DefaultCollator(object):
15 | 
16 |     def __call__(self, batch):
17 |         collated_batch = torch.utils.data.default_collate(batch)
18 |         return collated_batch, None, None
19 | 


--------------------------------------------------------------------------------
/src/masks/utils.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
 2 | #
 3 | # This source code is licensed under the MIT license found in the
 4 | # LICENSE file in the root directory of this source tree.
 5 | 
 6 | import torch
 7 | 
 8 | 
 9 | def apply_masks(x, masks, concat=True):
10 |     """
11 |     :param x: tensor of shape [B (batch-size), N (num-patches), D (feature-dim)]
12 |     :param masks: list of tensors of shape [B, K] containing indices of K patches in [N] to keep
13 |     """
14 |     all_x = []
15 |     for m in masks:
16 |         mask_keep = m.unsqueeze(-1).repeat(1, 1, x.size(-1))
17 |         all_x += [torch.gather(x, dim=1, index=mask_keep)]
18 |     if not concat:
19 |         return all_x
20 | 
21 |     return torch.cat(all_x, dim=0)
22 | 


--------------------------------------------------------------------------------
/src/models/attentive_pooler.py:
--------------------------------------------------------------------------------
  1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
  2 | #
  3 | # This source code is licensed under the MIT license found in the
  4 | # LICENSE file in the root directory of this source tree.
  5 | 
  6 | 
  7 | import math
  8 | 
  9 | import torch
 10 | import torch.nn as nn
 11 | 
 12 | from src.models.utils.modules import Block, CrossAttention, CrossAttentionBlock
 13 | from src.utils.tensors import trunc_normal_
 14 | 
 15 | 
 16 | class AttentivePooler(nn.Module):
 17 |     """Attentive Pooler"""
 18 | 
 19 |     def __init__(
 20 |         self,
 21 |         num_queries=1,
 22 |         embed_dim=768,
 23 |         num_heads=12,
 24 |         mlp_ratio=4.0,
 25 |         depth=1,
 26 |         norm_layer=nn.LayerNorm,
 27 |         init_std=0.02,
 28 |         qkv_bias=True,
 29 |         complete_block=True,
 30 |         use_activation_checkpointing=False,
 31 |     ):
 32 |         super().__init__()
 33 |         self.use_activation_checkpointing = use_activation_checkpointing
 34 |         self.query_tokens = nn.Parameter(torch.zeros(1, num_queries, embed_dim))
 35 | 
 36 |         self.complete_block = complete_block
 37 |         if complete_block:
 38 |             self.cross_attention_block = CrossAttentionBlock(
 39 |                 dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, norm_layer=norm_layer
 40 |             )
 41 |         else:
 42 |             self.cross_attention_block = CrossAttention(dim=embed_dim, num_heads=num_heads, qkv_bias=qkv_bias)
 43 | 
 44 |         self.blocks = None
 45 |         if depth > 1:
 46 |             self.blocks = nn.ModuleList(
 47 |                 [
 48 |                     Block(
 49 |                         dim=embed_dim,
 50 |                         num_heads=num_heads,
 51 |                         mlp_ratio=mlp_ratio,
 52 |                         qkv_bias=qkv_bias,
 53 |                         qk_scale=False,
 54 |                         norm_layer=norm_layer,
 55 |                     )
 56 |                     for i in range(depth - 1)
 57 |                 ]
 58 |             )
 59 | 
 60 |         self.init_std = init_std
 61 |         trunc_normal_(self.query_tokens, std=self.init_std)
 62 |         self.apply(self._init_weights)
 63 |         self._rescale_blocks()
 64 | 
 65 |     def _rescale_blocks(self):
 66 |         def rescale(param, layer_id):
 67 |             param.div_(math.sqrt(2.0 * layer_id))
 68 | 
 69 |         layer_id = 0
 70 |         if self.blocks is not None:
 71 |             for layer_id, layer in enumerate(self.blocks):
 72 |                 rescale(layer.attn.proj.weight.data, layer_id + 1)
 73 |                 rescale(layer.mlp.fc2.weight.data, layer_id + 1)
 74 | 
 75 |         if self.complete_block:
 76 |             rescale(self.cross_attention_block.mlp.fc2.weight.data, layer_id + 1)
 77 | 
 78 |     def _init_weights(self, m):
 79 |         if isinstance(m, nn.Linear):
 80 |             trunc_normal_(m.weight, std=self.init_std)
 81 |             if isinstance(m, nn.Linear) and m.bias is not None:
 82 |                 nn.init.constant_(m.bias, 0)
 83 |         elif isinstance(m, nn.LayerNorm):
 84 |             nn.init.constant_(m.bias, 0)
 85 |             nn.init.constant_(m.weight, 1.0)
 86 |         elif isinstance(m, nn.Conv2d):
 87 |             trunc_normal_(m.weight, std=self.init_std)
 88 |             if m.bias is not None:
 89 |                 nn.init.constant_(m.bias, 0)
 90 | 
 91 |     def forward(self, x):
 92 |         if self.blocks is not None:
 93 |             for blk in self.blocks:
 94 |                 if self.use_activation_checkpointing:
 95 |                     x = torch.utils.checkpoint.checkpoint(blk, x, False, None, use_reentrant=False)
 96 |                 else:
 97 |                     x = blk(x)
 98 |         q = self.query_tokens.repeat(len(x), 1, 1)
 99 |         q = self.cross_attention_block(q, x)
100 |         return q
101 | 
102 | 
103 | class AttentiveClassifier(nn.Module):
104 |     """Attentive Classifier"""
105 | 
106 |     def __init__(
107 |         self,
108 |         embed_dim=768,
109 |         num_heads=12,
110 |         mlp_ratio=4.0,
111 |         depth=1,
112 |         norm_layer=nn.LayerNorm,
113 |         init_std=0.02,
114 |         qkv_bias=True,
115 |         num_classes=1000,
116 |         complete_block=True,
117 |         use_activation_checkpointing=False,
118 |     ):
119 |         super().__init__()
120 |         self.pooler = AttentivePooler(
121 |             num_queries=1,
122 |             embed_dim=embed_dim,
123 |             num_heads=num_heads,
124 |             mlp_ratio=mlp_ratio,
125 |             depth=depth,
126 |             norm_layer=norm_layer,
127 |             init_std=init_std,
128 |             qkv_bias=qkv_bias,
129 |             complete_block=complete_block,
130 |             use_activation_checkpointing=use_activation_checkpointing,
131 |         )
132 |         self.linear = nn.Linear(embed_dim, num_classes, bias=True)
133 | 
134 |     def forward(self, x):
135 |         x = self.pooler(x).squeeze(1)
136 |         x = self.linear(x)
137 |         return x
138 | 


--------------------------------------------------------------------------------
/src/models/utils/patch_embed.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
 2 | #
 3 | # This source code is licensed under the MIT license found in the
 4 | # LICENSE file in the root directory of this source tree.
 5 | 
 6 | import torch.nn as nn
 7 | from einops import rearrange
 8 | 
 9 | 
10 | class PatchEmbed(nn.Module):
11 |     """
12 |     Image to Patch Embedding
13 |     """
14 | 
15 |     def __init__(self, patch_size=16, in_chans=3, embed_dim=768):
16 |         super().__init__()
17 |         self.patch_size = patch_size
18 |         self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
19 | 
20 |     def forward(self, x):
21 |         B, C, H, W = x.shape
22 |         x = self.proj(x).flatten(2).transpose(1, 2)
23 |         return x
24 | 
25 | 
26 | class PatchEmbed3D(nn.Module):
27 |     """
28 |     Image to Patch Embedding
29 |     """
30 | 
31 |     def __init__(
32 |         self,
33 |         patch_size=16,
34 |         tubelet_size=2,
35 |         in_chans=3,
36 |         embed_dim=768,
37 |     ):
38 |         super().__init__()
39 |         self.patch_size = patch_size
40 |         self.tubelet_size = tubelet_size
41 | 
42 |         self.proj = nn.Conv3d(
43 |             in_channels=in_chans,
44 |             out_channels=embed_dim,
45 |             kernel_size=(tubelet_size, patch_size, patch_size),
46 |             stride=(tubelet_size, patch_size, patch_size),
47 |         )
48 | 
49 |     def forward(self, x, **kwargs):
50 |         B, C, T, H, W = x.shape
51 |         x = self.proj(x).flatten(2).transpose(1, 2)
52 |         return x
53 | 


--------------------------------------------------------------------------------
/src/models/utils/pos_embs.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
 2 | #
 3 | # This source code is licensed under the MIT license found in the
 4 | # LICENSE file in the root directory of this source tree.
 5 | 
 6 | import numpy as np
 7 | 
 8 | 
 9 | def get_3d_sincos_pos_embed(embed_dim, grid_size, grid_depth, cls_token=False, uniform_power=False):
10 |     """
11 |     grid_size: int of the grid height and width
12 |     grid_depth: int of the grid depth
13 |     returns:
14 |         pos_embed: [grid_depth*grid_size*grid_size, embed_dim] (w/o cls_token)
15 |                 or [1+grid_depth*grid_size*grid_size, embed_dim] (w/ cls_token)
16 |     """
17 |     grid_d = np.arange(grid_depth, dtype=float)
18 |     grid_h = np.arange(grid_size, dtype=float)
19 |     grid_w = np.arange(grid_size, dtype=float)
20 |     grid_h, grid_d, grid_w = np.meshgrid(
21 |         grid_h, grid_d, grid_w
22 |     )  # order of meshgrid is very important for indexing as [d,h,w]
23 | 
24 |     if not uniform_power:
25 |         h_embed_dim = embed_dim // 4
26 |         w_embed_dim = embed_dim // 4
27 |         d_embed_dim = embed_dim // 2
28 |     else:
29 |         h_embed_dim = w_embed_dim = d_embed_dim = int(np.ceil(embed_dim / 6) * 2)
30 | 
31 |     emb_h = get_1d_sincos_pos_embed_from_grid(h_embed_dim, grid_h)  # (T*H*W, D1)
32 |     emb_w = get_1d_sincos_pos_embed_from_grid(w_embed_dim, grid_w)  # (T*H*W, D2)
33 |     emb_d = get_1d_sincos_pos_embed_from_grid(d_embed_dim, grid_d)  # (T*H*W, D3)
34 |     pos_embed = np.concatenate([emb_d, emb_h, emb_w], axis=1)
35 |     pos_embed = pos_embed[:, :embed_dim]
36 |     if cls_token:
37 |         pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
38 |     return pos_embed
39 | 
40 | 
41 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
42 |     """
43 |     grid_size: int of the grid height and width
44 |     returns:
45 |         pos_embed: [grid_size*grid_size, embed_dim] (w/o cls_token)
46 |                 or [1+grid_size*grid_size, embed_dim] (w/ cls_token)
47 |     """
48 |     grid_h = np.arange(grid_size, dtype=float)
49 |     grid_w = np.arange(grid_size, dtype=float)
50 |     grid_w, grid_h = np.meshgrid(grid_w, grid_h)  # order of meshgrid is very important for indexing as [h, w]
51 | 
52 |     emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid_h)  # (H*W, D/2)
53 |     emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid_w)  # (H*W, D/2)
54 |     pos_embed = np.concatenate([emb_h, emb_w], axis=1)  # (H*W, D)
55 |     if cls_token:
56 |         pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
57 |     return pos_embed
58 | 
59 | 
60 | def get_1d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
61 |     """
62 |     embed_dim: output dimension for each position
63 |     grid_size: int of the grid length
64 |     returns:
65 |         pos_embed: [grid_size, embed_dim] (w/o cls_token)
66 |                 or [1+grid_size, embed_dim] (w/ cls_token)
67 |     """
68 |     grid = np.arange(grid_size, dtype=float)
69 |     pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid)
70 |     if cls_token:
71 |         pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
72 |     return pos_embed
73 | 
74 | 
75 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
76 |     """
77 |     embed_dim: output dimension for each position
78 |     pos: a list of positions to be encoded: size (M,)
79 |     returns: (M, D)
80 |     """
81 |     assert embed_dim % 2 == 0
82 |     omega = np.arange(embed_dim // 2, dtype=float)
83 |     omega /= embed_dim / 2.0
84 |     omega = 1.0 / 10000**omega  # (D/2,)
85 | 
86 |     pos = pos.reshape(-1)  # (M,)
87 |     out = np.einsum("m,d->md", pos, omega)  # (M, D/2), outer product
88 | 
89 |     emb_sin = np.sin(out)  # (M, D/2)
90 |     emb_cos = np.cos(out)  # (M, D/2)
91 | 
92 |     emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)
93 |     return emb
94 | 


--------------------------------------------------------------------------------
/src/utils/checkpoint_loader.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
 2 | #
 3 | # This source code is licensed under the MIT license found in the
 4 | # LICENSE file in the root directory of this source tree.
 5 | 
 6 | import os
 7 | import random
 8 | import time
 9 | from typing import Any
10 | 
11 | import torch
12 | from torch.serialization import MAP_LOCATION
13 | 
14 | from src.utils.logging import get_logger
15 | 
16 | logger = get_logger(os.path.basename(__file__))
17 | 
18 | 
19 | def robust_checkpoint_loader(r_path: str, map_location: MAP_LOCATION = "cpu", max_retries: int = 3) -> Any:
20 |     """
21 |     Loads a checkpoint from a path, retrying up to max_retries times if the checkpoint is not found.
22 |     """
23 |     retries = 0
24 | 
25 |     while retries < max_retries:
26 |         try:
27 |             return torch.load(r_path, map_location=map_location)
28 |         except Exception as e:
29 |             logger.warning(f"Encountered exception when loading checkpoint {e}")
30 |             retries += 1
31 |             if retries < max_retries:
32 |                 sleep_time_s = (2**retries) * random.uniform(1.0, 1.1)
33 |                 logger.warning(f"Sleeping {sleep_time_s}s and trying again, count {retries}/{max_retries}")
34 |                 time.sleep(sleep_time_s)
35 |                 continue
36 |             else:
37 |                 raise e
38 | 


--------------------------------------------------------------------------------
/src/utils/distributed.py:
--------------------------------------------------------------------------------
  1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
  2 | #
  3 | # This source code is licensed under the MIT license found in the
  4 | # LICENSE file in the root directory of this source tree.
  5 | 
  6 | import os
  7 | from pathlib import Path
  8 | 
  9 | import torch
 10 | import torch.distributed as dist
 11 | 
 12 | from src.utils.logging import get_logger
 13 | 
 14 | logger = get_logger()
 15 | 
 16 | 
 17 | def init_distributed(port=37129, rank_and_world_size=(None, None)):
 18 |     # try to set all environment variables to avoid triggering a segfault
 19 |     # environment variables can be reallocated during the execution of torch.distributed.init_process_group
 20 |     # the idea is a race condition may trigger if init_progress_group is modifying an environment variable at
 21 |     # the same time as Python, so we try to set all environs before initializing distributed
 22 |     if "SLURM_JOB_ID" in os.environ:
 23 |         # Use the slurm_tmpdir (if it exists) instead of /tmp
 24 |         tmpdir = Path(f"/scratch/slurm_tmpdir/{os.environ['SLURM_JOB_ID']}")
 25 |         if tmpdir.exists():
 26 |             os.environ["TMPDIR"] = str(tmpdir)
 27 | 
 28 |     if dist.is_available() and dist.is_initialized():
 29 |         return dist.get_world_size(), dist.get_rank()
 30 | 
 31 |     rank, world_size = rank_and_world_size
 32 |     os.environ["MASTER_ADDR"] = "localhost"
 33 | 
 34 |     if (rank is None) or (world_size is None):
 35 |         try:
 36 |             world_size = int(os.environ["SLURM_NTASKS"])
 37 |             rank = int(os.environ["SLURM_PROCID"])
 38 |             os.environ["MASTER_ADDR"] = os.environ["HOSTNAME"]
 39 |         except Exception:
 40 |             logger.info("SLURM vars not set (distributed training not available)")
 41 |             world_size, rank = 1, 0
 42 |             return world_size, rank
 43 | 
 44 |     try:
 45 |         os.environ["MASTER_PORT"] = str(port)
 46 |         torch.distributed.init_process_group(backend="nccl", world_size=world_size, rank=rank)
 47 |     except Exception as e:
 48 |         world_size, rank = 1, 0
 49 |         logger.info(f"Rank: {rank}. Distributed training not available {e}")
 50 | 
 51 |     return world_size, rank
 52 | 
 53 | 
 54 | class AllGather(torch.autograd.Function):
 55 | 
 56 |     @staticmethod
 57 |     def forward(ctx, x):
 58 |         if dist.is_available() and dist.is_initialized() and (dist.get_world_size() > 1):
 59 |             x = x.contiguous()
 60 |             outputs = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
 61 |             dist.all_gather(outputs, x)
 62 |             return torch.cat(outputs, 0)
 63 |         return x
 64 | 
 65 |     @staticmethod
 66 |     def backward(ctx, grads):
 67 |         if dist.is_available() and dist.is_initialized() and (dist.get_world_size() > 1):
 68 |             s = (grads.shape[0] // dist.get_world_size()) * dist.get_rank()
 69 |             e = (grads.shape[0] // dist.get_world_size()) * (dist.get_rank() + 1)
 70 |             grads = grads.contiguous()
 71 |             dist.all_reduce(grads)
 72 |             return grads[s:e]
 73 |         return grads
 74 | 
 75 | 
 76 | class AllReduceSum(torch.autograd.Function):
 77 | 
 78 |     @staticmethod
 79 |     def forward(ctx, x):
 80 |         if dist.is_available() and dist.is_initialized() and (dist.get_world_size() > 1):
 81 |             x = x.contiguous()
 82 |             dist.all_reduce(x)
 83 |         return x
 84 | 
 85 |     @staticmethod
 86 |     def backward(ctx, grads):
 87 |         return grads
 88 | 
 89 | 
 90 | class AllReduce(torch.autograd.Function):
 91 | 
 92 |     @staticmethod
 93 |     def forward(ctx, x):
 94 |         if dist.is_available() and dist.is_initialized() and (dist.get_world_size() > 1):
 95 |             x = x.contiguous() / dist.get_world_size()
 96 |             dist.all_reduce(x)
 97 |         return x
 98 | 
 99 |     @staticmethod
100 |     def backward(ctx, grads):
101 |         return grads
102 | 


--------------------------------------------------------------------------------
/src/utils/logging.py:
--------------------------------------------------------------------------------
  1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
  2 | #
  3 | # This source code is licensed under the MIT license found in the
  4 | # LICENSE file in the root directory of this source tree.
  5 | 
  6 | import logging
  7 | import os
  8 | import subprocess
  9 | import sys
 10 | 
 11 | import torch
 12 | 
 13 | 
 14 | def gpu_timer(closure, log_timings=True):
 15 |     """Helper to time gpu-time to execute closure()"""
 16 |     log_timings = log_timings and torch.cuda.is_available()
 17 | 
 18 |     elapsed_time = -1.0
 19 |     if log_timings:
 20 |         start = torch.cuda.Event(enable_timing=True)
 21 |         end = torch.cuda.Event(enable_timing=True)
 22 |         start.record()
 23 | 
 24 |     result = closure()
 25 | 
 26 |     if log_timings:
 27 |         end.record()
 28 |         torch.cuda.synchronize()
 29 |         elapsed_time = start.elapsed_time(end)
 30 | 
 31 |     return result, elapsed_time
 32 | 
 33 | 
 34 | LOG_FORMAT = "[%(levelname)-8s][%(asctime)s][%(name)-20s][%(funcName)-25s] %(message)s"
 35 | DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
 36 | 
 37 | 
 38 | def get_logger(name=None, force=False):
 39 |     logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=LOG_FORMAT, datefmt=DATE_FORMAT, force=force)
 40 |     return logging.getLogger(name=name)
 41 | 
 42 | 
 43 | class CSVLogger(object):
 44 | 
 45 |     def __init__(self, fname, *argv, **kwargs):
 46 |         self.fname = fname
 47 |         self.types = []
 48 |         mode = kwargs.get("mode", "+a")
 49 |         self.delim = kwargs.get("delim", ",")
 50 |         # -- print headers
 51 |         with open(self.fname, mode) as f:
 52 |             for i, v in enumerate(argv, 1):
 53 |                 self.types.append(v[0])
 54 |                 if i < len(argv):
 55 |                     print(v[1], end=self.delim, file=f)
 56 |                 else:
 57 |                     print(v[1], end="\n", file=f)
 58 | 
 59 |     def log(self, *argv):
 60 |         with open(self.fname, "+a") as f:
 61 |             for i, tv in enumerate(zip(self.types, argv), 1):
 62 |                 end = self.delim if i < len(argv) else "\n"
 63 |                 print(tv[0] % tv[1], end=end, file=f)
 64 | 
 65 | 
 66 | class AverageMeter(object):
 67 |     """computes and stores the average and current value"""
 68 | 
 69 |     def __init__(self):
 70 |         self.reset()
 71 | 
 72 |     def reset(self):
 73 |         self.val = 0
 74 |         self.avg = 0
 75 |         self.max = float("-inf")
 76 |         self.min = float("inf")
 77 |         self.sum = 0
 78 |         self.count = 0
 79 | 
 80 |     def update(self, val, n=1):
 81 |         self.val = val
 82 |         try:
 83 |             self.max = max(val, self.max)
 84 |             self.min = min(val, self.min)
 85 |         except Exception:
 86 |             pass
 87 |         self.sum += val * n
 88 |         self.count += n
 89 |         self.avg = self.sum / self.count
 90 | 
 91 | 
 92 | def jepa_rootpath():
 93 |     this_file = os.path.abspath(__file__)
 94 |     return "/".join(this_file.split("/")[:-3])
 95 | 
 96 | 
 97 | def git_information():
 98 |     jepa_root = jepa_rootpath()
 99 |     try:
100 |         resp = (
101 |             subprocess.check_output(["git", "-C", jepa_root, "rev-parse", "HEAD", "--abbrev-ref", "HEAD"])
102 |             .decode("ascii")
103 |             .strip()
104 |         )
105 |         commit, branch = resp.split("\n")
106 |         return f"branch: {branch}\ncommit: {commit}\n"
107 |     except Exception:
108 |         return "unknown"
109 | 


--------------------------------------------------------------------------------
/src/utils/schedulers.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
 2 | #
 3 | # This source code is licensed under the MIT license found in the
 4 | # LICENSE file in the root directory of this source tree.
 5 | 
 6 | import math
 7 | 
 8 | 
 9 | class WSDSchedule(object):
10 | 
11 |     def __init__(self, optimizer, warmup_steps, anneal_steps, T_max, start_lr, ref_lr, final_lr=0.0):
12 |         self.optimizer = optimizer
13 |         self.start_lr = start_lr
14 |         self.ref_lr = ref_lr
15 |         self.final_lr = final_lr
16 |         self.anneal_steps = anneal_steps
17 |         self.warmup_steps = warmup_steps
18 |         self.T_max = T_max - warmup_steps - anneal_steps
19 |         self._step = 0.0
20 | 
21 |     def step(self):
22 |         self._step += 1
23 |         if self._step < self.warmup_steps:
24 |             progress = float(self._step) / float(max(1, self.warmup_steps))
25 |             new_lr = self.start_lr + progress * (self.ref_lr - self.start_lr)
26 |         elif self._step < self.T_max + self.warmup_steps:
27 |             new_lr = self.ref_lr
28 |         else:
29 |             _step = self._step - (self.T_max + self.warmup_steps)
30 |             progress = float(_step) / float(max(1, self.anneal_steps))
31 |             new_lr = self.ref_lr + progress * (self.final_lr - self.ref_lr)
32 | 
33 |         for group in self.optimizer.param_groups:
34 |             group["lr"] = new_lr
35 |             if "lr_scale" in group:
36 |                 group["lr"] *= group["lr_scale"]
37 | 
38 |         return new_lr
39 | 
40 | 
41 | class WarmupCosineSchedule(object):
42 | 
43 |     def __init__(self, optimizer, warmup_steps, start_lr, ref_lr, T_max, last_epoch=-1, final_lr=0.0):
44 |         self.optimizer = optimizer
45 |         self.start_lr = start_lr
46 |         self.ref_lr = ref_lr
47 |         self.final_lr = final_lr
48 |         self.warmup_steps = warmup_steps
49 |         self.T_max = T_max - warmup_steps
50 |         self._step = 0.0
51 | 
52 |     def step(self):
53 |         self._step += 1
54 |         if self._step < self.warmup_steps:
55 |             progress = float(self._step) / float(max(1, self.warmup_steps))
56 |             new_lr = self.start_lr + progress * (self.ref_lr - self.start_lr)
57 |         else:
58 |             # -- progress after warmup
59 |             progress = float(self._step - self.warmup_steps) / float(max(1, self.T_max))
60 |             new_lr = max(
61 |                 self.final_lr,
62 |                 self.final_lr + (self.ref_lr - self.final_lr) * 0.5 * (1.0 + math.cos(math.pi * progress)),
63 |             )
64 | 
65 |         for group in self.optimizer.param_groups:
66 |             group["lr"] = new_lr
67 | 
68 |         return new_lr
69 | 
70 | 
71 | class CosineWDSchedule(object):
72 | 
73 |     def __init__(self, optimizer, ref_wd, T_max, final_wd=0.0):
74 |         self.optimizer = optimizer
75 |         self.ref_wd = ref_wd
76 |         self.final_wd = final_wd
77 |         self.T_max = T_max
78 |         self._step = 0.0
79 | 
80 |     def step(self):
81 |         self._step += 1
82 |         progress = self._step / self.T_max
83 |         new_wd = self.final_wd + (self.ref_wd - self.final_wd) * 0.5 * (1.0 + math.cos(math.pi * progress))
84 | 
85 |         if self.final_wd <= self.ref_wd:
86 |             new_wd = max(self.final_wd, new_wd)
87 |         else:
88 |             new_wd = min(self.final_wd, new_wd)
89 | 
90 |         for group in self.optimizer.param_groups:
91 |             if ("WD_exclude" not in group) or not group["WD_exclude"]:
92 |                 group["weight_decay"] = new_wd
93 |         return new_wd
94 | 


--------------------------------------------------------------------------------
/src/utils/tensors.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
 2 | #
 3 | # This source code is licensed under the MIT license found in the
 4 | # LICENSE file in the root directory of this source tree.
 5 | 
 6 | import math
 7 | from logging import getLogger
 8 | 
 9 | import torch
10 | 
11 | logger = getLogger()
12 | 
13 | 
14 | def _no_grad_trunc_normal_(tensor, mean, std, a, b):
15 |     # Cut & paste from PyTorch official master until it's in a few official releases - RW
16 |     # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
17 |     def norm_cdf(x):
18 |         # Computes standard normal cumulative distribution function
19 |         return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
20 | 
21 |     with torch.no_grad():
22 |         # Values are generated by using a truncated uniform distribution and
23 |         # then using the inverse CDF for the normal distribution.
24 |         # Get upper and lower cdf values
25 |         lower = norm_cdf((a - mean) / std)
26 |         upper = norm_cdf((b - mean) / std)
27 | 
28 |         # Uniformly fill tensor with values from [lower, upper], then translate to
29 |         # [2*lower-1, 2*upper-1].
30 |         tensor.uniform_(2 * lower - 1, 2 * upper - 1)
31 | 
32 |         # Use inverse cdf transform for normal distribution to get truncated
33 |         # standard normal
34 |         tensor.erfinv_()
35 | 
36 |         # Transform to proper mean, std
37 |         tensor.mul_(std * math.sqrt(2.0))
38 |         tensor.add_(mean)
39 | 
40 |         # Clamp to ensure it's in the proper range
41 |         tensor.clamp_(min=a, max=b)
42 |         return tensor
43 | 
44 | 
45 | def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
46 |     # type: (Tensor, float, float, float, float) -> Tensor
47 |     return _no_grad_trunc_normal_(tensor, mean, std, a, b)
48 | 
49 | 
50 | def repeat_interleave_batch(x, B, repeat):
51 |     N = len(x) // B
52 |     x = torch.cat([torch.cat([x[i * B : (i + 1) * B] for _ in range(repeat)], dim=0) for i in range(N)], dim=0)
53 |     return x
54 | 


--------------------------------------------------------------------------------
/src/utils/wrappers.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
 2 | #
 3 | # This source code is licensed under the MIT license found in the
 4 | # LICENSE file in the root directory of this source tree.
 5 | 
 6 | import torch.nn as nn
 7 | 
 8 | 
 9 | class MultiSeqWrapper(nn.Module):
10 | 
11 |     def __init__(self, backbone):
12 |         super().__init__()
13 |         self.backbone = backbone
14 | 
15 |     def forward(self, x, masks=None):
16 |         """
17 |         :param x: [list] List of Tensors of different seq lengths
18 |         :param masks: [list[list]] List of Tensors (out index: masks for given seq length, inner index: multimasks for that seq len)
19 |         """
20 |         if masks is None:
21 |             return [self.backbone(xi) for xi in x]
22 | 
23 |         outs = [[] for _ in x]
24 |         for i, (xi, mi) in enumerate(zip(x, masks)):
25 |             for mij in mi:
26 |                 outs[i] += [self.backbone(xi, masks=mij)]
27 |         return outs
28 | 
29 | 
30 | class PredictorMultiSeqWrapper(nn.Module):
31 | 
32 |     def __init__(self, backbone):
33 |         super().__init__()
34 |         self.backbone = backbone
35 | 
36 |     def forward(self, x, masks_x, masks_y, has_cls=False):
37 |         n = 0
38 |         outs = [[] for _ in x]
39 |         for i, (xi, mxi, myi) in enumerate(zip(x, masks_x, masks_y)):
40 |             for xij, mxij, myij in zip(xi, mxi, myi):
41 |                 outs[i] += [self.backbone(xij, mxij, myij, mask_index=i, has_cls=has_cls)]
42 |                 n += 1
43 |         return outs
44 | 


--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
 2 | #
 3 | # This source code is licensed under the MIT license found in the
 4 | # LICENSE file in the root directory of this source tree.
 5 | 
 6 | import os
 7 | import sys
 8 | 
 9 | JEPA_ROOT = os.path.dirname(os.path.dirname(__file__))
10 | sys.path.append(JEPA_ROOT)
11 | 


--------------------------------------------------------------------------------
/tests/datasets/__init__.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
 2 | #
 3 | # This source code is licensed under the MIT license found in the
 4 | # LICENSE file in the root directory of this source tree.
 5 | 
 6 | import os
 7 | import sys
 8 | 
 9 | JEPA_ROOT = os.path.dirname(os.path.dirname(__file__))
10 | sys.path.append(JEPA_ROOT)
11 | 


--------------------------------------------------------------------------------
/tests/datasets/test_dataloader.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
 2 | #
 3 | # This source code is licensed under the MIT license found in the
 4 | # LICENSE file in the root directory of this source tree.
 5 | 
 6 | import unittest
 7 | 
 8 | from src.datasets.utils.dataloader import ConcatIndices
 9 | 
10 | 
11 | class TestConcatIndices(unittest.TestCase):
12 |     def test_concat_indices(self):
13 |         sizes = [10, 20, 30, 40]
14 |         total_size = sum(sizes)
15 |         concat_indices = ConcatIndices(sizes)
16 | 
17 |         # -1 is outside the total range
18 |         with self.assertRaises(ValueError):
19 |             concat_indices[-1]
20 |         # 0-9 map to dataset 0
21 |         self.assertEqual(concat_indices[0], (0, 0))
22 |         self.assertEqual(concat_indices[9], (0, 9))
23 |         # 10-29 map to dataset 1
24 |         self.assertEqual(concat_indices[10], (1, 0))
25 |         self.assertEqual(concat_indices[29], (1, 19))
26 |         # 30-59 map to dataset 2
27 |         self.assertEqual(concat_indices[30], (2, 0))
28 |         self.assertEqual(concat_indices[59], (2, 29))
29 |         # 60-99 map to dataset 3
30 |         self.assertEqual(concat_indices[60], (3, 0))
31 |         self.assertEqual(concat_indices[99], (3, 39))
32 |         # 100 is outside the total range
33 |         with self.assertRaises(ValueError):
34 |             concat_indices[total_size]
35 | 


--------------------------------------------------------------------------------
/tests/datasets/test_memory_efficient_sampler.py:
--------------------------------------------------------------------------------
  1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
  2 | #
  3 | # This source code is licensed under the MIT license found in the
  4 | # LICENSE file in the root directory of this source tree.
  5 | 
  6 | import unittest
  7 | 
  8 | from src.datasets.utils.weighted_sampler import (
  9 |     MemoryEfficientDistributedWeightedSampler,
 10 |     MemoryEfficientDistributedWeightedSamplerLessRepeat,
 11 | )
 12 | 
 13 | 
 14 | class MockDataset:
 15 |     def __init__(self, datasets, dataset_weights):
 16 |         self.datasets = datasets
 17 |         self.dataset_weights = dataset_weights
 18 | 
 19 |     def __len__(self):
 20 |         return sum(len(d) for d in self.datasets)
 21 | 
 22 | 
 23 | class TestMemoryEfficientSampler(unittest.TestCase):
 24 | 
 25 |     def test_shuffled_sampling_single(self):
 26 |         "The specific values returned are a function of the random sampler with the given seed."
 27 |         datasets = []
 28 |         for i in range(3):
 29 |             datasets.append([f"DS{i}"] * 100 * (i + 1))
 30 | 
 31 |         mock_dataset = MockDataset(datasets, [1, 1, 1])
 32 |         sampler = MemoryEfficientDistributedWeightedSampler(mock_dataset, num_replicas=1, rank=0, shuffle=True)
 33 | 
 34 |         smplr_it = iter(sampler)
 35 | 
 36 |         ex = next(smplr_it)
 37 |         self.assertIsNotNone(ex)
 38 |         self.assertEqual(ex, 202)  # Based on previous run
 39 | 
 40 |         ex = next(smplr_it)
 41 |         self.assertIsNotNone(ex)
 42 |         self.assertEqual(ex, 26)  # Based on previous run
 43 | 
 44 |     def test_shuffled_sampling(self):
 45 |         datasets = []
 46 |         for i in range(3):
 47 |             datasets.append([f"DS{i}"] * 100 * (i + 1))
 48 | 
 49 |         mock_dataset = MockDataset(datasets, [1, 2000, 1])
 50 |         sampler = MemoryEfficientDistributedWeightedSampler(mock_dataset, num_replicas=8, rank=3, shuffle=True)
 51 | 
 52 |         smplr_it = iter(sampler)
 53 | 
 54 |         # Notice how the following samples are drawn from the 2nd dataset, which has a weight of 2000.
 55 |         ex = next(smplr_it)
 56 |         self.assertIsNotNone(ex)
 57 |         self.assertEqual(ex, 135)  # Based on previous run
 58 | 
 59 |         ex = next(smplr_it)
 60 |         self.assertIsNotNone(ex)
 61 |         self.assertEqual(ex, 143)  # Based on previous run
 62 | 
 63 |     def test_non_shuffled_sampling(self):
 64 |         datasets = []
 65 |         for i in range(3):
 66 |             datasets.append([f"DS{i}"] * 100 * (i + 1))
 67 | 
 68 |         mock_dataset = MockDataset(datasets, [1, 10, 1])
 69 |         sampler = MemoryEfficientDistributedWeightedSampler(mock_dataset, num_replicas=4, rank=2, shuffle=False)
 70 | 
 71 |         smplr_it = iter(sampler)
 72 | 
 73 |         ex = next(smplr_it)
 74 |         self.assertIsNotNone(ex)
 75 |         self.assertEqual(ex, 102)  # Calculated based on the `__next__` function's non shuffled logic.
 76 | 
 77 |         ex = next(smplr_it)
 78 |         self.assertIsNotNone(ex)
 79 |         self.assertEqual(ex, 106)  # Calculated based on the `__next__` function's non shuffled logic.
 80 | 
 81 | 
 82 | class TestMemoryEfficientSamplerLessRepeat(unittest.TestCase):
 83 | 
 84 |     def test_shuffled_sampling_single(self):
 85 |         """
 86 |         Testing all weights are equal to 1.
 87 |         The specific values returned are a function of the random sampler with the given seed.
 88 |         """
 89 |         datasets = []
 90 |         for i in range(3):
 91 |             datasets.append([f"DS{i}"] * 100 * (i + 1))
 92 | 
 93 |         mock_dataset = MockDataset(datasets, [1, 1, 1])
 94 |         sampler = MemoryEfficientDistributedWeightedSamplerLessRepeat(
 95 |             mock_dataset, num_replicas=1, rank=0, shuffle=True
 96 |         )
 97 | 
 98 |         smplr_it = iter(sampler)
 99 | 
100 |         ex = next(smplr_it)
101 |         self.assertIsNotNone(ex)
102 |         self.assertEqual(ex, 144)  # Based on previous run
103 | 
104 |         ex = next(smplr_it)
105 |         self.assertIsNotNone(ex)
106 |         self.assertEqual(ex, 84)  # Based on previous run
107 | 
108 |     def test_shuffled_sampling(self):
109 |         """
110 |         Testing one dominant dataset.
111 |         The specific values returned are a function of the random sampler with the given seed.
112 |         """
113 |         datasets = []
114 |         for i in range(3):
115 |             datasets.append([f"DS{i}"] * 100 * (i + 1))
116 | 
117 |         mock_dataset = MockDataset(datasets, [1, 2000, 1])
118 |         sampler = MemoryEfficientDistributedWeightedSamplerLessRepeat(
119 |             mock_dataset, num_replicas=8, rank=3, shuffle=True
120 |         )
121 | 
122 |         smplr_it = iter(sampler)
123 | 
124 |         # Notice how the following samples are drawn from the 2nd dataset, which has a weight of 2000.
125 |         ex = next(smplr_it)
126 |         self.assertIsNotNone(ex)
127 |         self.assertEqual(ex, 255)  # Based on previous run
128 | 
129 |         ex = next(smplr_it)
130 |         self.assertIsNotNone(ex)
131 |         self.assertEqual(ex, 231)  # Based on previous run
132 | 
133 |     def test_non_shuffled_sampling(self):
134 |         datasets = []
135 |         for i in range(3):
136 |             datasets.append([f"DS{i}"] * 100 * (i + 1))
137 | 
138 |         mock_dataset = MockDataset(datasets, [1, 10, 1])
139 |         sampler = MemoryEfficientDistributedWeightedSamplerLessRepeat(
140 |             mock_dataset, num_replicas=4, rank=2, shuffle=False
141 |         )
142 | 
143 |         smplr_it = iter(sampler)
144 | 
145 |         ex = next(smplr_it)
146 |         self.assertIsNotNone(ex)
147 |         self.assertEqual(ex, 102)  # Calculated based on the `__next__` function's non shuffled logic.
148 | 
149 |         ex = next(smplr_it)
150 |         self.assertIsNotNone(ex)
151 |         self.assertEqual(ex, 106)  # Calculated based on the `__next__` function's non shuffled logic.
152 | 


--------------------------------------------------------------------------------
/tests/datasets/test_vjepa_transforms.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
 2 | #
 3 | # This source code is licensed under the MIT license found in the
 4 | # LICENSE file in the root directory of this source tree.
 5 | 
 6 | import unittest
 7 | 
 8 | import numpy as np
 9 | import torch
10 | 
11 | from app.vjepa import transforms
12 | from src.datasets.utils.video import functional
13 | from src.datasets.utils.video.volume_transforms import ClipToTensor
14 | 
15 | 
16 | class TestNormalize(unittest.TestCase):
17 | 
18 |     def setUp(self):
19 |         self.g = torch.Generator()
20 |         self.g.manual_seed(42)
21 | 
22 |     def test_approximation_equivalance(self):
23 |         T, H, W, C = 16, 224, 224, 3
24 |         shape = (T, H, W, C)
25 |         mean = torch.tensor([0.485, 0.456, 0.406])
26 |         std = torch.tensor([0.229, 0.224, 0.225])
27 |         for i in range(10):
28 |             X = torch.randint(low=0, high=255, size=shape, generator=self.g, dtype=torch.uint8)
29 |             X_clone = X.clone().permute(3, 0, 1, 2)  # C, T, H, W
30 | 
31 |             X_norm = transforms.tensor_normalize(X, mean, std)
32 |             X_norm_fast = transforms._tensor_normalize_inplace(X_clone, 255.0 * mean, 255.0 * std)
33 |             self.assertTrue(torch.allclose(X_norm, X_norm_fast.permute(1, 2, 3, 0)))
34 | 
35 | 
36 | class TestVideoTransformFunctionalCrop(unittest.TestCase):
37 |     def test_tensor_numpy(self):
38 |         T, C, H, W = 16, 3, 280, 320
39 |         shape = (T, C, H, W)
40 |         crop_szie = (10, 10, 224, 224)
41 |         video_tensor = torch.randint(low=0, high=255, size=shape, dtype=torch.uint8)
42 |         video_numpy = video_tensor.numpy()
43 | 
44 |         cropped_tensor = functional.crop_clip(video_tensor, *crop_szie)
45 |         self.assertIsInstance(cropped_tensor[0], torch.Tensor)
46 | 
47 |         cropped_np_array = functional.crop_clip(video_numpy, *crop_szie)
48 |         self.assertIsInstance(cropped_np_array[0], np.ndarray)
49 | 
50 |         for clip_tensor, clip_np in zip(cropped_tensor, cropped_np_array):
51 |             torch.testing.assert_close(clip_tensor, torch.Tensor(clip_np).to(dtype=torch.uint8))
52 | 
53 | 
54 | class TestVideoTransformFunctionalResize(unittest.TestCase):
55 |     def test_tensor_numpy(self):
56 |         T, C, H, W = 16, 3, 280, 320
57 |         shape = (T, C, H, W)
58 |         resize_to = 256
59 | 
60 |         video_tensor = torch.randint(low=0, high=255, size=shape, dtype=torch.int16)
61 |         # We permute the videos because our underlying numpy.array based transforms are expecting
62 |         # image in (H, W, C) shape whereas our tensor transforms are mostly in (C, H, W)
63 |         video_numpy = video_tensor.permute(0, 2, 3, 1).numpy()  # (T, C, H, W) -> (T, H, W, C)
64 | 
65 |         resized_tensor = functional.resize_clip(video_tensor, resize_to)
66 |         self.assertIsInstance(resized_tensor[0], torch.Tensor)
67 | 
68 |         resized_np_array = functional.resize_clip(video_numpy, resize_to)
69 |         self.assertIsInstance(resized_np_array[0], np.ndarray)
70 | 
71 |         for clip_tensor, clip_np in zip(resized_tensor, resized_np_array):
72 |             clip_tensor = clip_tensor.permute(1, 2, 0)
73 |             diff = torch.mean((torch.abs(clip_tensor - torch.Tensor(clip_np).to(torch.int16))) / (clip_tensor + 1))
74 | 
75 |             # Transformatinos can not exactly match because of their interpolation functions coming from
76 |             # two different sources. Here we check for their relative differences.
77 |             # See the discussion here: https://github.com/fairinternal/jepa-internal/pull/65#issuecomment-2101833959
78 |             self.assertLess(diff, 0.05)
79 | 
80 | 
81 | class TestVideoTransformClipToTensor(unittest.TestCase):
82 |     def test_tensor_numpy(self):
83 |         T, C, H, W = 16, 3, 280, 320
84 |         shape = (T, C, H, W)
85 |         transform = ClipToTensor()
86 | 
87 |         video_tensor = [clip for clip in torch.randint(low=0, high=255, size=shape, dtype=torch.int16)]
88 |         # We permute the videos because our underlying numpy.array based transforms are expecting
89 |         # image in (H, W, C) shape whereas our tensor transforms are mostly in (C, H, W)
90 |         video_numpy = [clip.permute(1, 2, 0).numpy() for clip in video_tensor]
91 |         torch.testing.assert_close(transform(video_tensor), transform(video_numpy))
92 | 


--------------------------------------------------------------------------------
/tests/models/__init__.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
 2 | #
 3 | # This source code is licensed under the MIT license found in the
 4 | # LICENSE file in the root directory of this source tree.
 5 | 
 6 | import os
 7 | import sys
 8 | 
 9 | JEPA_ROOT = os.path.dirname(os.path.dirname(__file__))
10 | sys.path.append(JEPA_ROOT)
11 | 


--------------------------------------------------------------------------------
/tests/models/test_models.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
 2 | #
 3 | # This source code is licensed under the MIT license found in the
 4 | # LICENSE file in the root directory of this source tree.
 5 | 
 6 | import unittest
 7 | 
 8 | import torch
 9 | 
10 | from src.models.vision_transformer import VIT_EMBED_DIMS, vit_tiny
11 | 
12 | 
13 | class TestImageViT(unittest.TestCase):
14 |     def setUp(self) -> None:
15 |         self._vit_tiny = vit_tiny()
16 |         self.height, self.width = 224, 224
17 |         self.num_patches = (self.height // self._vit_tiny.patch_size) * (self.width // self._vit_tiny.patch_size)
18 | 
19 |     def test_model_image_nomask_batchsize_4(self):
20 |         BS = 4
21 |         x = torch.rand((BS, 3, self.height, self.width))
22 |         y = self._vit_tiny(x)
23 |         self.assertIsInstance(y, torch.Tensor)
24 |         self.assertEqual(y.size(), (BS, self.num_patches, VIT_EMBED_DIMS["vit_tiny"]))
25 | 
26 |     def test_model_image_nomask_batchsize_1(self):
27 |         BS = 1
28 |         x = torch.rand((BS, 3, self.height, self.width))
29 |         y = self._vit_tiny(x)
30 |         self.assertIsInstance(y, torch.Tensor)
31 |         self.assertEqual(y.size(), (BS, self.num_patches, VIT_EMBED_DIMS["vit_tiny"]))
32 | 
33 |     def test_model_image_masked_batchsize_4(self):
34 |         BS = 4
35 |         mask_indices = [6, 7, 8]
36 |         masks = [torch.tensor(mask_indices, dtype=torch.int64) for _ in range(BS)]
37 |         x = torch.rand((BS, 3, self.height, self.width))
38 |         y = self._vit_tiny(x, masks=masks)
39 |         self.assertIsInstance(y, torch.Tensor)
40 |         self.assertEqual(y.size(), (BS, len(mask_indices), VIT_EMBED_DIMS["vit_tiny"]))
41 | 
42 |     def test_model_image_masked_batchsize_1(self):
43 |         BS = 1
44 |         mask_indices = [6, 7, 8]
45 |         masks = [torch.tensor(mask_indices, dtype=torch.int64) for _ in range(BS)]
46 |         x = torch.rand((BS, 3, self.height, self.width))
47 |         y = self._vit_tiny(x, masks=masks)
48 |         self.assertIsInstance(y, torch.Tensor)
49 |         self.assertEqual(y.size(), (BS, len(mask_indices), VIT_EMBED_DIMS["vit_tiny"]))
50 | 
51 | 
52 | class TestVideoViT(unittest.TestCase):
53 |     def setUp(self) -> None:
54 |         self.num_frames = 8
55 |         self._vit_tiny = vit_tiny(num_frames=8)
56 |         self.height, self.width = 224, 224
57 |         self.num_patches = (
58 |             (self.height // self._vit_tiny.patch_size)
59 |             * (self.width // self._vit_tiny.patch_size)
60 |             * (self.num_frames // self._vit_tiny.tubelet_size)
61 |         )
62 | 
63 |     def test_model_video_nomask_batchsize_4(self):
64 |         BS = 4
65 |         x = torch.rand((BS, 3, self.num_frames, self.height, self.width))
66 |         y = self._vit_tiny(x)
67 |         self.assertIsInstance(y, torch.Tensor)
68 |         self.assertEqual(y.size(), (BS, self.num_patches, VIT_EMBED_DIMS["vit_tiny"]))
69 | 
70 |     def test_model_video_nomask_batchsize_1(self):
71 |         BS = 1
72 |         x = torch.rand((BS, 3, self.num_frames, self.height, self.width))
73 |         y = self._vit_tiny(x)
74 |         self.assertIsInstance(y, torch.Tensor)
75 |         self.assertEqual(y.size(), (BS, self.num_patches, VIT_EMBED_DIMS["vit_tiny"]))
76 | 
77 |     def test_model_video_masked_batchsize_4(self):
78 |         BS = 4
79 |         mask_indices = [6, 7, 8]
80 |         masks = [torch.tensor(mask_indices, dtype=torch.int64) for _ in range(BS)]
81 |         x = torch.rand((BS, 3, self.num_frames, self.height, self.width))
82 |         y = self._vit_tiny(x, masks=masks)
83 |         self.assertIsInstance(y, torch.Tensor)
84 |         self.assertEqual(y.size(), (BS, len(mask_indices), VIT_EMBED_DIMS["vit_tiny"]))
85 | 
86 |     def test_model_video_masked_batchsize_1(self):
87 |         BS = 1
88 |         mask_indices = [6, 7, 8]
89 |         masks = [torch.tensor(mask_indices, dtype=torch.int64) for _ in range(BS)]
90 |         x = torch.rand((BS, 3, self.num_frames, self.height, self.width))
91 |         y = self._vit_tiny(x, masks=masks)
92 |         self.assertIsInstance(y, torch.Tensor)
93 |         self.assertEqual(y.size(), (BS, len(mask_indices), VIT_EMBED_DIMS["vit_tiny"]))
94 | 


--------------------------------------------------------------------------------
/tests/models/test_predictor.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
 2 | #
 3 | # This source code is licensed under the MIT license found in the
 4 | # LICENSE file in the root directory of this source tree.
 5 | 
 6 | import unittest
 7 | 
 8 | import torch
 9 | 
10 | from src.models.predictor import VisionTransformerPredictor
11 | 
12 | 
13 | class TestImagePredictorMaskTokens(unittest.TestCase):
14 |     def setUp(self) -> None:
15 |         self._embed_dim = 768
16 |         self._predictor = VisionTransformerPredictor(embed_dim=self._embed_dim, use_mask_tokens=True)
17 | 
18 |     def test_image_predictor_batchsize_4(self):
19 |         BS = 4
20 |         enc_mask_indices = [torch.tensor(BS * [[6, 7, 8]], dtype=torch.int64)]
21 |         target_mask_indices = [torch.tensor(BS * [[16, 17, 18, 19]], dtype=torch.int64)]
22 |         enc = torch.rand((BS, len(enc_mask_indices[0][0]), self._embed_dim))
23 |         y = self._predictor(enc, enc_mask_indices, target_mask_indices)
24 |         self.assertIsInstance(y, torch.Tensor)
25 |         self.assertEqual(y.size(), (BS, target_mask_indices[0].size(1), self._embed_dim))
26 | 
27 |     def test_image_predictor_batchsize_1(self):
28 |         BS = 1
29 |         enc_mask_indices = [torch.tensor(BS * [[6, 7, 8]], dtype=torch.int64)]
30 |         target_mask_indices = [torch.tensor(BS * [[16, 17, 18, 19]], dtype=torch.int64)]
31 |         enc = torch.rand((BS, len(enc_mask_indices[0][0]), self._embed_dim))
32 |         y = self._predictor(enc, enc_mask_indices, target_mask_indices)
33 |         self.assertIsInstance(y, torch.Tensor)
34 |         self.assertEqual(y.size(), (BS, target_mask_indices[0].size(1), self._embed_dim))
35 | 
36 | 
37 | class TestVideoPredictorMaskTokens(unittest.TestCase):
38 |     def setUp(self) -> None:
39 |         self._embed_dim = 768
40 |         self._predictor = VisionTransformerPredictor(embed_dim=self._embed_dim, use_mask_tokens=True)
41 | 
42 |     def test_video_predictor_batchsize_4(self):
43 |         BS = 4
44 |         enc_mask_indices = [torch.tensor(BS * [[6, 7, 8]], dtype=torch.int64)]
45 |         target_mask_indices = [torch.tensor(BS * [[16, 17, 18, 19]], dtype=torch.int64)]
46 |         enc = torch.rand((BS, len(enc_mask_indices[0][0]), self._embed_dim))
47 |         y = self._predictor(enc, enc_mask_indices, target_mask_indices)
48 |         self.assertIsInstance(y, torch.Tensor)
49 |         self.assertEqual(y.size(), (BS, target_mask_indices[0].size(1), self._embed_dim))
50 | 
51 |     def test_video_predictor_batchsize_1(self):
52 |         BS = 1
53 |         enc_mask_indices = [torch.tensor(BS * [[6, 7, 8]], dtype=torch.int64)]
54 |         target_mask_indices = [torch.tensor(BS * [[16, 17, 18, 19]], dtype=torch.int64)]
55 |         enc = torch.rand((BS, len(enc_mask_indices[0][0]), self._embed_dim))
56 |         y = self._predictor(enc, enc_mask_indices, target_mask_indices)
57 |         self.assertIsInstance(y, torch.Tensor)
58 |         self.assertEqual(y.size(), (BS, target_mask_indices[0].size(1), self._embed_dim))
59 | 


--------------------------------------------------------------------------------
/tests/models/test_vision_transformer.py:
--------------------------------------------------------------------------------
 1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
 2 | #
 3 | # This source code is licensed under the MIT license found in the
 4 | # LICENSE file in the root directory of this source tree.
 5 | 
 6 | import unittest
 7 | from copy import deepcopy
 8 | 
 9 | import numpy as np
10 | import pytest
11 | import torch
12 | 
13 | from src.models.vision_transformer import vit_giant_xformers_rope
14 | 
15 | 
16 | # Usage: pytest tests/models/test_vision_transformer.py
17 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="This test requires CUDA")
18 | class TestViTGiant(unittest.TestCase):
19 |     def setUp(self) -> None:
20 |         self.model_shape_invariant = vit_giant_xformers_rope(
21 |             img_size=256, patch_size=16, num_frames=16, handle_nonsquare_inputs=True
22 |         ).cuda()
23 |         self.model_square = deepcopy(self.model_shape_invariant)
24 |         self.model_square.handle_nonsquare_inputs = False
25 |         torch.manual_seed(42)
26 |         self.total_iters = 10
27 | 
28 |     def test_square_inputs(self):
29 |         for i in range(self.total_iters):
30 |             input = torch.rand(1, 3, 16, 256, 256).cuda()
31 |             with torch.cuda.amp.autocast(enabled=True):
32 |                 with torch.no_grad():
33 |                     out1 = self.model_shape_invariant(input)
34 |                     out2 = self.model_square(input)
35 |                     torch.testing.assert_close(out1, out2)
36 | 
37 |     def test_square_inputs_with_mask(self):
38 |         for i in range(self.total_iters):
39 |             input = torch.rand(1, 3, 16, 256, 256).cuda()
40 |             mask = torch.randint(0, 2, (1, 2048)).cuda()
41 |             with torch.cuda.amp.autocast(enabled=True):
42 |                 with torch.no_grad():
43 |                     out1 = self.model_shape_invariant(input, masks=mask)
44 |                     out2 = self.model_square(input, masks=mask)
45 |                     torch.testing.assert_close(out1, out2)
46 | 
47 |     def test_nonsquare_inputs(self):
48 |         for i in range(self.total_iters):
49 |             rand_width = np.random.randint(256, 512)
50 |             rand_height = np.random.randint(256, 512)
51 |             input = torch.rand(1, 3, 16, rand_height, rand_width).cuda()
52 |             # Since input is interpolated, output won't be exactly the same
53 |             input_resized_to_square = [
54 |                 torch.nn.functional.interpolate(input[:, :, frame_idx], size=256, mode="bicubic")
55 |                 for frame_idx in range(input.shape[2])
56 |             ]
57 |             input_resized_to_square = torch.stack(input_resized_to_square, dim=2)
58 | 
59 |             with torch.cuda.amp.autocast(enabled=True):
60 |                 with torch.no_grad():
61 |                     out1 = self.model_shape_invariant(input).mean(dim=1)
62 |                     out2 = self.model_square(input_resized_to_square).mean(dim=1)
63 |                     self.assertAlmostEqual(torch.nn.functional.cosine_similarity(out2, out1).item(), 1.0, places=3)
64 | 


--------------------------------------------------------------------------------