├── .github └── workflows │ └── basic_checks.yml ├── .gitignore ├── LICENSE ├── demos ├── 1_Load_SAE.ipynb ├── 2_Train_SAE.ipynb ├── 3_Evaluate_SAE.ipynb ├── Bonus_SAE_notebooks │ └── Evaluating_Gytis_SAE.ipynb ├── Emoji_Logit_Lens_Demo.ipynb ├── Interactive_Attention_Head_Tour.ipynb └── ViT_Prisma_Main_Demo.ipynb ├── docs ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── PRISMA_CONTRIBUTORS.md ├── README.md ├── UsageGuide.md ├── assets │ └── images │ │ ├── attention head 1.png │ │ ├── attention head 2.png │ │ ├── attention head 3.png │ │ ├── cat toilet segmentation.png │ │ ├── cheetah segmentation.png │ │ ├── child lion segmentation.png │ │ ├── corner-head.gif │ │ ├── direct logit attribution.png │ │ ├── dogit lens 2.png │ │ ├── dogit lens.png │ │ ├── house.jpg │ │ ├── house2.jpg │ │ ├── monogenic_induction.png │ │ ├── patched head.png │ │ ├── polygenic_induction.png │ │ ├── prism1.jpg │ │ ├── prism2.jpg │ │ └── prisma.jpg ├── dataset_descriptions │ └── Induction.md ├── experiments │ ├── ImagenetResults.md │ └── dSprites_results.md ├── list_of_supported_models.md ├── sae_experiment_tracking.md └── sae_table.md ├── pretrained_classifiers └── clip_benchmark │ └── imagenet_classifier_hf_hub_laion_CLIP_ViT_B_32_DataComp.XL_s13B_b90K.npy ├── setup.py ├── src └── vit_prisma │ ├── __init__.py │ ├── configs │ ├── HookedTextTransformerConfig.py │ ├── HookedViTConfig.py │ └── __init__.py │ ├── dataloaders │ ├── __init__.py │ ├── cifar10.py │ ├── circle.py │ ├── conceptual_captions.py │ ├── dsprites.py │ ├── imagenet_classes_simple.py │ ├── imagenet_dataset.py │ ├── imagenet_index.py │ ├── induction.py │ └── polygenic_induction.py │ ├── model_eval │ ├── __init__.py │ └── evaluate_imagenet.py │ ├── models │ ├── __init__.py │ ├── activation_fns.py │ ├── base_text_transformer.py │ ├── base_transformer.py │ ├── base_vit.py │ ├── build_zero_shot_classifier.py │ ├── layers │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── head.py │ │ ├── layer_norm.py │ │ ├── mlp.py │ │ ├── patch_embedding.py │ │ ├── position_embedding.py │ │ └── transformer_block.py │ ├── model_config_registry.py │ ├── model_loader.py │ └── weight_conversion.py │ ├── prisma_tools │ ├── __init__.py │ ├── activation_cache.py │ ├── factored_matrix.py │ ├── hook_point.py │ ├── hooked_root_module.py │ ├── lens_handle.py │ └── logit_lens.py │ ├── sae │ ├── __init__.py │ ├── config.py │ ├── evals │ │ ├── Evaluating_Vision_SAE.ipynb │ │ ├── __init__.py │ │ ├── eval_feature_list.py │ │ ├── eval_neuron_basis.py │ │ ├── evals.py │ │ ├── model_eval.py │ │ ├── mustache_decoder.ipynb │ │ └── train_tinyclip_kadinsky_adapter.py │ ├── sae.py │ ├── sae_utils.py │ ├── train_sae.py │ ├── training │ │ ├── __init__.py │ │ ├── activations_store.py │ │ ├── geometric_median.py │ │ └── get_scheduler.py │ └── transcoder.py │ ├── sample_images │ ├── cat_crop.jpeg │ ├── cat_dog.jpeg │ ├── cheetah.jpeg │ ├── child_lion.jpeg │ ├── crop_cat.png │ ├── crop_dog.jpeg │ ├── fruit_bowl.jpeg │ ├── n01818515_39.JPEG │ ├── n01818515_39_typographic.JPEG │ ├── sea_slug_typogrpahic.JPEG │ └── toilet_cat.jpeg │ ├── training │ ├── __init__.py │ ├── early_stopping.py │ ├── schedulers.py │ ├── trainer.py │ ├── training_dictionary.py │ └── training_utils.py │ ├── transforms │ ├── __init__.py │ └── model_transforms.py │ ├── utils │ ├── __init__.py │ ├── constants.py │ ├── data_utils │ │ ├── __init__.py │ │ ├── cifar │ │ │ ├── __init__.py │ │ │ └── cifar_10_utils.py │ │ ├── imagenet │ │ │ ├── __init__.py │ │ │ ├── imagenet100_classes.py │ │ │ ├── imagenet_dict.py │ │ │ ├── imagenet_emoji.py │ │ │ └── imagenet_utils.py │ │ └── loader.py │ ├── devices.py │ ├── enums.py │ ├── get_activations.py │ ├── load_model.py │ ├── openai_templates.py │ ├── prisma_utils.py │ ├── saving_utils.py │ ├── tutorial_utils.py │ └── wandb_utils.py │ ├── visualization │ ├── __init__.py │ ├── get_color.js │ ├── main_visualize.js │ ├── patch_level_logit_lens.py │ ├── patch_to_img.js │ ├── sample_cifar10_image_0.npy │ ├── sample_cifar10_image_10.npy │ ├── template.html │ ├── visualize_attention.py │ ├── visualize_attention_js.py │ └── visualize_image.py │ └── vjepa_hf │ ├── README.md │ ├── __init__.py │ ├── configs │ ├── __init__.py │ ├── configs.py │ └── sae_jepa_config.py │ ├── convert_vjepa_to_hf.py │ ├── evals │ ├── __init__.py │ ├── eval_configs │ │ ├── __init__.py │ │ └── imagenet_config.yaml │ ├── imagenet_classification.py │ ├── main_distributed.py │ └── scaffold.py │ ├── modeling_vjepa.py │ ├── paths_cw.yaml │ ├── src │ ├── __init__.py │ ├── datasets │ │ ├── __init__.py │ │ ├── airstore_dataset.py │ │ ├── audio_video_dataset.py │ │ ├── blobstore_dataset.py │ │ ├── data_manager.py │ │ ├── dm_perception.py │ │ ├── ht100m_captions.py │ │ ├── imagenet1k.py │ │ ├── imagenet22k.py │ │ ├── imagenet_subsets │ │ │ └── 1percent.txt │ │ ├── iterable_video_dataset.py │ │ ├── laion.py │ │ ├── lvd_images.py │ │ ├── mixed_dataset.py │ │ ├── mvbench.py │ │ ├── utils │ │ │ ├── dataloader.py │ │ │ ├── kenburns.py │ │ │ ├── tokenizers │ │ │ │ ├── __init__.py │ │ │ │ ├── bpe_simple_vocab_16e6.txt.gz │ │ │ │ └── tokenizer.py │ │ │ ├── utils.py │ │ │ ├── video │ │ │ │ ├── functional.py │ │ │ │ ├── randaugment.py │ │ │ │ ├── randerase.py │ │ │ │ ├── transforms.py │ │ │ │ ├── transforms_builder.py │ │ │ │ └── volume_transforms.py │ │ │ └── weighted_sampler.py │ │ ├── video_dataset.py │ │ ├── video_webdataset.py │ │ └── webvid.py │ ├── masks │ │ ├── __init__.py │ │ ├── causal.py │ │ ├── default.py │ │ ├── multiblock.py │ │ ├── multiblock3d.py │ │ ├── multiseq_multiblock3d.py │ │ ├── random.py │ │ ├── random_tube.py │ │ ├── simple_multiblock.py │ │ └── utils.py │ ├── models │ │ ├── __init__.py │ │ ├── attentive_pooler.py │ │ ├── baselines │ │ │ ├── dinov2.py │ │ │ ├── dinov2_layers │ │ │ │ ├── __init__.py │ │ │ │ ├── attention.py │ │ │ │ ├── block.py │ │ │ │ ├── dino_head.py │ │ │ │ ├── drop_path.py │ │ │ │ ├── layer_scale.py │ │ │ │ ├── mlp.py │ │ │ │ ├── patch_embed.py │ │ │ │ └── swiglu_ffn.py │ │ │ ├── omnimae.py │ │ │ ├── openclip.py │ │ │ ├── videomae.py │ │ │ └── videomaev2.py │ │ ├── predictor.py │ │ ├── text_transformer.py │ │ ├── utils │ │ │ ├── modules.py │ │ │ ├── multimask.py │ │ │ ├── patch_embed.py │ │ │ ├── pos_embs.py │ │ │ └── rope.py │ │ └── vision_transformer.py │ └── utils │ │ ├── __init__.py │ │ ├── adamw.py │ │ ├── cluster.py │ │ ├── cluster_processor │ │ ├── __init__.py │ │ ├── _cluster_processor.py │ │ ├── _local_file_dataset.py │ │ └── _rsc_cluster_processor.py │ │ ├── distributed.py │ │ ├── logging.py │ │ ├── model.py │ │ ├── monitoring.py │ │ ├── profiler.py │ │ ├── schedulers.py │ │ ├── temp_ssd_dir.py │ │ └── tensors.py │ ├── test_vjepa.py │ └── train_jepa_sae.py └── tests ├── __init__.py ├── models ├── neater_test_loading_clip.py ├── test_loading_all_models.py ├── test_loading_dino.py ├── test_loading_kandinsky_vit_l.py ├── test_loading_timm.py ├── test_loading_vit_for_image_classification.py ├── test_loading_vivet.py ├── test_models.py └── test_openclip.py ├── sae ├── test_activation_cache_inference_speed.py ├── test_load_VisionModelSAERunnerConfig.py └── test_sae_training.py ├── test_cache_hook_names.py ├── test_hooks.py ├── test_loading_CLIP-ViT-B-32-DataComp-XL-s13B-b90K.py ├── test_notebooks ├── Untitled.ipynb └── timm_comparison.ipynb ├── test_text_encoder ├── load_text_model.py └── test_diagram_img.png └── test_weight_properties.py /.github/workflows/basic_checks.yml: -------------------------------------------------------------------------------- 1 | name: Basic checks 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | - dev 8 | pull_request: 9 | branches: 10 | - main 11 | - dev 12 | 13 | jobs: 14 | build-and-test: 15 | name: Install package and run tests 16 | runs-on: ubuntu-latest 17 | 18 | strategy: 19 | matrix: 20 | python-version: 21 | - "3.10" 22 | steps: 23 | - name: Check out the repo content 24 | uses: actions/checkout@v3 25 | 26 | - name: Set up Python 27 | uses: actions/setup-python@v4 28 | with: 29 | python-version: ${{ matrix.python-version }} 30 | 31 | - name: Install the package 32 | run: | 33 | python -m pip install --upgrade pip 34 | pip install .[test] 35 | 36 | - name: Run tests 37 | run: | 38 | pytest tests 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data 2 | 3 | # Other repos 4 | beta-tcvae 5 | disent 6 | 7 | # Added 8 | data/* 9 | **slurm** 10 | *slurm* 11 | my_draft/* 12 | checkpoints/* 13 | checkpoints 14 | my_draft 15 | scripts 16 | *.vscode 17 | 18 | 19 | # Byte-compiled / optimized / DLL files 20 | __pycache__/ 21 | *.py[cod] 22 | *$py.class 23 | 24 | # C extensions 25 | *.so 26 | 27 | # Output files 28 | *sbatch_err 29 | *sbatch_out 30 | 31 | # Distribution / packaging 32 | .Python 33 | build/ 34 | develop-eggs/ 35 | dist/ 36 | downloads/ 37 | eggs/ 38 | .eggs/ 39 | lib/ 40 | lib64/ 41 | parts/ 42 | sdist/ 43 | var/ 44 | wheels/ 45 | pip-wheel-metadata/ 46 | share/python-wheels/ 47 | *.egg-info/ 48 | .installed.cfg 49 | *.egg 50 | MANIFEST 51 | 52 | # PyInstaller 53 | # Usually these files are written by a python script from a template 54 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 55 | *.manifest 56 | *.spec 57 | 58 | # Installer logs 59 | pip-log.txt 60 | pip-delete-this-directory.txt 61 | 62 | # Unit test / coverage reports 63 | htmlcov/ 64 | .tox/ 65 | .nox/ 66 | .coverage 67 | .coverage.* 68 | .cache 69 | nosetests.xml 70 | coverage.xml 71 | *.cover 72 | *.py,cover 73 | .hypothesis/ 74 | .pytest_cache/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # IPython 80 | profile_default/ 81 | ipython_config.py 82 | 83 | # pyenv 84 | .python-version 85 | 86 | # Environment and configuration files 87 | .env 88 | .venv 89 | env/ 90 | venv/ 91 | ENV/ 92 | env.bak/ 93 | venv.bak/ 94 | prisma_env/ 95 | .prisma_env 96 | # Spyder project settings 97 | .spyderproject 98 | .spyproject 99 | 100 | # Rope project settings 101 | .ropeproject 102 | 103 | # mkdocs documentation 104 | /site 105 | 106 | 107 | 108 | 109 | # mypy 110 | .mypy_cache/ 111 | .dmypy.json 112 | dmypy.json 113 | 114 | # Pyre type checker 115 | .pyre/ 116 | 117 | # wandb logs anywhere in folder 118 | **/wandb/ 119 | 120 | # Results directories 121 | results/ 122 | 123 | # Config files 124 | configs/ 125 | 126 | 127 | 128 | 129 | # Directories to keep, but ignore contents 130 | !data/ 131 | !data/.gitkeep 132 | !results/ 133 | !results/.gitkeep 134 | !configs/ 135 | !configs/.gitkeep 136 | 137 | # User-specific and OS-specific files 138 | .DS_Store 139 | .DS_Store? 140 | ._* 141 | .Spotlight-V100 142 | .Trashes 143 | ehthumbs.db 144 | Thumbs.db 145 | .idea 146 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Sonia Joseph 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /docs/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to ViT Prisma 2 | 3 | First and foremost, thank you for considering contributing to ViT-Planetarium. It's people like you that make ViT-Planetarium such a great tool. 4 | 5 | ## Getting Started 6 | 7 | - Fork the repository on GitHub. 8 | - Clone your forked repository to your local machine. 9 | - Create a new branch for your feature or fix. 10 | - Make your changes and commit them with a meaningful commit message. 11 | - Push your changes to your fork on GitHub. 12 | - Open a pull request against the main branch of the official repo. 13 | 14 | ## Coding Standards 15 | 16 | - Follow the code style and conventions used throughout the existing codebase. 17 | - Include comments in your code as necessary. 18 | - Document each function with input and output arguments (include types) and the purpose of the function or class. 19 | - Write tests for any new code, and ensure all tests pass before submitting a pull request. 20 | 21 | ## Reporting Bugs 22 | 23 | - Ensure the bug was not already reported by searching on GitHub under Issues. 24 | - If you're unable to find an open issue addressing the problem, open a new one. Be sure to include a title and clear description, as much relevant information as possible, and a code sample or an executable test case demonstrating the expected behavior that is not occurring. 25 | 26 | ## Suggesting Enhancements 27 | - Feel free to discuss any new ideas on the Discussions tab 28 | - Open a new issue on GitHub with a detailed description of the feature. 29 | - Include the steps that it would enhance, and the possible benefits. 30 | 31 | ## Pull Requests 32 | 33 | - One pull request for one feature/fix. Do not mix multiple subjects into a single PR. 34 | - Update the `README.md` or documentation as necessary with your changes. 35 | - Delete your branch once changes have been merged. 36 | 37 | ## Code of Conduct 38 | 39 | By participating in this project, you are expected to uphold our [Code of Conduct](CODE_OF_CONDUCT.md). 40 | 41 | -------------------------------------------------------------------------------- /docs/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Sonia Joseph 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /docs/PRISMA_CONTRIBUTORS.md: -------------------------------------------------------------------------------- 1 | # Prisma Contributors 2 | 3 | We are a community of researchers dedicated to making this repo beautiful, functional, and open source. 4 | 5 | If you'd like to join us, check out our [open issues](https://github.com/soniajoseph/ViT-Prisma/issues), our [contributing guidelines](https://github.com/soniajoseph/ViT-Prisma/blob/main/docs/CONTRIBUTING.md), and our [code of conduct](https://github.com/soniajoseph/ViT-Prisma/blob/main/docs/CODE_OF_CONDUCT.md). 6 | 7 | ## Contributors 8 | 9 | *In alphabetical order* 10 | 11 | * [Edward Stevinson](https://github.com/stevinson) 12 | * [Ethan Goldfarb](https://github.com/ekg15) 13 | * [Lorenz Hufe](https://github.com/lowlorenz) 14 | * [Praneet Suresh](https://github.com/PraneetNeuro) 15 | * [Rob Graham](https://github.com/themachinefan) 16 | * [Sonia Joseph](https://github.com/soniajoseph) 17 | * [Yash Vadi](https://github.com/YashVadi) 18 | -------------------------------------------------------------------------------- /docs/UsageGuide.md: -------------------------------------------------------------------------------- 1 | # Usage Guide 2 | 3 | The Prisma Project is a framework that facilitates training Vision Transformers (ViTs) in various configurations, with synthetically generated datasets suited for mechanistic interpretability. The library also allows loading pretrained models from huggingface and timm, which can further be finetuned using the trainer built into the framework. 4 | 5 | ## Installing Repo 6 | 7 | To install as an editable repo: 8 | 9 | ``` 10 | git clone https://github.com/soniajoseph/ViT-Prisma 11 | cd ViT-Prisma 12 | pip install -e . 13 | ``` 14 | 15 | ## Configurations 16 | 17 | The Prisma config object houses certain hyperparameters of the models, along with attributes of the training procedure, which can come in handy to track, and design new experiments. 18 | 19 | ```python 20 | from vit_prisma.configs.HookedViTConfig import HookedViTConfig 21 | ``` 22 | By inspecting the HookedViTConfig config available in the framework, you can create your own one based on the experiment. 23 | 24 | ## Hooked ViT 25 | 26 | The Hooked ViT part of the Prisma Project allows instantiating models in a flexible way along with hooks to inspect the internals of a model during forward / backward passes, making it easier for mechanistic interpretability research with ViTs. The architecture is defined by the Prisma config object discussed in the previous section. 27 | 28 | ### Instantiating the Base ViT 29 | 30 | ```python 31 | from vit_prisma.models.base_vit import HookedViT 32 | from vit_prisma.configs.HookedViTConfig import HookedViTConfig 33 | 34 | # The HookedViTConfig is available in the framework itself, which is the base setup for any experiment, one can customize it as per the requirements of an experiment. 35 | config = HookedViTConfig() 36 | 37 | model = HookedViT(config) 38 | ``` 39 | 40 | ### Using pretrained model 41 | 42 | ```python 43 | import vit_prisma 44 | model_name = "open-clip:timm/vit_base_patch32_clip_224.laion2b_e16" 45 | hooked_clip = vit_prisma.load_hooked_model(model_name) 46 | transform = vit_prisma.get_model_transforms(model_name) 47 | ``` 48 | 49 | ### Using pretrained models (Legacy) 50 | 51 | ```python 52 | from vit_prisma.models.pretrained_model import PretrainedModel 53 | from vit_prisma.configs.HookedViTConfig import HookedViTConfig 54 | 55 | config = HookedViTConfig() 56 | 57 | hf_model = PretrainedModel('google/vit-base-patch16-224-in21k', config) 58 | 59 | timm_model = PretrainedModel('vit_base_patch32_224', config, is_timm=True) 60 | ``` 61 | 62 | ### Datasets 63 | 64 | The framework also has a few datasets that can be synthetically generated, and cached. 65 | 66 | Below are the datasets available as part of the framework: 67 | - Circle 68 | - [dsprites](https://github.com/google-deepmind/dsprites-dataset) 69 | - Induction 70 | 71 | You will be able to use any other dataset with the framework, for ease of use, you can instantiate it as a Pytorch dataset object. 72 | 73 | ```python 74 | from vit_prisma.dataloaders.induction import InductionDataset 75 | 76 | train = InductionDataset('train') 77 | 78 | # If you don't provide a test split to the trainer, the trainer will perform the train/test split for you. 79 | ``` 80 | 81 | ### Training and tracking experiments 82 | 83 | The trainer has built in support for wandb experiment tracking. Make sure to set up wandb on your localhost, and configure the tracking attributes in the Prisma config object. 84 | 85 | ```python 86 | from vit_prisma.models.base_vit import HookedViT 87 | from vit_prisma.configs.HookedViTConfig import HookedViTConfig 88 | from vit_prisma.training import trainer 89 | from vit_prisma.dataloaders.induction import InductionDataset 90 | 91 | train_dataset = InductionDataset('train') 92 | 93 | config = HookedViTConfig() 94 | 95 | model_function = HookedViT 96 | 97 | trainer.train(model_function, config, train_dataset) 98 | ``` 99 | 100 | ### Callbacks 101 | 102 | The trainer has support for callbacks, that allows you to pass methods that can be executed after a desired number of epoch(s), or step(s). To setup callback methods, make sure to follow the PrismaCallback protocol. 103 | 104 | ```python 105 | from vit_prisma.models.base_vit import HookedViT 106 | from vit_prisma.configs.HookedViTConfig import HookedViTConfig 107 | from vit_prisma.training import trainer 108 | from vit_prisma.dataloaders.induction import InductionDataset 109 | 110 | from vit_prisma.training.training_utils import PrismaCallback 111 | 112 | class DemoCallback(PrismaCallback): 113 | def on_epoch_end(self, epoch, net, val_loader, wandb_logger): 114 | # Specify a condition if you don't want the callback to execute after each epoch 115 | if epoch % 5 == 0: 116 | # perform some function with the network, validation set, and log it if required. 117 | pass 118 | 119 | def on_step_end(self, step, net, val_loader, wandb_logger): 120 | # It is similar to on_epoch_end but runs after desired number of steps instead of epochs 121 | pass 122 | 123 | train_dataset = InductionDataset('train') 124 | 125 | config = HookedViTConfig() 126 | 127 | model_function = HookedViT 128 | 129 | trainer.train(model_function, config, train_dataset, callbacks=[DemoCallback()]) 130 | ``` 131 | 132 | 133 | -------------------------------------------------------------------------------- /docs/assets/images/attention head 1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Prisma-Multimodal/ViT-Prisma/aa0a8d531e3bbbdffe37624bfc19a07feb2b8262/docs/assets/images/attention head 1.png -------------------------------------------------------------------------------- /docs/assets/images/attention head 2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Prisma-Multimodal/ViT-Prisma/aa0a8d531e3bbbdffe37624bfc19a07feb2b8262/docs/assets/images/attention head 2.png -------------------------------------------------------------------------------- /docs/assets/images/attention head 3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Prisma-Multimodal/ViT-Prisma/aa0a8d531e3bbbdffe37624bfc19a07feb2b8262/docs/assets/images/attention head 3.png -------------------------------------------------------------------------------- /docs/assets/images/cat toilet segmentation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Prisma-Multimodal/ViT-Prisma/aa0a8d531e3bbbdffe37624bfc19a07feb2b8262/docs/assets/images/cat toilet segmentation.png -------------------------------------------------------------------------------- /docs/assets/images/cheetah segmentation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Prisma-Multimodal/ViT-Prisma/aa0a8d531e3bbbdffe37624bfc19a07feb2b8262/docs/assets/images/cheetah segmentation.png -------------------------------------------------------------------------------- /docs/assets/images/child lion segmentation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Prisma-Multimodal/ViT-Prisma/aa0a8d531e3bbbdffe37624bfc19a07feb2b8262/docs/assets/images/child lion segmentation.png -------------------------------------------------------------------------------- /docs/assets/images/corner-head.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Prisma-Multimodal/ViT-Prisma/aa0a8d531e3bbbdffe37624bfc19a07feb2b8262/docs/assets/images/corner-head.gif -------------------------------------------------------------------------------- /docs/assets/images/direct logit attribution.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Prisma-Multimodal/ViT-Prisma/aa0a8d531e3bbbdffe37624bfc19a07feb2b8262/docs/assets/images/direct logit attribution.png -------------------------------------------------------------------------------- /docs/assets/images/dogit lens 2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Prisma-Multimodal/ViT-Prisma/aa0a8d531e3bbbdffe37624bfc19a07feb2b8262/docs/assets/images/dogit lens 2.png -------------------------------------------------------------------------------- /docs/assets/images/dogit lens.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Prisma-Multimodal/ViT-Prisma/aa0a8d531e3bbbdffe37624bfc19a07feb2b8262/docs/assets/images/dogit lens.png -------------------------------------------------------------------------------- /docs/assets/images/house.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Prisma-Multimodal/ViT-Prisma/aa0a8d531e3bbbdffe37624bfc19a07feb2b8262/docs/assets/images/house.jpg -------------------------------------------------------------------------------- /docs/assets/images/house2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Prisma-Multimodal/ViT-Prisma/aa0a8d531e3bbbdffe37624bfc19a07feb2b8262/docs/assets/images/house2.jpg -------------------------------------------------------------------------------- /docs/assets/images/monogenic_induction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Prisma-Multimodal/ViT-Prisma/aa0a8d531e3bbbdffe37624bfc19a07feb2b8262/docs/assets/images/monogenic_induction.png -------------------------------------------------------------------------------- /docs/assets/images/patched head.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Prisma-Multimodal/ViT-Prisma/aa0a8d531e3bbbdffe37624bfc19a07feb2b8262/docs/assets/images/patched head.png -------------------------------------------------------------------------------- /docs/assets/images/polygenic_induction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Prisma-Multimodal/ViT-Prisma/aa0a8d531e3bbbdffe37624bfc19a07feb2b8262/docs/assets/images/polygenic_induction.png -------------------------------------------------------------------------------- /docs/assets/images/prism1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Prisma-Multimodal/ViT-Prisma/aa0a8d531e3bbbdffe37624bfc19a07feb2b8262/docs/assets/images/prism1.jpg -------------------------------------------------------------------------------- /docs/assets/images/prism2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Prisma-Multimodal/ViT-Prisma/aa0a8d531e3bbbdffe37624bfc19a07feb2b8262/docs/assets/images/prism2.jpg -------------------------------------------------------------------------------- /docs/assets/images/prisma.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Prisma-Multimodal/ViT-Prisma/aa0a8d531e3bbbdffe37624bfc19a07feb2b8262/docs/assets/images/prisma.jpg -------------------------------------------------------------------------------- /docs/dataset_descriptions/Induction.md: -------------------------------------------------------------------------------- 1 | ## Induction Dataset 2 | 3 | The Induction Mice are trained on the synthetically generated Induction dataset, which comes in two flavours: monogenic and polygenic. This dataset is designed to enable experiments that could potentially uncover induction heads in ViTs, akin to the induction heads found in language models. 4 | 5 | ## Monogenic Induction dataset 6 | 7 | ![Sample from each class](assets/images/monogenic_induction.png) 8 | 9 | **Classes** 10 | 11 | 0. **Vertical-Same** 12 | 1. **Vertical-Not Same** 13 | 2. **Horizontal-Same** 14 | 3. **Horizontal-Not same** 15 | 16 | Here, the same/not-same indicates whether the structures present in the image are the same or not. 17 | 18 | ## Polygenic Induction dataset 19 | 20 | ![Sample from each class](assets/images/polygenic_induction.png) 21 | 22 | **Classes** 23 | 24 | | ID | Orientation | Pattern | 25 | |----|-------------|---------| 26 | | 0 | H | AAAA | 27 | | 1 | H | ABAB | 28 | | 2 | H | ABBA | 29 | | 3 | H | AABB | 30 | | 4 | H | ABBB | 31 | | 5 | H | AAAB | 32 | | 6 | V | AAAA | 33 | | 7 | V | ABAB | 34 | | 8 | V | ABBA | 35 | | 9 | V | AABB | 36 | | 10 | V | ABBB | 37 | | 11 | V | AAAB | 38 | 39 | The Orientation column indicates whether the structures are arranged horizontally (H) or vertically (V). The Pattern column denotes the sequence in which the structures are arranged. -------------------------------------------------------------------------------- /docs/experiments/ImagenetResults.md: -------------------------------------------------------------------------------- 1 | # ImageNet-1k Results 2 | In our ImageNet experiment, we aimed to assess the performance of Mice ViTs on a more complex and diverse dataset, ImageNet. We trained mice ViTs on classifying the 1000 ImageNet classes. 3 | 4 | ## Training Details 5 | Similar to the dSprites experiment, for each attention layer setting, we explored two model variants: an attention-only model and a model combining attention with the MLP module. Dropout and layer normalization were not applied for simplicity. The detailed training logs and metrics can be found [here](https://wandb.ai/vit-prisma/Imagenet/overview?workspace=user-yash-vadi). 6 | 7 | ## Table of Results 8 | Below table describe the accuracy `[ | ]` of Mice ViTs with different configurations. 9 | | **Size** | **NumLayers** | **Attention+MLP** | **AttentionOnly** | **Model Link** | 10 | |:--------:|:-------------:|:-----------------:|:-----------------:|--------------------------------------------| 11 | | **tiny** | **1** | 0.16 \| 0.33 | 0.11 \| 0.25 | [AttentionOnly](https://huggingface.co/IamYash/ImageNet-tiny-AttentionOnly), [Attention+MLP](https://huggingface.co/IamYash/ImageNet-tiny-Attention-and-MLP) | 12 | | **base** | **2** | 0.23 \| 0.44 | 0.16 \| 0.34 | [AttentionOnly](https://huggingface.co/IamYash/ImageNet-base-AttentionOnly), [Attention+MLP](https://huggingface.co/IamYash/ImageNet-base-Attention-and-MLP) | 13 | | **small**| **3** | 0.28 \| 0.51 | 0.17 \| 0.35 | [AttentionOnly](https://huggingface.co/IamYash/ImageNet-small-AttentionOnly), [Attention+MLP](https://huggingface.co/IamYash/ImageNet-small-Attention-and-MLP) | 14 | | **medium**|**4** | 0.33 \| 0.56 | 0.17 \| 0.36 | [AttentionOnly](https://huggingface.co/IamYash/ImageNet-medium-AttentionOnly), [Attention+MLP](https://huggingface.co/IamYash/ImageNet-medium-Attention-and-MLP) | 15 | -------------------------------------------------------------------------------- /docs/experiments/dSprites_results.md: -------------------------------------------------------------------------------- 1 | ## dSprites Results 2 | The dSprites dataset consists of 2D shapes generated procedurally based on six independent latent factors: color, shape, scale, rotation, and the x and y positions of a sprite. Our specific focus was on the task of shape classification within the dSprites dataset, which encompasses a total of three distinct shapes. We train different mice vits for this task and presented the result below. 3 | 4 | ## Training Details 5 | We conducted experiments with Mice ViTs using the following configurations: 6 | 7 | - Number of Attention Layers: Varied from 1 to 4 8 | - Number of Attention Heads: 8 9 | - Training Batch Size: 512 10 | - Image Patch Size: 8 11 | - Learning Rate: 1e-4 12 | - Scheduler Step Size: 200 13 | - Scheduler Gamma: 0.8 14 | 15 | For each attention layer setting, we analyzed two variants: an attention-only model and a model combining attention with the MLP module. We didn't apply any dropout or layer normalization for this experiment to make it simple to understand. 16 | 17 | ## Table of Results 18 | Below table descirbe the accuracy of Mice ViTs with different configuration. 19 | | **Size** | **NumLayers** | **Attention+MLP** | **AttentionOnly** | **Model Link** | 20 | |:--------:|:-------------:|:-----------------:|:-----------------:|--------------------------------------------| 21 | | **tiny** | **1** | 0.535 | 0.459 | [AttentionOnly](https://huggingface.co/IamYash/dSprites-tiny-AttentionOnly), [Attention+MLP](https://huggingface.co/IamYash/dSprites-tiny-Attention-and-MLP) | 22 | | **base** | **2** | 0.996 | 0.685 | [AttentionOnly](https://huggingface.co/IamYash/dSprites-base-AttentionOnly), [Attention+MLP](https://huggingface.co/IamYash/dSprites-base-Attention-and-MLP) | 23 | | **small**| **3** | 1.000 | 0.774 | [AttentionOnly](https://huggingface.co/IamYash/dSprites-small-AttentionOnly), [Attention+MLP](https://huggingface.co/IamYash/dSprites-small-Attention-and-MLP) | 24 | | **medium**|**4** | 1.000 | 0.991 | [AttentionOnly](https://huggingface.co/IamYash/dSprites-medium-AttentionOnly), [Attention+MLP](https://huggingface.co/IamYash/dSprites-medium-Attention-and-MLP) | 25 | -------------------------------------------------------------------------------- /docs/list_of_supported_models.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | **Model name list coming soon** 5 | 6 | # How to load the transformer... 7 | 8 | # Available Vision/Video Transformers to Load # 9 | 10 | # Vision Transformers 11 | 12 | ## Vanilla Vision Transformers 13 | 14 | ### timm 15 | 16 | 17 | 18 | ## CLIP Transformers 19 | 20 | #### OpenAI 21 | 22 | #### timm 23 | 24 | ### LAION 25 | 26 | ### To load the CLIP text encoder... 27 | 28 | 29 | 30 | 31 | ## DINO Transformers 32 | 33 | 34 | 35 | 36 | # Video Transformers 37 | 38 | ## JEPA Transformers (video encoder) 39 | 40 | * Encoder 41 | * Predictor 42 | 43 | 44 | # Not yet supported 45 | 46 | 47 | # Testing the model 48 | 49 | To test for yourself that the model is properly adapted into the HookedViT, check out this file (COMING SOON). 50 | -------------------------------------------------------------------------------- /docs/sae_experiment_tracking.md: -------------------------------------------------------------------------------- 1 | # SAE Experiment Logging and Tracking 2 | 3 | An objective we have when training an SAE is to have a small number of non-zero elements in the hidden layer (i.e. a small L0 'norm'). The L1 norm is used as a proxy for this. 4 | 5 | The loss term of the SAE is: 6 | 7 | $$loss = mse + l1\_coeff * sparsity + [ \: ghost \: gadients \: ]$$ 8 | 9 | The two loss term in the SAE loss function have conflicting goals: 10 | * The **reconstruction term** (MSE) works to make the autoencoder good at reconstructing the input. 11 | * The **sparsity** term works to reduce the magnitudes in the hidden layer. 12 | 13 | When training an SAE most metrics are related to measuring reconstruction loss vs sparsity, in different ways. The goal is to monitor these and find pareto improvements. 14 | 15 | Weights and Biases is used to track experiments (toggled with `VisionModelSAERunnerConfig.log_to_wandb`) 16 | 17 | # Validation 18 | 19 | Whilst training an SAE validation metrics are logged every `VisionModelSAERunnerConfig.wandb_log_frequency` steps, whilst the sparsity is logged every `VisionModelSAERunnerConfig.feature_sampling_window` steps. Note that each step is `VisionModelSAERunnerConfig.train_batch_size` number of tokens, for a total of (num_images x num_epochs x tokens_per_image / training_batch) steps. 20 | 21 | During validation the following matrics are logged: 22 | 23 | #### plots 24 | 25 | * `log_feature_density_histogram` - every step since the last log the number of feature activations that are greater than 0 divided by the total amount. See [here](https://arena3-chapter1-transformer-interp.streamlit.app/[1.3.2]_Interpretability_with_SAEs) for a discussion on how to interpret this. 26 | 27 | #### details 28 | 29 | * `details/current_learning_rate` - The learning rate at each step of training. 30 | * `details/n_training_tokens` - The number of training tokens seen so far (Images x tokens per image) 31 | * `details/n_training_images` - The number of training images seen so far 32 | 33 | #### losses 34 | 35 | * `losses/mse_loss` - The normalised MSE loss between the SAE input and output, i.e. the reconstruction loss. Normalised such that it is less dependent on the size of the input/output. 36 | * `losses/l1_loss` 37 | * `losses/ghost_grad_loss` - this describes the method of adding an additional term to the loss, which essentially gives dead latents a gradient signal that pushes them in the direction of explaining more of the autoencoder's residual. 38 | `losses/overall_loss` 39 | 40 | #### metrics 41 | 42 | 43 | * `metrics/explained_variance` - explained variance is the ratio of the variance in the reconstructed data to the variance in the original input data. A higher explained variance indicates that the SAE has learned features that capture more of the meaningful structure in the data. This is the mean over the batch. 44 | * `metrics/explained_variance_std` - This is the std over the batch. 45 | * `metrics/l0` - the feature activations for that step greater than 0 summed together, and the mean over the batch. 46 | * `metrics/mean_log10_feature_sparsity` 47 | 48 | #### sparsity 49 | 50 | * `sparsity/mean_passes_since_fired` - For each activation, Tracks the number of training tokens (images x tokens per image) since activation was greater than 0 51 | * `sparsity/dead_features` 52 | * `sparsity/below_1e-5` - feature_sparsity < 1e-5 53 | * `sparsity/below_1e-6` - feature_sparsity < 1e-6 54 | -------------------------------------------------------------------------------- /pretrained_classifiers/clip_benchmark/imagenet_classifier_hf_hub_laion_CLIP_ViT_B_32_DataComp.XL_s13B_b90K.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Prisma-Multimodal/ViT-Prisma/aa0a8d531e3bbbdffe37624bfc19a07feb2b8262/pretrained_classifiers/clip_benchmark/imagenet_classifier_hf_hub_laion_CLIP_ViT_B_32_DataComp.XL_s13B_b90K.npy -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | 4 | install_requires = [ 5 | 'pytest>=6.0', # Test dependency, but included here to auto-install 6 | 'torch', # Test dependency 7 | 'numpy', # Test dependency 8 | 'jaxtyping', # Test dependency 9 | 'einops', # Test dependency 10 | 'fancy_einsum', # Test dependency 11 | 'plotly==5.19.0', 12 | 'timm', # Test dependency 13 | 'transformers', # Test dependency 14 | 'scikit-learn', # Test dependency 15 | 'datasets', 16 | 'line_profiler', 17 | 'matplotlib', 18 | 'wandb', 19 | 'kaleido', 20 | 'open-clip-torch', 21 | ] 22 | 23 | setup( 24 | name='vit-prisma', 25 | version='2.0.0', 26 | author='Sonia Joseph', 27 | author_email='soniamollyjoseph@gmail.com', 28 | description='A Vision Transformer library for mechanistic interpretability.', 29 | long_description=open('docs/README.md').read(), 30 | long_description_content_type='text/markdown', 31 | url='https://github.com/soniajoseph/vit-prisma', 32 | packages=find_packages(where='src'), 33 | package_dir={'': 'src'}, 34 | package_data={ 35 | 'vit_prisma': ['visualization/*.html', 'visualization/*.js'], 36 | # Add other patterns here as needed 37 | }, 38 | classifiers=[ 39 | 'Programming Language :: Python :: 3', 40 | 'License :: OSI Approved :: MIT License', 41 | 'Operating System :: OS Independent', 42 | 'Intended Audience :: Science/Research', 43 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 44 | ], 45 | python_requires='>=3.6', 46 | install_requires=install_requires, 47 | keywords='vision-transformer, clip, multimodal, machine-learning, mechanistic interpretability', 48 | zip_safe=False, 49 | extras_require={ 50 | 'sae': ['sae-lens==2.1.3'], 51 | 'arrow': ['pyarrow'] # to use: pip install -e .[sae] # as of 2.1.3, windows will require pip install sae-lens==2.1.3 --no-dependencies followed by manually installing needed packages 52 | }, 53 | ) 54 | -------------------------------------------------------------------------------- /src/vit_prisma/__init__.py: -------------------------------------------------------------------------------- 1 | from . import configs 2 | from . import dataloaders 3 | from . import model_eval 4 | from . import models 5 | from . import prisma_tools 6 | from . import sae 7 | from . import training 8 | from . import transforms 9 | from . import utils 10 | from . import visualization 11 | from . import vjepa_hf 12 | 13 | from .transforms.model_transforms import get_model_transforms 14 | from .models.model_loader import load_hooked_model 15 | 16 | __all__ = [ 17 | configs, 18 | dataloaders, 19 | model_eval, 20 | models, 21 | prisma_tools, 22 | sae, 23 | training, 24 | transforms, 25 | utils, 26 | visualization, 27 | vjepa_hf, 28 | ] 29 | -------------------------------------------------------------------------------- /src/vit_prisma/configs/HookedTextTransformerConfig.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import torch.nn as nn 3 | import torch 4 | 5 | from typing import Any, Dict, List, Optional 6 | 7 | from vit_prisma.configs.HookedViTConfig import HookedViTConfig 8 | 9 | 10 | @dataclass 11 | class HookedTextTransformerConfig(HookedViTConfig): 12 | """Config specific to the text transformer.""" 13 | 14 | context_length: int = 77 15 | vocab_size: int = 10_000 16 | -------------------------------------------------------------------------------- /src/vit_prisma/configs/HookedViTConfig.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import torch.nn as nn 3 | import torch 4 | 5 | from typing import Any, Dict, List, Optional 6 | 7 | 8 | @dataclass 9 | class HookedViTConfig: 10 | 11 | n_layers: int = None 12 | d_model: int = None 13 | d_head: int = None 14 | d_mlp: int = None 15 | model_name: str = "custom" 16 | use_cls_token: bool = True # Off for ViT 17 | n_heads: int = 4 18 | activation_name: str = "gelu" 19 | d_vocab: int = -1 20 | eps: float = 1e-6 21 | use_attn_result: bool = False 22 | use_attn_scale: bool = True 23 | use_split_qkv_input: bool = False 24 | use_hook_mlp_in: bool = False 25 | use_attn_in: bool = False 26 | use_local_attn: bool = False 27 | original_architecture: Optional[str] = None 28 | from_checkpoint: bool = False 29 | checkpoint_index: Optional[int] = None 30 | checkpoint_label_type: Optional[str] = None 31 | checkpoint_value: Optional[int] = None 32 | tokenizer_name: Optional[str] = None 33 | window_size: Optional[int] = None 34 | attn_types: Optional[List] = None 35 | init_mode: str = "gpt2" 36 | normalization_type: Optional[str] = "LN" 37 | normalize_output: bool = False 38 | device: Optional[str] = 'cpu' 39 | n_devices: int = 1 40 | attention_dir: str = "bidirectional" 41 | attn_only: bool = False 42 | seed: Optional[int] = None 43 | initializer_range: float = -1.0 44 | init_weights: bool = True 45 | scale_attn_by_inverse_layer_idx: bool = False 46 | positional_embedding_type: str = "standard" 47 | final_rms: bool = False 48 | d_vocab_out: int = -1 49 | parallel_attn_mlp: bool = False 50 | rotary_dim: Optional[int] = None 51 | n_params: Optional[int] = None 52 | use_hook_tokens: bool = False 53 | gated_mlp: bool = False 54 | default_prepend_bos: bool = True 55 | dtype: torch.dtype = torch.float32 56 | tokenizer_prepends_bos: Optional[bool] = None 57 | n_key_value_heads: Optional[int] = None 58 | post_embedding_ln: bool = False 59 | rotary_base: int = 10000 60 | trust_remote_code: bool = False 61 | rotary_adjacent_pairs: bool = False 62 | 63 | # Layer norm 64 | layer_norm_pre: bool = False # add layernorm before transformer blocks 65 | 66 | #Bert Block 67 | use_bert_block: bool = False 68 | 69 | # Initialization 70 | weight_type: str = 'he' 71 | cls_std: float = 1e-6 72 | pos_std: float = 0.02 73 | 74 | # Image related 75 | n_channels: int = 3 76 | patch_size: int = 32 77 | image_size: int = 224 78 | 79 | # Classification related 80 | classification_type: str = 'cls' 81 | n_classes: int = 10 82 | return_type: str = 'pre_logits' 83 | 84 | # Logging related 85 | log_dir: str = 'logs' 86 | use_wandb: bool = True 87 | wandb_team_name: str = 'perceptual-alignment' 88 | wandb_project_name: str = None 89 | log_frequency: int = 1 90 | print_every: int = 0 91 | 92 | 93 | # Training related 94 | optimizer_name: str = "AdamW" 95 | lr: float = 3e-4 96 | weight_decay: float = 0.01 97 | loss_fn_name: str = "CrossEntropy" 98 | batch_size: int = 512 99 | warmup_steps: int = 10 100 | scheduler_step: int = 200 101 | scheduler_gamma: float = .8 102 | scheduler_type: str = "WarmupThenStep" 103 | early_stopping: bool = False 104 | early_stopping_patience: int = 2 105 | num_epochs: int = 50 106 | max_grad_norm = 1.0 107 | attn_dropout_rate: float = 0.0 108 | mlp_dropout_rate: float = 0.0 109 | 110 | # Saving related 111 | parent_dir: str = "" 112 | save_dir: str = 'Checkpoints' 113 | save_checkpoints: bool = True 114 | save_cp_frequency: int = 5 115 | 116 | # Properties specific to video transformers 117 | is_video_transformer: bool = False 118 | video_tubelet_depth: Optional[int] = None # Can be int or None 119 | video_num_frames: Optional[int] = None # Can be int or None 120 | 121 | 122 | @classmethod 123 | def from_dict(cls, config_dict: Dict[str, Any]): 124 | return cls(**config_dict) -------------------------------------------------------------------------------- /src/vit_prisma/configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Prisma-Multimodal/ViT-Prisma/aa0a8d531e3bbbdffe37624bfc19a07feb2b8262/src/vit_prisma/configs/__init__.py -------------------------------------------------------------------------------- /src/vit_prisma/dataloaders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Prisma-Multimodal/ViT-Prisma/aa0a8d531e3bbbdffe37624bfc19a07feb2b8262/src/vit_prisma/dataloaders/__init__.py -------------------------------------------------------------------------------- /src/vit_prisma/dataloaders/cifar10.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Prisma-Multimodal/ViT-Prisma/aa0a8d531e3bbbdffe37624bfc19a07feb2b8262/src/vit_prisma/dataloaders/cifar10.py -------------------------------------------------------------------------------- /src/vit_prisma/dataloaders/dsprites.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import Dataset, Subset 4 | from sklearn.model_selection import train_test_split 5 | import os 6 | 7 | 8 | class DSpritesDataset(Dataset): 9 | def __init__(self, data_path= '/network/scratch/s/sonia.joseph/datasets/dsprites.npz'): 10 | # Load the dataset 11 | data = np.load(data_path, allow_pickle=True, encoding='latin1') 12 | 13 | self.images = data['imgs'] 14 | self.labels = data['latents_values'][:,1] 15 | 16 | def __len__(self): 17 | return len(self.images) 18 | 19 | def __getitem__(self, idx): 20 | # Convert image from numpy array to PyTorch tensor and add channel dimension 21 | image = torch.tensor(self.images[idx], dtype=torch.float32).unsqueeze(0) 22 | label = torch.tensor(self.labels[idx]-1, dtype=torch.int64) 23 | return image, label 24 | 25 | 26 | def train_test_dataset(dataset, test_split=0.25): 27 | train_idx, val_idx = train_test_split(list(range(len(dataset))), test_size=test_split) 28 | datasets = {} 29 | datasets['train'] = Subset(dataset, train_idx) 30 | datasets['test'] = Subset(dataset, val_idx) 31 | return datasets 32 | 33 | 34 | # Example to load the data with train and test split 35 | ## dsprites_dataset = DSpritesDataset(data_path) 36 | ## dsprites_datasets = train_test_dataset(dsprites_dataset, test_split=0.25) 37 | -------------------------------------------------------------------------------- /src/vit_prisma/model_eval/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Prisma-Multimodal/ViT-Prisma/aa0a8d531e3bbbdffe37624bfc19a07feb2b8262/src/vit_prisma/model_eval/__init__.py -------------------------------------------------------------------------------- /src/vit_prisma/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Prisma-Multimodal/ViT-Prisma/aa0a8d531e3bbbdffe37624bfc19a07feb2b8262/src/vit_prisma/models/__init__.py -------------------------------------------------------------------------------- /src/vit_prisma/models/activation_fns.py: -------------------------------------------------------------------------------- 1 | """ 2 | Prisma Repo 3 | By Sonia Joseph 4 | 5 | Copyright (c) Sonia Joseph. All rights reserved. 6 | 7 | Inspired by TransformerLens. Some functions have been adapted from the TransformerLens project. 8 | For more information on TransformerLens, visit: https://github.com/neelnanda-io/TransformerLens 9 | """ 10 | 11 | import numpy as np 12 | import torch 13 | from typing import Dict, Optional, Tuple, Union 14 | from jaxtyping import Float, Int 15 | import torch.nn.functional as F 16 | 17 | 18 | 19 | def quick_gelu(input: Float[torch.Tensor, "batch pos d_mlp"]): 20 | return input * torch.sigmoid(1.702 * input) 21 | 22 | 23 | def gelu_new( 24 | input: Float[torch.Tensor, "batch pos d_mlp"] 25 | ) -> Float[torch.Tensor, "batch pos d_mlp"]: 26 | # Implementation of GeLU used by GPT2 - subtly different from PyTorch's 27 | return ( 28 | 0.5 29 | * input 30 | * ( 31 | 1.0 32 | + torch.tanh( 33 | np.sqrt(2.0 / np.pi) * (input + 0.044715 * torch.pow(input, 3.0)) 34 | ) 35 | ) 36 | ) 37 | 38 | 39 | def gelu_fast( 40 | input: Float[torch.Tensor, "batch pos d_mlp"] 41 | ) -> Float[torch.Tensor, "batch pos d_mlp"]: 42 | return ( 43 | 0.5 44 | * input 45 | * (1.0 + torch.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input))) 46 | ) 47 | 48 | 49 | def solu( 50 | input: Float[torch.Tensor, "batch pos d_mlp"] 51 | ) -> Float[torch.Tensor, "batch pos d_mlp"]: 52 | """ 53 | SoLU activation function as described by 54 | https://transformer-circuits.pub/2022/solu/index.html. 55 | 56 | LayerNorm implemented by the MLP class. 57 | """ 58 | return input * F.softmax(input, dim=-1) -------------------------------------------------------------------------------- /src/vit_prisma/models/layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Prisma-Multimodal/ViT-Prisma/aa0a8d531e3bbbdffe37624bfc19a07feb2b8262/src/vit_prisma/models/layers/__init__.py -------------------------------------------------------------------------------- /src/vit_prisma/models/layers/head.py: -------------------------------------------------------------------------------- 1 | 2 | import torch.nn as nn 3 | import torch 4 | 5 | from typing import Dict, Union 6 | 7 | from jaxtyping import Float 8 | 9 | from vit_prisma.configs.HookedViTConfig import HookedViTConfig 10 | from fancy_einsum import einsum 11 | 12 | # maps to class logits 13 | class Head(nn.Module): 14 | def __init__(self, cfg: Union[Dict, HookedViTConfig]): 15 | super().__init__() 16 | if isinstance(cfg, Dict): 17 | cfg = HookedViTConfig.from_dict(cfg) 18 | self.cfg = cfg 19 | 20 | self.W_H: Float[torch.Tensor, "d_model n_classes"] = nn.Parameter( 21 | torch.empty(self.cfg.d_model, self.cfg.n_classes, dtype=cfg.dtype) 22 | ) 23 | self.b_H: Float[torch.Tensor, "n_classes"] = nn.Parameter( 24 | torch.zeros(self.cfg.n_classes, dtype=cfg.dtype) 25 | ) 26 | 27 | def forward( 28 | self, residual: Float[torch.Tensor, "batch d_model"] 29 | ) -> Float[torch.Tensor, "batch n_classes"]: 30 | 31 | return ( 32 | einsum( 33 | "batch d_model, d_model n_classes -> batch n_classes", 34 | residual, 35 | self.W_H, 36 | ) 37 | + self.b_H 38 | ) -------------------------------------------------------------------------------- /src/vit_prisma/models/layers/layer_norm.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | from typing import Dict, Optional, Tuple, Union 5 | 6 | from jaxtyping import Float, Int 7 | 8 | from vit_prisma.configs.HookedViTConfig import HookedViTConfig 9 | from vit_prisma.prisma_tools.hook_point import HookPoint 10 | 11 | class LayerNormPre(nn.Module): 12 | def __init__(self, cfg: Union[Dict, HookedViTConfig]): 13 | """LayerNormPre - the 'center and normalise' part of LayerNorm. Length is 14 | normally d_model, but is d_mlp for softmax. Not needed as a parameter. This 15 | should only be used in inference mode after folding in LayerNorm weights""" 16 | super().__init__() 17 | if isinstance(cfg, Dict): 18 | cfg = HookedViTConfig.from_dict(cfg) 19 | self.cfg = cfg 20 | self.eps = self.cfg.eps 21 | 22 | # Adds a hook point for the normalisation scale factor 23 | self.hook_scale = HookPoint() # [batch, pos] 24 | # Hook Normalized captures LN output - here it's a vector with std 1 and mean 0 25 | self.hook_normalized = HookPoint() # [batch, pos, length] 26 | 27 | def forward( 28 | self, 29 | x: Union[ 30 | Float[torch.Tensor, "batch pos d_model"], 31 | Float[torch.Tensor, "batch pos head_index d_model"], 32 | ], 33 | ) -> Union[ 34 | Float[torch.Tensor, "batch pos d_model"], 35 | Float[torch.Tensor, "batch pos head_index d_model"], 36 | ]: 37 | if self.cfg.dtype not in [torch.float32, torch.float64]: 38 | x = x.to(torch.float32) 39 | 40 | x = x - x.mean(axis=-1, keepdim=True) # [batch, pos, length] 41 | scale: Union[ 42 | Float[torch.Tensor, "batch pos 1"], 43 | Float[torch.Tensor, "batch pos head_index 1"], 44 | ] = self.hook_scale((x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt()) 45 | return self.hook_normalized(x / scale).to(self.cfg.dtype) 46 | 47 | 48 | class LayerNorm(nn.Module): 49 | def __init__( 50 | self, cfg: Union[Dict, HookedViTConfig], length: Optional[int] = None 51 | ): 52 | """ 53 | LayerNorm with optional length parameter 54 | 55 | length (Optional[int]): If the dimension of the LayerNorm. If not provided, assumed to be d_model 56 | """ 57 | super().__init__() 58 | if isinstance(cfg, Dict): 59 | cfg = HookedViTConfig.from_dict(cfg) 60 | self.cfg = cfg 61 | self.eps = self.cfg.eps 62 | if length is None: 63 | self.length = self.cfg.d_model 64 | else: 65 | self.length = length 66 | 67 | self.w = nn.Parameter(torch.ones(self.length, dtype=cfg.dtype)) 68 | self.b = nn.Parameter(torch.zeros(self.length, dtype=cfg.dtype)) 69 | 70 | # Adds a hook point for the normalisation scale factor 71 | self.hook_scale = HookPoint() # [batch, pos, 1] 72 | # Hook_normalized is on the LN output 73 | self.hook_normalized = HookPoint() # [batch, pos, length] 74 | 75 | def forward( 76 | self, 77 | x: Union[ 78 | Float[torch.Tensor, "batch pos d_model"], 79 | Float[torch.Tensor, "batch pos head_index d_model"], 80 | ], 81 | ) -> Union[ 82 | Float[torch.Tensor, "batch pos d_model"], 83 | Float[torch.Tensor, "batch pos head_index d_model"], 84 | ]: 85 | if self.cfg.dtype not in [torch.float32, torch.float64]: 86 | x = x.to(torch.float32) 87 | 88 | x = x - x.mean(axis=-1, keepdim=True) # [batch, pos, length] 89 | scale: Float[torch.Tensor, "batch pos 1"] = self.hook_scale( 90 | (x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt() 91 | ) 92 | x = x / scale # [batch, pos, length] 93 | return self.hook_normalized(x * self.w + self.b).to(self.cfg.dtype) 94 | -------------------------------------------------------------------------------- /src/vit_prisma/models/layers/mlp.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from fancy_einsum import einsum 7 | from jaxtyping import Float 8 | from vit_prisma.configs.HookedTextTransformerConfig import HookedTextTransformerConfig 9 | from vit_prisma.configs.HookedViTConfig import HookedViTConfig 10 | from vit_prisma.models.activation_fns import gelu_fast, gelu_new, solu, quick_gelu 11 | from vit_prisma.models.layers.layer_norm import LayerNorm, LayerNormPre 12 | from vit_prisma.prisma_tools.hook_point import HookPoint 13 | 14 | 15 | class MLP(nn.Module): 16 | 17 | def __init__(self, cfg: Union[Dict, HookedViTConfig]): 18 | super().__init__() 19 | 20 | if not isinstance(cfg, (HookedViTConfig, HookedTextTransformerConfig)): 21 | cfg = HookedViTConfig.from_dict(cfg) 22 | 23 | self.cfg = cfg 24 | 25 | self.W_in = nn.Parameter( 26 | torch.empty(self.cfg.d_model, self.cfg.d_mlp, dtype=self.cfg.dtype) 27 | ) 28 | self.b_in = nn.Parameter( 29 | torch.empty(self.cfg.d_mlp, dtype=self.cfg.dtype) 30 | ) 31 | self.W_out = nn.Parameter( 32 | torch.empty(self.cfg.d_mlp, self.cfg.d_model, dtype=self.cfg.dtype) 33 | ) 34 | self.b_out = nn.Parameter( 35 | torch.empty(self.cfg.d_model, dtype=self.cfg.dtype) 36 | ) 37 | 38 | self.hook_pre = HookPoint() 39 | self.hook_post = HookPoint() 40 | 41 | if self.cfg.activation_name == "relu": 42 | self.act_fn = F.relu 43 | elif self.cfg.activation_name == "gelu": 44 | self.act_fn = F.gelu 45 | elif self.cfg.activation_name == "silu": 46 | self.act_fn = F.silu 47 | elif self.cfg.activation_name == "gelu_new": 48 | self.act_fn = gelu_new 49 | elif self.cfg.activation_name == "gelu_fast": 50 | self.act_fn = gelu_fast 51 | elif self.cfg.activation_name == "quick_gelu": 52 | self.act_fn = quick_gelu 53 | elif self.cfg.activation_name == "solu_ln": # why does only solu have a layernorm? 54 | self.act_fn = solu 55 | # Hook taken between activation and layer norm 56 | self.hook_mid = HookPoint() # [batch, pos, d_mlp] 57 | if self.cfg.normalization_type == "LN": 58 | self.ln = LayerNorm(self.cfg, self.cfg.d_mlp) 59 | else: 60 | self.ln = LayerNormPre(self.cfg) 61 | 62 | else: 63 | raise ValueError(f"Invalid activation function name: {self.cfg.activation_name}") 64 | 65 | def forward(self, x: Float[torch.Tensor, "batch pos d_model"] 66 | ) -> Float[torch.Tensor, "batch pos d_model"]: 67 | 68 | pre_act = self.hook_pre( 69 | einsum("batch pos d_model, d_model d_mlp -> batch pos d_mlp", x, self.W_in) 70 | + self.b_in 71 | ) 72 | if not self.cfg.activation_name.endswith("_ln"): 73 | post_act = self.hook_post(self.act_fn(pre_act)) 74 | else: 75 | mid_act = self.hook_mid(self.act_fn(pre_act)) 76 | post_act = self.hook_post(self.ln(mid_act)) 77 | return ( 78 | einsum("batch pos d_mlp, d_mlp d_model -> batch pos d_model", post_act, self.W_out) 79 | + self.b_out 80 | ) 81 | 82 | -------------------------------------------------------------------------------- /src/vit_prisma/models/layers/patch_embedding.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch.nn as nn 3 | import torch 4 | from jaxtyping import Float 5 | from vit_prisma.configs import HookedViTConfig 6 | import einops 7 | 8 | class PatchEmbedding(nn.Module): 9 | 10 | def __init__(self, config, logger=None): 11 | super().__init__() 12 | self.logger = logger 13 | self.config = config 14 | self.proj = nn.Conv2d( 15 | self.config.n_channels, 16 | self.config.d_model, 17 | kernel_size=self.config.patch_size, 18 | stride=self.config.patch_size, 19 | bias=True 20 | ) 21 | 22 | def _log(self, stage, tensor): 23 | if self.logger: 24 | self.logger.info(f"{stage} size: {tensor.shape}") 25 | 26 | def forward(self, x:Float[torch.Tensor, "batch channel height width"]) -> Float[torch.Tensor, "batch n_tokens d_model"]: 27 | self._log("PatchEmbedding input", x) 28 | 29 | x = self.proj(x).flatten(2).transpose(1, 2) 30 | 31 | self._log("PatchEmbedding output", x) 32 | return x 33 | 34 | 35 | # Used for videos, 3 dimensional spacetime patches 36 | class TubeletEmbedding(nn.Module): 37 | 38 | def __init__(self, cfg:HookedViTConfig): 39 | super().__init__() 40 | self.cfg = cfg 41 | 42 | tubelet_size = [self.cfg.video_tubelet_depth, self.cfg.patch_size, self.cfg.patch_size] 43 | self.proj = nn.Conv3d( 44 | self.cfg.n_channels, 45 | self.cfg.d_model, 46 | kernel_size=tubelet_size, 47 | stride=tubelet_size, 48 | bias=True 49 | ) 50 | 51 | def forward(self, x:Float[torch.Tensor, "batch num_frames channels height width"]) -> Float[torch.Tensor, "batch n_tokens d_model"]: 52 | 53 | # Flip num_frames and channels 54 | # x = einops.rearrange(x, "b t c h w -> b c t h w") 55 | 56 | x = self.proj(x) 57 | 58 | # Flatten the tokens 59 | x = einops.rearrange(x, "b c t h w -> b (t h w) c") 60 | 61 | return x 62 | 63 | -------------------------------------------------------------------------------- /src/vit_prisma/models/layers/position_embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import einops 5 | 6 | from vit_prisma.configs.HookedViTConfig import HookedViTConfig 7 | 8 | from typing import Dict, Optional, Tuple, Union 9 | 10 | from jaxtyping import Float, Int 11 | 12 | class PosEmbedding(nn.Module): 13 | 14 | def __init__(self, cfg: Union[Dict, HookedViTConfig]): 15 | 16 | super().__init__() 17 | 18 | if isinstance(cfg, Dict): 19 | cfg = HookedViTConfig.from_dict(cfg) 20 | self.cfg = cfg 21 | 22 | num_patches = (self.cfg.image_size // self.cfg.patch_size)**2 23 | if self.cfg.is_video_transformer: 24 | num_patches = num_patches*(self.cfg.video_num_frames//self.cfg.video_tubelet_depth) 25 | 26 | token_length = num_patches + 1 if self.cfg.use_cls_token else num_patches 27 | 28 | self.W_pos = nn.Parameter( 29 | torch.empty(token_length, self.cfg.d_model, dtype=self.cfg.dtype) 30 | ) 31 | 32 | def forward( 33 | self, 34 | tokens: Int[torch.Tensor, "batch pos"], 35 | ): 36 | pos_embed = self.W_pos 37 | batch_pos_embed = einops.repeat(pos_embed, "pos d_model -> batch pos d_model", batch=tokens.size(0)) 38 | return batch_pos_embed 39 | 40 | -------------------------------------------------------------------------------- /src/vit_prisma/prisma_tools/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .hook_point import HookPoint 3 | from .hooked_root_module import HookedRootModule 4 | from .factored_matrix import FactoredMatrix -------------------------------------------------------------------------------- /src/vit_prisma/prisma_tools/hook_point.py: -------------------------------------------------------------------------------- 1 | """ 2 | Prisma Repo 3 | By Sonia Joseph 4 | 5 | Copyright (c) Sonia Joseph. All rights reserved. 6 | 7 | Inspired by TransformerLens. Some functions have been adapted from the TransformerLens project. 8 | For more information on TransformerLens, visit: https://github.com/neelnanda-io/TransformerLens 9 | """ 10 | 11 | from typing import List, Union, Dict, Callable, Tuple, Optional, Any 12 | from vit_prisma.prisma_tools.lens_handle import LensHandle 13 | import torch.nn as nn 14 | 15 | 16 | class HookPoint(nn.Module): 17 | """ 18 | A helper class to access intermediate activations in a PyTorch model (adapted from TransformerLens, which was inspired by Garcon). 19 | 20 | HookPoint is a dummy module that acts as an identity function by default. By wrapping any 21 | intermediate activation in a HookPoint, it provides a convenient way to add PyTorch hooks. 22 | """ 23 | def __init__(self): 24 | super().__init__() 25 | self.fwd_hooks: List[LensHandle] = [] 26 | self.bwd_hooks: List[LensHandle] = [] 27 | 28 | self.ctx = {} # what is this? 29 | 30 | self.name = None 31 | 32 | def add_perma_hook(self, hook, dir="fwd") -> None: 33 | self.add_hook(hook, dir, is_permanent=True) 34 | 35 | def add_hook( 36 | self, hook, dir="fwd", is_permanent=False, level=None, prepend=False 37 | ) -> None: 38 | """ 39 | If prepend is True, add this hook before all other hooks. 40 | """ 41 | 42 | if dir == "fwd": 43 | 44 | def full_hook(module, module_input, module_output): 45 | return hook(module_output, hook=self) 46 | 47 | full_hook.__name__ = ( 48 | hook.__repr__() 49 | ) 50 | 51 | handle = self.register_forward_hook(full_hook) 52 | handle = LensHandle(handle, is_permanent, level) 53 | 54 | if prepend: 55 | self._forward_hooks.move_to_end(handle.hook.id, last=False) 56 | self.fwd_hooks.insert(0, handle) 57 | else: 58 | self.fwd_hooks.append(handle) 59 | 60 | elif dir == "bwd": 61 | 62 | def full_hook(module, module_input, module_output): 63 | return hook(module_output[0], hook=self) 64 | 65 | full_hook.__name__ = ( 66 | hook.__repr__() 67 | ) 68 | 69 | handle = self.register_backward_hook(full_hook) 70 | handle = LensHandle(handle, is_permanent, level) 71 | 72 | if prepend: 73 | self._backward_hooks.move_to_end(handle.hook.id, last=False) 74 | self.bwd_hooks.insert(0, handle) 75 | else: 76 | self.bwd_hooks.append(handle) 77 | 78 | else : 79 | raise ValueError(f"Invalid dir {dir}. dir must be 'fwd' or 'bwd'") 80 | 81 | def remove_hooks(self, dir="fwd", including_permanent=False, level=None) -> None: 82 | def _remove_hooks(handles: List[LensHandle]) -> List[LensHandle]: 83 | output_handles = [] 84 | for handle in handles: 85 | if including_permanent: 86 | handle.hook.remove() 87 | elif (not handle.is_permanent) and (level is None or handle.context_level == level): 88 | handle.hook.remove() 89 | else: 90 | output_handles.append(handle) 91 | return output_handles 92 | 93 | if dir == "fwd" or dir == "both": 94 | self.fwd_hooks = _remove_hooks(self.fwd_hooks) 95 | elif dir == "bwd" or dir == "both": 96 | self.bwd_hooks = _remove_hooks(self.bwd_hooks) 97 | else: 98 | raise ValueError(f"Invalid direction {dir}. dir must be 'fwd', 'bwd', or 'both'") 99 | 100 | def clear_context(self): 101 | del self.ctx 102 | self.ctx = {} 103 | 104 | def forward(self, x): 105 | return x 106 | 107 | def layer(self): 108 | # Returns the layer index if the name has the form 'blocks.{layer}.{...}' 109 | # Helper function that's mainly useful on HookedTransformer 110 | # If it doesn't have this form, raises an error - 111 | split_name = self.name.split(".") 112 | return int(split_name[1]) 113 | 114 | -------------------------------------------------------------------------------- /src/vit_prisma/prisma_tools/lens_handle.py: -------------------------------------------------------------------------------- 1 | """ 2 | Prisma Repo 3 | By Sonia Joseph 4 | 5 | Copyright (c) Sonia Joseph. All rights reserved. 6 | 7 | Inspired by TransformerLens. Some functions have been adapted from the TransformerLens project. 8 | For more information on TransformerLens, visit: https://github.com/neelnanda-io/TransformerLens 9 | """ 10 | 11 | from dataclasses import dataclass 12 | 13 | from typing import Optional 14 | 15 | import torch.utils.hooks as hooks 16 | 17 | 18 | @dataclass 19 | class LensHandle: 20 | """Dataclass that holds information about a PyTorch hook.""" 21 | 22 | hook: hooks.RemovableHandle 23 | """Reference to the Hook's Removable Handle.""" 24 | 25 | is_permanent: bool = False 26 | """Indicates if the Hook is Permanent.""" 27 | 28 | context_level: Optional[int] = None 29 | """Context level associated with the hooks context manager for the given hook.""" 30 | 31 | 32 | -------------------------------------------------------------------------------- /src/vit_prisma/prisma_tools/logit_lens.py: -------------------------------------------------------------------------------- 1 | """ 2 | Prisma Repo 3 | By Sonia Joseph 4 | 5 | Copyright (c) Sonia Joseph. All rights reserved. 6 | 7 | Inspired by TransformerLens. Some functions have been adapted from the TransformerLens project. 8 | For more information on TransformerLens, visit: https://github.com/neelnanda-io/TransformerLens 9 | """ 10 | 11 | import numpy as np 12 | import torch 13 | from collections import defaultdict 14 | from typing import Union, Optional, Dict, List, Tuple 15 | 16 | from vit_prisma.utils.data_utils.imagenet.imagenet_dict import IMAGENET_DICT 17 | from vit_prisma.utils.data_utils.imagenet.imagenet_utils import imagenet_index_from_word 18 | 19 | 20 | def get_patch_logit_directions(cache, all_answers: torch.Tensor, incl_mid: bool = False, return_labels: bool = True) -> tuple: 21 | """ 22 | Computes the patch logit directions based on accumulated residuals from the cache. 23 | 24 | Args: 25 | cache: An object that provides methods to access and process model residuals. 26 | all_answers (torch.Tensor): A tensor containing all possible answers with shape (num_answers, d_model). 27 | incl_mid (bool, optional): Whether to include intermediate layers. Default is False. 28 | return_labels (bool, optional): Whether to return labels along with the result. Default is True. 29 | 30 | Returns: 31 | tuple: A tuple containing: 32 | - result (torch.Tensor): The computed logit directions with shape (batch_size, num_patches, num_labels, num_answers). 33 | - labels: Labels associated with the accumulated residuals, if `return_labels` is True. 34 | """ 35 | 36 | accumulated_residual, labels = cache.accumulated_resid( 37 | layer=-1, incl_mid=incl_mid, return_labels=True 38 | ) 39 | scaled_residual_stack = cache.apply_ln_to_stack( 40 | accumulated_residual, layer=-1, 41 | ) 42 | result = torch.einsum('lbpd,od -> lbpo', scaled_residual_stack, all_answers) 43 | # Rearrange so batches are first 44 | result = result.permute(1, 2, 0, 3) 45 | return result, labels 46 | 47 | def get_patch_logit_dictionary( 48 | patch_logit_directions: Union[torch.Tensor, Tuple[torch.Tensor, ...]], 49 | batch_idx: int = 0, 50 | rank_label: Optional[str] = None 51 | ) -> Dict[int, List[Tuple[float, str, int, Optional[int]]]]: 52 | """ 53 | Constructs a dictionary of patch logit predictions for a given batch index. 54 | 55 | Args: 56 | patch_logit_directions (Union[torch.Tensor, Tuple[torch.Tensor, ...]]): A tensor or a tuple of tensors 57 | containing the logit directions with shape 58 | (batch_size, num_patches, num_labels, num_answers). 59 | batch_idx (int, optional): The index of the batch to process. Default is 0. 60 | rank_label (Optional[str], optional): A label to rank against the predictions. Default is None. 61 | 62 | Returns: 63 | Dict[int, List[Tuple[float, str, int, Optional[int]]]]: A dictionary where each key is a patch index and each value is a list of tuples. 64 | Each tuple contains the logit, predicted class name, predicted index, 65 | and optionally the rank of the rank_label. 66 | """ 67 | 68 | patch_dictionary = defaultdict(list) 69 | # if tuple, get first entry 70 | if isinstance(patch_logit_directions, tuple): 71 | patch_logit_directions = patch_logit_directions[0] 72 | # Go through laeyrs of one batch 73 | for patch_idx, patches in enumerate(patch_logit_directions[batch_idx]): 74 | # Go through every patch and get max prediction 75 | for logits in patches: 76 | probs = torch.softmax(logits, dim=-1) 77 | # Get index of max prediction 78 | predicted_idx = int(torch.argmax(probs)) 79 | logit = logits[predicted_idx].item() 80 | predicted_class_name = IMAGENET_DICT[predicted_idx] 81 | if rank_label: 82 | # Where is the rank_label in the sorted list? 83 | rank_index = imagenet_index_from_word(rank_label) 84 | sorted_list = torch.argsort(probs, descending=True) 85 | rank = np.where(sorted_list == rank_index)[0][0] 86 | patch_dictionary[patch_idx].append((logit, predicted_class_name, predicted_idx, rank)) 87 | else: 88 | patch_dictionary[patch_idx].append((logit, predicted_class_name, predicted_idx)) 89 | return patch_dictionary 90 | -------------------------------------------------------------------------------- /src/vit_prisma/sae/__init__.py: -------------------------------------------------------------------------------- 1 | from .sae import StandardSparseAutoencoder, GatedSparseAutoencoder, SparseAutoencoder 2 | from .config import VisionModelSAERunnerConfig, CacheActivationsRunnerConfig 3 | from .train_sae import VisionSAETrainer 4 | from .evals.model_eval import SparsecoderEval -------------------------------------------------------------------------------- /src/vit_prisma/sae/evals/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Prisma-Multimodal/ViT-Prisma/aa0a8d531e3bbbdffe37624bfc19a07feb2b8262/src/vit_prisma/sae/evals/__init__.py -------------------------------------------------------------------------------- /src/vit_prisma/sae/training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Prisma-Multimodal/ViT-Prisma/aa0a8d531e3bbbdffe37624bfc19a07feb2b8262/src/vit_prisma/sae/training/__init__.py -------------------------------------------------------------------------------- /src/vit_prisma/sae/training/geometric_median.py: -------------------------------------------------------------------------------- 1 | from types import SimpleNamespace 2 | from typing import Optional 3 | 4 | import torch 5 | import tqdm 6 | 7 | def weighted_average(points: torch.Tensor, weights: torch.Tensor): 8 | weights = weights / weights.sum() 9 | return (points * weights.view(-1, 1)).sum(dim=0) 10 | 11 | 12 | @torch.no_grad() 13 | def geometric_median_objective( 14 | median: torch.Tensor, points: torch.Tensor, weights: torch.Tensor 15 | ) -> torch.Tensor: 16 | 17 | norms = torch.linalg.norm(points - median.view(1, -1), dim=1) # type: ignore 18 | 19 | return (norms * weights).sum() 20 | 21 | 22 | 23 | def compute_geometric_median( 24 | points: torch.Tensor, 25 | weights: Optional[torch.Tensor] = None, 26 | eps: float = 1e-6, 27 | maxiter: int = 100, 28 | ftol: float = 1e-20, 29 | do_log: bool = False, 30 | ): 31 | """ 32 | :param points: ``torch.Tensor`` of shape ``(n, d)`` 33 | :param weights: Optional ``torch.Tensor`` of shape :math:``(n,)``. 34 | :param eps: Smallest allowed value of denominator, to avoid divide by zero. 35 | Equivalently, this is a smoothing parameter. Default 1e-6. 36 | :param maxiter: Maximum number of Weiszfeld iterations. Default 100 37 | :param ftol: If objective value does not improve by at least this `ftol` fraction, terminate the algorithm. Default 1e-20. 38 | :param do_log: If true will return a log of function values encountered through the course of the algorithm 39 | :return: SimpleNamespace object with fields 40 | - `median`: estimate of the geometric median, which is a ``torch.Tensor`` object of shape :math:``(d,)`` 41 | - `termination`: string explaining how the algorithm terminated. 42 | - `logs`: function values encountered through the course of the algorithm in a list (None if do_log is false). 43 | """ 44 | with torch.no_grad(): 45 | 46 | if weights is None: 47 | weights = torch.ones((points.shape[0],), device=points.device) 48 | # initialize median estimate at mean 49 | new_weights = weights 50 | median = weighted_average(points, weights) 51 | objective_value = geometric_median_objective(median, points, weights) 52 | if do_log: 53 | logs = [objective_value] 54 | else: 55 | logs = None 56 | 57 | # Weiszfeld iterations 58 | early_termination = False 59 | pbar = tqdm.tqdm(range(maxiter)) 60 | for _ in pbar: 61 | prev_obj_value = objective_value 62 | 63 | norms = torch.linalg.norm(points - median.view(1, -1), dim=1) # type: ignore 64 | new_weights = weights / torch.clamp(norms, min=eps) 65 | median = weighted_average(points, new_weights) 66 | objective_value = geometric_median_objective(median, points, weights) 67 | 68 | if logs is not None: 69 | logs.append(objective_value) 70 | if abs(prev_obj_value - objective_value) <= ftol * objective_value: 71 | early_termination = True 72 | break 73 | 74 | pbar.set_description(f"Objective value: {objective_value:.4f}") 75 | 76 | median = weighted_average(points, new_weights) # allow autodiff to track it 77 | return SimpleNamespace( 78 | median=median, 79 | new_weights=new_weights, 80 | termination=( 81 | "function value converged within tolerance" 82 | if early_termination 83 | else "maximum iterations reached" 84 | ), 85 | logs=logs, 86 | ) -------------------------------------------------------------------------------- /src/vit_prisma/sae/training/get_scheduler.py: -------------------------------------------------------------------------------- 1 | """ 2 | Took the LR scheduler from Joseph Bloom's previous work: 3 | https://github.com/jbloomAus/DecisionTransformerInterpretability/blob/ee55df35cdb92e81d689c72fb9dd5a7252893363/src/decision_transformer/utils.py#L425 4 | """ 5 | 6 | import math 7 | from typing import Any, Optional 8 | 9 | import torch.optim as optim 10 | import torch.optim.lr_scheduler as lr_scheduler 11 | 12 | 13 | # None 14 | # Linear Warmup and decay 15 | # Cosine Annealing with Warmup 16 | # Cosine Annealing with Warmup / Restarts 17 | def get_scheduler( 18 | scheduler_name: Optional[str], optimizer: optim.Optimizer, **kwargs: Any 19 | ): 20 | """ 21 | Loosely based on this, seemed simpler write this than import 22 | transformers: https://huggingface.co/docs/transformers/main_classes/optimizer_schedules 23 | 24 | Args: 25 | scheduler_name (Optional[str]): Name of the scheduler to use. If None, returns a constant scheduler 26 | optimizer (optim.Optimizer): Optimizer to use 27 | **kwargs: Additional arguments to pass to the scheduler including warm_up_steps, 28 | training_steps, num_cycles, lr_end. 29 | """ 30 | 31 | def get_warmup_lambda(warm_up_steps: int, training_steps: int): 32 | 33 | def lr_lambda(steps: int): 34 | if steps < warm_up_steps: 35 | return (steps + 1) / warm_up_steps 36 | else: 37 | return (training_steps - steps) / (training_steps - warm_up_steps) 38 | 39 | return lr_lambda 40 | 41 | # heavily derived from hugging face although copilot helped. 42 | def get_warmup_cosine_lambda( 43 | warm_up_steps: int, training_steps: int, lr_end: float 44 | ): 45 | 46 | def lr_lambda(steps: int): 47 | if steps < warm_up_steps: 48 | return (steps + 1) / warm_up_steps 49 | else: 50 | progress = (steps - warm_up_steps) / (training_steps - warm_up_steps) 51 | return lr_end + 0.5 * (1 - lr_end) * (1 + math.cos(math.pi * progress)) 52 | 53 | return lr_lambda 54 | 55 | if scheduler_name is None or scheduler_name.lower() == "constant": 56 | return lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda steps: 1.0) 57 | elif scheduler_name.lower() == "constantwithwarmup": 58 | warm_up_steps = kwargs.get("warm_up_steps", 0) 59 | return lr_scheduler.LambdaLR( 60 | optimizer, 61 | lr_lambda=lambda steps: min(1.0, (steps + 1) / warm_up_steps), 62 | ) 63 | elif scheduler_name.lower() == "linearwarmupdecay": 64 | warm_up_steps = kwargs.get("warm_up_steps", 0) 65 | training_steps = kwargs.get("training_steps") 66 | assert training_steps is not None, "training_steps must be provided" 67 | lr_lambda = get_warmup_lambda(warm_up_steps, training_steps) 68 | return lr_scheduler.LambdaLR(optimizer, lr_lambda) 69 | elif scheduler_name.lower() == "cosineannealing": 70 | training_steps = kwargs.get("training_steps") 71 | assert training_steps is not None, "training_steps must be provided" 72 | eta_min = kwargs.get("lr_end", 0) 73 | return lr_scheduler.CosineAnnealingLR( 74 | optimizer, T_max=training_steps, eta_min=eta_min 75 | ) 76 | elif scheduler_name.lower() == "cosineannealingwarmup": 77 | warm_up_steps = kwargs.get("warm_up_steps", 0) 78 | training_steps = kwargs.get("training_steps") 79 | assert training_steps is not None, "training_steps must be provided" 80 | eta_min = kwargs.get("lr_end", 0) 81 | lr_lambda = get_warmup_cosine_lambda(warm_up_steps, training_steps, eta_min) 82 | return lr_scheduler.LambdaLR(optimizer, lr_lambda) 83 | elif scheduler_name.lower() == "cosineannealingwarmrestarts": 84 | training_steps = kwargs.get("training_steps") 85 | eta_min = kwargs.get("lr_end", 0) 86 | num_cycles = kwargs.get("num_cycles", 1) 87 | T_0 = training_steps // num_cycles 88 | return lr_scheduler.CosineAnnealingWarmRestarts( 89 | optimizer, T_0=T_0, eta_min=eta_min 90 | ) 91 | else: 92 | raise ValueError(f"Unsupported scheduler: {scheduler_name}") 93 | -------------------------------------------------------------------------------- /src/vit_prisma/sae/transcoder.py: -------------------------------------------------------------------------------- 1 | from vit_prisma.sae.sae import SparseAutoencoder 2 | import torch.nn as nn 3 | import torch 4 | import einops 5 | 6 | class Transcoder(SparseAutoencoder): 7 | 8 | def initialize_sae_weights(self): 9 | 10 | self.W_skip = nn.Parameter(self.initialize_weights(self.d_in, self.d_in)) if self.cfg.transcoder_with_skip_connection else None 11 | 12 | self.W_dec = nn.Parameter(self.initialize_weights(self.d_sae, self.cfg.d_out)) 13 | 14 | self.W_enc = nn.Parameter(self.initialize_weights(self.d_in, self.d_sae)) 15 | 16 | self.b_enc = nn.Parameter( 17 | torch.zeros(self.d_sae, dtype=self.dtype, device=self.device) 18 | ) 19 | 20 | self.b_dec = nn.Parameter( 21 | torch.zeros(self.d_in, dtype=self.dtype, device=self.device) 22 | ) 23 | 24 | self.b_dec_out = nn.Parameter( 25 | torch.zeros(self.cfg.d_out, dtype=self.dtype, device=self.device) 26 | ) 27 | 28 | 29 | def encode(self, x: torch.Tensor, return_hidden_pre: bool = False): 30 | # move x to correct dtype 31 | x = x.to(self.dtype) 32 | 33 | sae_in = self.run_time_activation_norm_fn_in(x) 34 | 35 | sae_in = self.hook_sae_in( 36 | sae_in - self.b_dec 37 | ) 38 | 39 | hidden_pre = self.hook_hidden_pre( 40 | einops.einsum( 41 | sae_in, 42 | self.W_enc, 43 | "... d_in, d_in d_sae -> ... d_sae", 44 | ) 45 | + self.b_enc 46 | ) 47 | 48 | feature_acts = self.hook_hidden_post(self.activation_fn(hidden_pre)) 49 | 50 | if return_hidden_pre: 51 | return sae_in, feature_acts, hidden_pre 52 | 53 | return sae_in, feature_acts 54 | 55 | def decode(self, features: torch.Tensor): 56 | sae_out = self.hook_sae_out( 57 | einops.einsum( 58 | features, 59 | self.W_dec, 60 | "... d_sae, d_sae d_out -> ... d_out", 61 | ) 62 | + self.b_dec_out 63 | ) 64 | 65 | return sae_out 66 | 67 | def forward( 68 | self, x: torch.Tensor, y: torch.Tensor = None, dead_neuron_mask: torch.Tensor = None, *args, **kwargs 69 | 70 | ): 71 | 72 | _, feature_acts, hidden_pre = self.encode(x, return_hidden_pre=True) 73 | sae_out = self.decode(feature_acts) 74 | 75 | if self.W_skip is not None: 76 | sae_out += x.to(self.dtype) @ self.W_skip.mT 77 | 78 | sae_out = self.run_time_activation_norm_fn_out(sae_out) 79 | 80 | mse_loss = self._compute_mse_loss(y, sae_out) 81 | 82 | # Compute ghost residual loss if required 83 | if self.cfg.use_ghost_grads and self.training and dead_neuron_mask is not None: 84 | mse_loss_ghost_resid = self._compute_ghost_residual_loss( 85 | x, sae_out, hidden_pre, dead_neuron_mask 86 | ) 87 | else: 88 | mse_loss_ghost_resid = self.zero_loss 89 | 90 | # Compute sparsity loss 91 | sparsity = feature_acts.norm(p=self.lp_norm, dim=1).mean(dim=(0,)) 92 | 93 | # Combine losses based on activation function type 94 | l1_loss = ( 95 | self.l1_coefficient * sparsity 96 | if self.cfg.activation_fn_str != "topk" 97 | else None 98 | ) 99 | 100 | loss = mse_loss + (l1_loss if l1_loss is not None else 0) + mse_loss_ghost_resid 101 | 102 | # Placeholder for auxiliary reconstruction loss 103 | aux_reconstruction_loss = torch.tensor(0.0) 104 | 105 | if hasattr(self.cfg, "return_out_only"): # to work with HookedSAEViT efficiently 106 | if self.cfg.return_out_only: 107 | return sae_out 108 | 109 | return ( 110 | sae_out, 111 | feature_acts, 112 | loss, 113 | mse_loss, 114 | l1_loss, 115 | mse_loss_ghost_resid, 116 | aux_reconstruction_loss, 117 | ) -------------------------------------------------------------------------------- /src/vit_prisma/sample_images/cat_crop.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Prisma-Multimodal/ViT-Prisma/aa0a8d531e3bbbdffe37624bfc19a07feb2b8262/src/vit_prisma/sample_images/cat_crop.jpeg -------------------------------------------------------------------------------- /src/vit_prisma/sample_images/cat_dog.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Prisma-Multimodal/ViT-Prisma/aa0a8d531e3bbbdffe37624bfc19a07feb2b8262/src/vit_prisma/sample_images/cat_dog.jpeg -------------------------------------------------------------------------------- /src/vit_prisma/sample_images/cheetah.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Prisma-Multimodal/ViT-Prisma/aa0a8d531e3bbbdffe37624bfc19a07feb2b8262/src/vit_prisma/sample_images/cheetah.jpeg -------------------------------------------------------------------------------- /src/vit_prisma/sample_images/child_lion.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Prisma-Multimodal/ViT-Prisma/aa0a8d531e3bbbdffe37624bfc19a07feb2b8262/src/vit_prisma/sample_images/child_lion.jpeg -------------------------------------------------------------------------------- /src/vit_prisma/sample_images/crop_cat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Prisma-Multimodal/ViT-Prisma/aa0a8d531e3bbbdffe37624bfc19a07feb2b8262/src/vit_prisma/sample_images/crop_cat.png -------------------------------------------------------------------------------- /src/vit_prisma/sample_images/crop_dog.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Prisma-Multimodal/ViT-Prisma/aa0a8d531e3bbbdffe37624bfc19a07feb2b8262/src/vit_prisma/sample_images/crop_dog.jpeg -------------------------------------------------------------------------------- /src/vit_prisma/sample_images/fruit_bowl.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Prisma-Multimodal/ViT-Prisma/aa0a8d531e3bbbdffe37624bfc19a07feb2b8262/src/vit_prisma/sample_images/fruit_bowl.jpeg -------------------------------------------------------------------------------- /src/vit_prisma/sample_images/n01818515_39.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Prisma-Multimodal/ViT-Prisma/aa0a8d531e3bbbdffe37624bfc19a07feb2b8262/src/vit_prisma/sample_images/n01818515_39.JPEG -------------------------------------------------------------------------------- /src/vit_prisma/sample_images/n01818515_39_typographic.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Prisma-Multimodal/ViT-Prisma/aa0a8d531e3bbbdffe37624bfc19a07feb2b8262/src/vit_prisma/sample_images/n01818515_39_typographic.JPEG -------------------------------------------------------------------------------- /src/vit_prisma/sample_images/sea_slug_typogrpahic.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Prisma-Multimodal/ViT-Prisma/aa0a8d531e3bbbdffe37624bfc19a07feb2b8262/src/vit_prisma/sample_images/sea_slug_typogrpahic.JPEG -------------------------------------------------------------------------------- /src/vit_prisma/sample_images/toilet_cat.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Prisma-Multimodal/ViT-Prisma/aa0a8d531e3bbbdffe37624bfc19a07feb2b8262/src/vit_prisma/sample_images/toilet_cat.jpeg -------------------------------------------------------------------------------- /src/vit_prisma/training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Prisma-Multimodal/ViT-Prisma/aa0a8d531e3bbbdffe37624bfc19a07feb2b8262/src/vit_prisma/training/__init__.py -------------------------------------------------------------------------------- /src/vit_prisma/training/early_stopping.py: -------------------------------------------------------------------------------- 1 | 2 | class EarlyStopping: 3 | def __init__(self, patience=3, verbose=False, delta=0): 4 | self.patience = patience 5 | self.verbose = verbose 6 | self.counter = 0 7 | self.best_score = None 8 | self.early_stop = False 9 | self.val_accuracy_max = float('-inf') 10 | self.delta = delta 11 | 12 | def __call__(self, val_accuracy): 13 | score = val_accuracy 14 | 15 | if self.best_score is None: 16 | self.best_score = score 17 | elif score < self.best_score + self.delta: 18 | self.counter += 1 19 | print(f'EarlyStopping counter: {self.counter} out of {self.patience}') 20 | if self.counter >= self.patience: 21 | print("Early stopping triggered!") 22 | self.early_stop = True 23 | else: 24 | self.best_score = score 25 | self.counter = 0 26 | -------------------------------------------------------------------------------- /src/vit_prisma/training/schedulers.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | 5 | 6 | class WarmupThenStepLR(torch.optim.lr_scheduler._LRScheduler): 7 | def __init__(self, optimizer, warmup_steps, step_size, gamma=0.5, last_epoch=-1): 8 | # Does warmup for warmup steps, then moves to step decay parameterized by step_size 9 | self.warmup_steps = warmup_steps 10 | self.step_size = step_size 11 | self.gamma = gamma 12 | super(WarmupThenStepLR, self).__init__(optimizer, last_epoch) 13 | 14 | def get_lr(self): 15 | if self.last_epoch < self.warmup_steps: 16 | return [base_lr * (self.last_epoch / self.warmup_steps) for base_lr in self.base_lrs] 17 | else: 18 | # Check if it's time for step decay 19 | if (self.last_epoch - self.warmup_steps + 1) % self.step_size == 0: 20 | current_lr = self.optimizer.param_groups[0]['lr'] # Assuming all parameter groups have the same learning rate 21 | print(f"Reducing learning rate from {current_lr} to {current_lr * self.gamma} at epoch {self.last_epoch}") 22 | 23 | return [base_lr * (self.gamma ** ((self.last_epoch - self.warmup_steps) // self.step_size)) 24 | for base_lr in self.base_lrs] 25 | 26 | 27 | class WarmupCosineAnnealingLR(torch.optim.lr_scheduler._LRScheduler): 28 | def __init__(self, optimizer, warmup_steps, total_steps, min_lr=1e-6, last_epoch=-1): 29 | """Cosine Annealing with linear warmup learning rate scheduler. It has a linear 30 | warmup phase for warmup_steps steps and then transitions to cosine annealing. 31 | 32 | Args: 33 | optimizer: The optimiser for which to schedule the learning rate 34 | warmup_steps: Number of warmup steps at the start. 35 | total_steps: Total number of training steps. 36 | min_lr: Minimum learning rate to anneal to. 37 | last_epoch: The index of last epoch. 38 | """ 39 | self.warmup_steps = warmup_steps 40 | self.total_steps = total_steps 41 | self.min_lr = min_lr 42 | 43 | # Initialize parent class - this creates self.base_lrs from optimizer 44 | super().__init__(optimizer, last_epoch) 45 | 46 | def get_lr(self): 47 | if self.last_epoch < self.warmup_steps: 48 | # Linear warmup: go from min_lr to base_lr 49 | warmup_progress = self.last_epoch / self.warmup_steps 50 | return [self.min_lr + (base_lr - self.min_lr) * warmup_progress for base_lr in self.base_lrs] 51 | else: 52 | # Cosine annealing phase 53 | progress = (self.last_epoch - self.warmup_steps) / (self.total_steps - self.warmup_steps) 54 | cosine_factor = 0.5 * (1 + math.cos(math.pi * progress)) 55 | return [self.min_lr + (base_lr - self.min_lr) * cosine_factor 56 | for base_lr in self.base_lrs] 57 | -------------------------------------------------------------------------------- /src/vit_prisma/training/training_dictionary.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | activation_dict ={ 4 | "ReLU": torch.nn.ReLU, 5 | "LeakyReLU": torch.nn.LeakyReLU, 6 | "GELU": torch.nn.GELU, 7 | "Linear": torch.nn.Identity, 8 | } 9 | 10 | initialization_dict = { 11 | "ReLU": 'relu', 12 | "LeakyReLU": 'relu', 13 | "GELU": 'relu', # for 'he' intitializatoin 14 | } 15 | 16 | optimizer_dict = { 17 | "AdamW": torch.optim.AdamW, 18 | } 19 | 20 | loss_function_dict = { 21 | "CrossEntropy": torch.nn.CrossEntropyLoss, 22 | "MSE": torch.nn.MSELoss, 23 | } -------------------------------------------------------------------------------- /src/vit_prisma/training/training_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from itertools import islice 4 | import numpy as np 5 | import random 6 | from vit_prisma.training.training_dictionary import loss_function_dict 7 | from abc import ABC, abstractmethod 8 | 9 | class PrismaCallback(ABC): 10 | @abstractmethod 11 | def on_epoch_end(self, epoch, net, val_loader, wandb_logger): 12 | pass 13 | 14 | @abstractmethod 15 | def on_step_end(self, step, net, val_loader, wandb_logger): 16 | pass 17 | 18 | def calculate_accuracy(net, data_loader, device, N=2000, batch_size=50): 19 | net.eval() 20 | with torch.no_grad(): 21 | correct = 0 22 | total = 0 23 | for items in islice(data_loader, N // batch_size): 24 | x, labels, *extras = items 25 | logits = net(x.to(device)) 26 | predictions = torch.argmax(logits, dim=1) 27 | correct += torch.sum(predictions == labels.to(device)).item() 28 | total += len(labels) 29 | return correct / total 30 | 31 | def calculate_loss(net, data_loader, loss_fn, device, N=2000, batch_size=50): 32 | net.eval() 33 | with torch.no_grad(): 34 | total = 0 35 | points = 0 36 | for items in islice(data_loader, N // batch_size): 37 | x, labels, *extras = items 38 | logits = net(x.to(device)) 39 | total += loss_fn(logits, labels.to(device)).item() 40 | points += len(labels) 41 | return total / points 42 | 43 | def set_seed(seed): 44 | torch.manual_seed(seed) 45 | torch.cuda.manual_seed_all(seed) 46 | random.seed(seed) 47 | np.random.seed(seed) -------------------------------------------------------------------------------- /src/vit_prisma/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from .model_transforms import * -------------------------------------------------------------------------------- /src/vit_prisma/transforms/model_transforms.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | from torchvision.transforms.functional import InterpolationMode 3 | from transformers import AutoImageProcessor 4 | 5 | 6 | def _convert_to_rgb(image): 7 | return image.convert('RGB') 8 | 9 | def get_clip_val_transforms( 10 | image_size=224, 11 | mean=[0.48145466, 0.4578275, 0.40821073], 12 | std=[0.26862954, 0.26130258, 0.27577711] 13 | ): 14 | return transforms.Compose([ 15 | transforms.Resize(size=image_size, interpolation=InterpolationMode.BICUBIC, max_size=None, antialias=True), 16 | transforms.CenterCrop(size=(image_size, image_size)), 17 | _convert_to_rgb, 18 | transforms.ToTensor(), 19 | transforms.Normalize(mean=mean, std=std) 20 | ]) 21 | 22 | def get_model_transforms( 23 | model_name 24 | ): 25 | if model_name.startswith("open-clip:"): 26 | return get_clip_val_transforms() 27 | else: 28 | try: 29 | img_processor = AutoImageProcessor.from_pretrained(model_name) 30 | img_size = img_processor.size['height'] 31 | return transforms.Compose([ 32 | transforms.Resize(size=img_size, interpolation=InterpolationMode.BICUBIC, max_size=None, antialias=True), 33 | transforms.CenterCrop(size=(img_size, img_size)), 34 | _convert_to_rgb, 35 | transforms.ToTensor(), 36 | transforms.Normalize(mean=img_processor.image_mean, std=img_processor.image_std) 37 | ]) 38 | except: 39 | raise ValueError(f"Image processor for {model_name} not found. Please define the appropriate data transforms") 40 | 41 | -------------------------------------------------------------------------------- /src/vit_prisma/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Prisma-Multimodal/ViT-Prisma/aa0a8d531e3bbbdffe37624bfc19a07feb2b8262/src/vit_prisma/utils/__init__.py -------------------------------------------------------------------------------- /src/vit_prisma/utils/constants.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | import torch 4 | from pathlib import Path 5 | 6 | BASE_DIR = Path(__file__).resolve().parent.parent.parent.parent 7 | DATA_DIR = BASE_DIR / "data" 8 | MODEL_DIR = BASE_DIR / "models" 9 | MODEL_CHECKPOINTS_DIR = MODEL_DIR / "checkpoints" 10 | SAE_CHECKPOINTS_DIR = MODEL_DIR / "sae" 11 | 12 | def get_device(): 13 | if torch.cuda.is_available(): 14 | return torch.device("cuda") 15 | elif torch.backends.mps.is_available(): 16 | return torch.device("mps") 17 | else: 18 | return torch.device("cpu") 19 | 20 | DEVICE = get_device() 21 | -------------------------------------------------------------------------------- /src/vit_prisma/utils/data_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Prisma-Multimodal/ViT-Prisma/aa0a8d531e3bbbdffe37624bfc19a07feb2b8262/src/vit_prisma/utils/data_utils/__init__.py -------------------------------------------------------------------------------- /src/vit_prisma/utils/data_utils/cifar/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Prisma-Multimodal/ViT-Prisma/aa0a8d531e3bbbdffe37624bfc19a07feb2b8262/src/vit_prisma/utils/data_utils/cifar/__init__.py -------------------------------------------------------------------------------- /src/vit_prisma/utils/data_utils/cifar/cifar_10_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | from torch.utils.data import random_split, Dataset 5 | from torchvision import datasets 6 | from torchvision import transforms 7 | from torchvision.transforms import RandAugment 8 | 9 | 10 | def get_cifar_transforms(augmentation: bool, image_size: int = 128, visualisation=False): 11 | cifar_transforms = [ 12 | transforms.ToTensor(), 13 | ] 14 | 15 | if visualisation: 16 | cifar_transforms = cifar_transforms + [transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),] 17 | 18 | cifar_transforms = cifar_transforms + [transforms.Resize((image_size, image_size)),] 19 | 20 | if augmentation: 21 | cifar_transforms = [ 22 | transforms.RandomResizedCrop(128, scale=(0.8, 1.0), ratio=(0.9, 1.1)), # Moderate crop 23 | transforms.RandomHorizontalFlip(), 24 | transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), # Added color jitter 25 | RandAugment(2, 10), 26 | ] + cifar_transforms 27 | 28 | transform = transforms.Compose(cifar_transforms) 29 | 30 | return transform 31 | 32 | 33 | def load_cifar_10( 34 | dataset_path: str, 35 | split_size: float = 0.8, 36 | augmentation: bool = False, 37 | image_size: int = 128, 38 | with_index=False, 39 | visualisation=False, 40 | ) -> Tuple[Dataset, Dataset, Dataset]: 41 | """Load CIFAR-10 from Torchvision. It will cache the data locally.""" 42 | 43 | if visualisation: 44 | train_transform = get_cifar_transforms(augmentation, image_size, visualisation=True) 45 | test_transform = get_cifar_transforms(augmentation=False, image_size=image_size, visualisation=True) 46 | else: 47 | train_transform = get_cifar_transforms(augmentation, image_size, visualisation=False) 48 | test_transform = get_cifar_transforms(augmentation=False, image_size=image_size, visualisation=False) 49 | 50 | print(f"Downloading or fetching the CIFAR-10 dataset to: {dataset_path}") 51 | 52 | cifar10_trainset = datasets.CIFAR10( 53 | root=dataset_path, 54 | train=True, 55 | download=True, 56 | transform=train_transform, 57 | ) 58 | cifar10_test_dataset = datasets.CIFAR10( 59 | root=dataset_path, 60 | train=False, 61 | download=True, 62 | transform=test_transform, 63 | ) 64 | 65 | train_size = int(len(cifar10_trainset) * split_size) 66 | val_size = len(cifar10_trainset) - train_size 67 | 68 | cifar10_train_dataset, cifar10_val_dataset = random_split( 69 | cifar10_trainset, 70 | [train_size, val_size], 71 | generator=torch.Generator().manual_seed(42), 72 | ) 73 | print(f"There are {len(cifar10_train_dataset)} samples in the train dataset") 74 | print(f"There are {len(cifar10_val_dataset)} samples in the validation dataset") 75 | print(f"There are {len(cifar10_test_dataset)} samples in the test dataset") 76 | # Subset(cifar10_test_dataset, indices=list(range(len(cifar10_test_dataset)))) 77 | if with_index: 78 | return IndexPreservingSubset(cifar10_train_dataset.dataset, 79 | indices=cifar10_train_dataset.indices), IndexPreservingSubset( 80 | cifar10_val_dataset.dataset, indices=cifar10_val_dataset.indices), cifar10_test_dataset 81 | else: 82 | return cifar10_train_dataset, cifar10_val_dataset, cifar10_test_dataset -------------------------------------------------------------------------------- /src/vit_prisma/utils/data_utils/imagenet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Prisma-Multimodal/ViT-Prisma/aa0a8d531e3bbbdffe37624bfc19a07feb2b8262/src/vit_prisma/utils/data_utils/imagenet/__init__.py -------------------------------------------------------------------------------- /src/vit_prisma/utils/data_utils/imagenet/imagenet_utils.py: -------------------------------------------------------------------------------- 1 | from vit_prisma.utils.data_utils.imagenet.imagenet_dict import IMAGENET_DICT 2 | 3 | import os 4 | 5 | 6 | def imagenet_index_from_word(search_term: str) -> int: 7 | """ 8 | Finds the ImageNet index corresponding to a search term. 9 | 10 | Args: 11 | search_term (str): The search term to look up in the ImageNet dictionary. 12 | 13 | Returns: 14 | int: The index corresponding to the search term in the ImageNet dictionary. 15 | 16 | Raises: 17 | ValueError: If the search term is not found in the ImageNet dictionary. 18 | """ 19 | 20 | # Convert the search term to lowercase to ensure case-insensitive matching 21 | search_term = search_term.lower() 22 | 23 | # Iterate over the dictionary and search for the term 24 | for key, value in IMAGENET_DICT.items(): 25 | if ( 26 | search_term in value.lower() 27 | ): # Convert each value to lowercase for case-insensitive comparison 28 | return key # Return the key directly once found 29 | 30 | # If the loop completes without returning, the term was not found; raise an exception 31 | raise ValueError(f"'{search_term}' not found in IMAGENET_DICT.") 32 | 33 | 34 | def setup_imagenet_paths(imagenet_path, format="kaggle") -> dict: 35 | # Currently based on kaggle dataset structure, can be modified for particular dataset type. 36 | 37 | if format == "kaggle": 38 | return { 39 | "train": os.path.join(imagenet_path, "ILSVRC/Data/CLS-LOC/train"), 40 | "val": os.path.join(imagenet_path, "ILSVRC/Data/CLS-LOC/val"), 41 | "val_labels": os.path.join(imagenet_path, "LOC_val_solution.csv"), 42 | "label_strings": os.path.join(imagenet_path, "LOC_synset_mapping.txt"), 43 | } 44 | 45 | elif format == "hhi": 46 | return { 47 | "train": os.path.join(imagenet_path, "train"), 48 | "val": os.path.join(imagenet_path, "val"), 49 | "val_labels": os.path.join(imagenet_path, "LOC_val_solution.csv"), 50 | "label_strings": os.path.join(imagenet_path, "LOC_synset_mapping.txt"), 51 | } 52 | -------------------------------------------------------------------------------- /src/vit_prisma/utils/data_utils/loader.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import torchvision 4 | from torch.utils.data import Dataset 5 | 6 | from vit_prisma.dataloaders.imagenet_dataset import ( 7 | ImageNetValidationDataset, 8 | ) 9 | from vit_prisma.transforms.model_transforms import get_model_transforms 10 | from vit_prisma.utils.data_utils.cifar.cifar_10_utils import load_cifar_10 11 | from vit_prisma.utils.data_utils.imagenet.imagenet_utils import setup_imagenet_paths 12 | 13 | 14 | class SubsetDataset(Dataset): 15 | """For testing purposes - reduce a Dataset to N of its samples.""" 16 | 17 | def __init__(self, dataset, n, random_subset=False): 18 | self.dataset = dataset 19 | self.n = min(n, len(dataset)) 20 | if random_subset: 21 | self.indices = random.sample(range(len(dataset)), self.n) 22 | else: 23 | self.indices = list(range(self.n)) 24 | 25 | self.data = [] 26 | for idx in self.indices: 27 | self.data.append(self.dataset[idx]) 28 | 29 | del self.dataset 30 | 31 | def __getitem__(self, index): 32 | return self.data[index] 33 | 34 | def __len__(self): 35 | return self.n 36 | 37 | 38 | def load_dataset(cfg, visualize=False): 39 | """Load and prepare datasets based on configuration and model type. Currently the 40 | function loads either ImageNet1K or CIFAR-10 datasets. 41 | """ 42 | print(f"Dataset type: {cfg.dataset_name}") if cfg.verbose else None 43 | 44 | data_transforms = get_model_transforms(cfg.model_name) 45 | 46 | if cfg.dataset_name == "imagenet1k": 47 | imagenet_paths = setup_imagenet_paths(cfg.dataset_path) 48 | train_data = torchvision.datasets.ImageFolder( 49 | cfg.dataset_train_path, transform=data_transforms 50 | ) 51 | val_data = ImageNetValidationDataset( 52 | cfg.dataset_val_path, 53 | imagenet_paths["label_strings"], 54 | imagenet_paths["val_labels"], 55 | data_transforms, 56 | return_index=True, 57 | ) 58 | if visualize: 59 | val_data_visualize = ImageNetValidationDataset( 60 | cfg.dataset_val_path, 61 | imagenet_paths["label_strings"], 62 | imagenet_paths["val_labels"], 63 | torchvision.transforms.Compose( 64 | [ 65 | torchvision.transforms.Resize((224, 224)), 66 | torchvision.transforms.ToTensor(), 67 | ] 68 | ), 69 | return_index=True, 70 | ) 71 | else: 72 | val_data_visualize = None 73 | 74 | print(f"Train data length: {len(train_data)}") if cfg.verbose else None 75 | print(f"Validation data length: {len(val_data)}") if cfg.verbose else None 76 | return train_data, val_data, val_data_visualize 77 | elif cfg.dataset_name == "cifar10": 78 | train_data, val_data, _ = load_cifar_10(cfg.dataset_path) 79 | return train_data, val_data, None 80 | else: 81 | # raise error 82 | raise ValueError("Invalid dataset name") 83 | -------------------------------------------------------------------------------- /src/vit_prisma/utils/devices.py: -------------------------------------------------------------------------------- 1 | """Devices. 2 | 3 | Utilities to get the correct device, and assist in distributing model layers across multiple 4 | devices. 5 | """ 6 | from __future__ import annotations 7 | 8 | from typing import Optional, Union 9 | 10 | import torch 11 | from torch import nn 12 | 13 | from vit_prisma.configs.HookedViTConfig import HookedViTConfig 14 | 15 | from vit_prisma.prisma_tools.hooked_root_module import HookedRootModule 16 | 17 | def get_device_for_block_index( 18 | index: int, 19 | cfg: HookedViTConfig, 20 | device: Optional[Union[torch.device, str]] = None, 21 | ): 22 | """ 23 | Determine the device for a given layer index based on the model configuration. 24 | 25 | This function assists in distributing model layers across multiple devices. The distribution 26 | is based on the configuration's number of layers (cfg.n_layers) and devices (cfg.n_devices). 27 | 28 | Args: 29 | index (int): Model layer index. 30 | cfg (HookedTransformerConfig): Model and device configuration. 31 | device (Optional[Union[torch.device, str]], optional): Initial device used for determining the target device. 32 | If not provided, the function uses the device specified in the configuration (cfg.device). 33 | 34 | Returns: 35 | torch.device: The device for the specified layer index. 36 | """ 37 | assert cfg.device is not None 38 | layers_per_device = cfg.n_layers // cfg.n_devices 39 | if device is None: 40 | device = cfg.device 41 | device = torch.device(device) 42 | device_index = (device.index or 0) + (index // layers_per_device) 43 | return torch.device(device.type, device_index) 44 | 45 | 46 | def move_to_and_update_config( 47 | # model: Union[ 48 | # HookedViT 49 | # ], 50 | model: "HookedViT", 51 | device_or_dtype: Union[torch.device, str, torch.dtype], 52 | print_details=True, 53 | ): 54 | """ 55 | Wrapper around `to` that also updates `model.cfg`. 56 | """ 57 | if isinstance(device_or_dtype, torch.device): 58 | model.cfg.device = device_or_dtype.type 59 | if print_details: 60 | print("Moving model to device: ", model.cfg.device) 61 | elif isinstance(device_or_dtype, str): 62 | model.cfg.device = device_or_dtype 63 | if print_details: 64 | print("Moving model to device: ", model.cfg.device) 65 | elif isinstance(device_or_dtype, torch.dtype): 66 | model.cfg.dtype = device_or_dtype 67 | if print_details: 68 | print("Changing model dtype to", device_or_dtype) 69 | # change state_dict dtypes 70 | for k, v in model.state_dict().items(): 71 | model.state_dict()[k] = v.to(device_or_dtype) 72 | return nn.Module.to(model, device_or_dtype) -------------------------------------------------------------------------------- /src/vit_prisma/utils/enums.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class ModelType(Enum): 5 | TEXT = "text" 6 | VISION = "vision" 7 | -------------------------------------------------------------------------------- /src/vit_prisma/utils/load_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any 3 | from typing import Tuple 4 | 5 | import torch 6 | from vit_prisma.models.base_text_transformer import HookedTextTransformer 7 | from vit_prisma.models.base_vit import HookedViT 8 | from huggingface_hub import hf_hub_download 9 | from vit_prisma.sae import VisionModelSAERunnerConfig 10 | from vit_prisma.sae import SparseAutoencoder 11 | from vit_prisma.utils.constants import DEVICE 12 | from vit_prisma.utils.enums import ModelType 13 | 14 | 15 | def load_model( 16 | cfg: VisionModelSAERunnerConfig, 17 | model_from_pretrained_kwargs: dict = None 18 | ) -> HookedViT: 19 | model_from_pretrained_kwargs = model_from_pretrained_kwargs or {} 20 | 21 | if "n_devices" in model_from_pretrained_kwargs: 22 | n_devices = model_from_pretrained_kwargs["n_devices"] 23 | if n_devices > 1: 24 | print("MODEL LOADING:") 25 | print("Setting model device to cuda for d_devices") 26 | print(f"Will use cuda:0 to cuda:{n_devices-1}") 27 | device = "cuda" 28 | print("-------------") 29 | 30 | if cfg.model_class_name == "HookedViT": 31 | is_timm, is_clip, is_local = set_flags(cfg.model_name) 32 | 33 | if is_local: 34 | return HookedViT.from_local(cfg.prisma_vit_cfg, cfg.model_path).to(cfg.device) 35 | 36 | fold_ln = cfg.model_name != "openai/clip-vit-base-patch32" 37 | 38 | return HookedViT.from_pretrained(cfg.model_name, is_timm=is_timm, is_clip=is_clip, fold_ln=fold_ln).to(cfg.device) 39 | 40 | else: # pragma: no cover 41 | raise ValueError(f"Unknown model class: {cfg.model_class_name}") 42 | 43 | 44 | def set_flags(model_name): 45 | model_name_lower = model_name.lower() 46 | is_clip = 'clip' in model_name_lower 47 | is_local = 'local' in model_name_lower 48 | is_timm = (not is_clip) and (not is_local) 49 | return is_timm, is_clip, is_local 50 | 51 | 52 | def load_remote_sae_and_model( 53 | repo_id: str, 54 | checkpoint="n_images_2600058.pt", 55 | config_file: str = "config.json", 56 | current_cfg: dict = None 57 | ) -> Tuple[SparseAutoencoder, HookedTextTransformer, HookedViT]: 58 | """Load and test SAE from HuggingFace.""" 59 | 60 | print(f"Loading SAE from repo_id: {repo_id} with checkpoint: {checkpoint}") 61 | sae_path = hf_hub_download(repo_id, checkpoint) 62 | sae_config_path = hf_hub_download(repo_id, config_file) 63 | 64 | sae = SparseAutoencoder.load_from_pretrained(sae_path, config_path=sae_config_path, current_cfg=current_cfg, ) 65 | 66 | print(f"Loading model name: {sae.cfg.model_name}") 67 | print(f"The config device is: {sae.cfg.device}") 68 | language_model = HookedTextTransformer.from_pretrained(sae.cfg.model_name, is_timm=False, is_clip=True, model_type=ModelType.TEXT).to(sae.cfg.device) 69 | model = HookedViT.from_pretrained(sae.cfg.model_name, is_timm=False, is_clip=True).to(sae.cfg.device) 70 | 71 | sae = sae.to(DEVICE) 72 | print(f"Using device: {DEVICE}") 73 | 74 | return sae, language_model, model 75 | -------------------------------------------------------------------------------- /src/vit_prisma/utils/openai_templates.py: -------------------------------------------------------------------------------- 1 | OPENAI_IMAGENET_TEMPLATES = ( 2 | lambda c: f'a bad photo of a {c}.', 3 | lambda c: f'a photo of many {c}.', 4 | lambda c: f'a sculpture of a {c}.', 5 | lambda c: f'a photo of the hard to see {c}.', 6 | lambda c: f'a low resolution photo of the {c}.', 7 | lambda c: f'a rendering of a {c}.', 8 | lambda c: f'graffiti of a {c}.', 9 | lambda c: f'a bad photo of the {c}.', 10 | lambda c: f'a cropped photo of the {c}.', 11 | lambda c: f'a tattoo of a {c}.', 12 | lambda c: f'the embroidered {c}.', 13 | lambda c: f'a photo of a hard to see {c}.', 14 | lambda c: f'a bright photo of a {c}.', 15 | lambda c: f'a photo of a clean {c}.', 16 | lambda c: f'a photo of a dirty {c}.', 17 | lambda c: f'a dark photo of the {c}.', 18 | lambda c: f'a drawing of a {c}.', 19 | lambda c: f'a photo of my {c}.', 20 | lambda c: f'the plastic {c}.', 21 | lambda c: f'a photo of the cool {c}.', 22 | lambda c: f'a close-up photo of a {c}.', 23 | lambda c: f'a black and white photo of the {c}.', 24 | lambda c: f'a painting of the {c}.', 25 | lambda c: f'a painting of a {c}.', 26 | lambda c: f'a pixelated photo of the {c}.', 27 | lambda c: f'a sculpture of the {c}.', 28 | lambda c: f'a bright photo of the {c}.', 29 | lambda c: f'a cropped photo of a {c}.', 30 | lambda c: f'a plastic {c}.', 31 | lambda c: f'a photo of the dirty {c}.', 32 | lambda c: f'a jpeg corrupted photo of a {c}.', 33 | lambda c: f'a blurry photo of the {c}.', 34 | lambda c: f'a photo of the {c}.', 35 | lambda c: f'a good photo of the {c}.', 36 | lambda c: f'a rendering of the {c}.', 37 | lambda c: f'a {c} in a video game.', 38 | lambda c: f'a photo of one {c}.', 39 | lambda c: f'a doodle of a {c}.', 40 | lambda c: f'a close-up photo of the {c}.', 41 | lambda c: f'a photo of a {c}.', 42 | lambda c: f'the origami {c}.', 43 | lambda c: f'the {c} in a video game.', 44 | lambda c: f'a sketch of a {c}.', 45 | lambda c: f'a doodle of the {c}.', 46 | lambda c: f'a origami {c}.', 47 | lambda c: f'a low resolution photo of a {c}.', 48 | lambda c: f'the toy {c}.', 49 | lambda c: f'a rendition of the {c}.', 50 | lambda c: f'a photo of the clean {c}.', 51 | lambda c: f'a photo of a large {c}.', 52 | lambda c: f'a rendition of a {c}.', 53 | lambda c: f'a photo of a nice {c}.', 54 | lambda c: f'a photo of a weird {c}.', 55 | lambda c: f'a blurry photo of a {c}.', 56 | lambda c: f'a cartoon {c}.', 57 | lambda c: f'art of a {c}.', 58 | lambda c: f'a sketch of the {c}.', 59 | lambda c: f'a embroidered {c}.', 60 | lambda c: f'a pixelated photo of a {c}.', 61 | lambda c: f'itap of the {c}.', 62 | lambda c: f'a jpeg corrupted photo of the {c}.', 63 | lambda c: f'a good photo of a {c}.', 64 | lambda c: f'a plushie {c}.', 65 | lambda c: f'a photo of the nice {c}.', 66 | lambda c: f'a photo of the small {c}.', 67 | lambda c: f'a photo of the weird {c}.', 68 | lambda c: f'the cartoon {c}.', 69 | lambda c: f'art of the {c}.', 70 | lambda c: f'a drawing of the {c}.', 71 | lambda c: f'a photo of the large {c}.', 72 | lambda c: f'a black and white photo of a {c}.', 73 | lambda c: f'the plushie {c}.', 74 | lambda c: f'a dark photo of a {c}.', 75 | lambda c: f'itap of a {c}.', 76 | lambda c: f'graffiti of the {c}.', 77 | lambda c: f'a toy {c}.', 78 | lambda c: f'itap of my {c}.', 79 | lambda c: f'a photo of a cool {c}.', 80 | lambda c: f'a photo of a small {c}.', 81 | lambda c: f'a tattoo of the {c}.', 82 | ) 83 | -------------------------------------------------------------------------------- /src/vit_prisma/utils/saving_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | 4 | from pathlib import Path 5 | import re 6 | 7 | 8 | # Save config along with trainig files 9 | 10 | 11 | def object_to_dict(obj): 12 | if isinstance(obj, (int, float, str, bool, type(None))): 13 | return obj 14 | elif isinstance(obj, dict): 15 | return {k: object_to_dict(v) for k, v in obj.items() if not callable(v)} 16 | elif isinstance(obj, list): 17 | return [object_to_dict(item) for item in obj] 18 | elif hasattr(obj, "__dict__"): # for custom objects 19 | # Filter out methods or built-in functions 20 | return object_to_dict({k: v for k, v in obj.__dict__.items() if not callable(v)}) 21 | elif isinstance(obj, torch.dtype): 22 | return str(obj) 23 | else: 24 | raise TypeError(f"Object of type {type(obj)} is not serializable to JSON") 25 | 26 | # Function to save config to a JSON file 27 | def save_config_to_file(config, file_path): 28 | with open(file_path, 'w') as f: 29 | json.dump(object_to_dict(config), f, indent=4) 30 | 31 | 32 | # Get current Prisma version 33 | def get_version(): 34 | # Read the version from setup.py 35 | setup_py_path = Path(__file__).parent.parent.parent.parent 36 | setup_py_path = setup_py_path / 'setup.py' 37 | 38 | # Check if setup.py exists 39 | if not setup_py_path.exists(): 40 | raise FileNotFoundError(f"setup.py not found at {setup_py_path}") 41 | 42 | # Read the content of setup.py 43 | with open(setup_py_path, 'r') as f: 44 | content = f.read() 45 | 46 | # Try to find version using regex 47 | version_match = re.search(r"version\s*=\s*['\"]([^'\"]+)['\"]", content) 48 | if version_match: 49 | return version_match.group(1) 50 | 51 | # If version is not found, print the content of setup.py 52 | print("Content of setup.py:") 53 | print(content) 54 | 55 | raise ValueError("Version not found in setup.py. Content of the file is printed above.") -------------------------------------------------------------------------------- /src/vit_prisma/utils/wandb_utils.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | 3 | def update_dataclass_from_dict(dc, dct): 4 | for key, value in dct.items(): 5 | if hasattr(dc, key): 6 | attr = getattr(dc, key) 7 | if dataclasses.is_dataclass(attr): 8 | update_dataclass_from_dict(attr, value) 9 | else: 10 | setattr(dc, key, value) 11 | 12 | def update_config_with_wandb_sweep(config, sweep_values): 13 | update_dataclass_from_dict(config, sweep_values) 14 | 15 | def dataclass_to_dict(obj): 16 | if dataclasses.is_dataclass(obj): 17 | return {name: dataclass_to_dict(value) for name, value in dataclasses.asdict(obj).items()} 18 | else: 19 | return obj 20 | -------------------------------------------------------------------------------- /src/vit_prisma/visualization/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Prisma-Multimodal/ViT-Prisma/aa0a8d531e3bbbdffe37624bfc19a07feb2b8262/src/vit_prisma/visualization/__init__.py -------------------------------------------------------------------------------- /src/vit_prisma/visualization/get_color.js: -------------------------------------------------------------------------------- 1 | function getColor(intensity, colorMapName='viridis') { 2 | 3 | const colorMaps = { 4 | deep: [ 5 | { pos: 0.0, rgb: [76, 114, 176] }, 6 | { pos: 0.1, rgb: [221, 132, 82] }, 7 | { pos: 0.2, rgb: [85, 168, 104] }, 8 | { pos: 0.3, rgb: [196, 78, 82] }, 9 | { pos: 0.4, rgb: [129, 114, 179] }, 10 | { pos: 0.5, rgb: [147, 120, 96] }, 11 | { pos: 0.6, rgb: [218, 139, 195] }, 12 | { pos: 0.7, rgb: [140, 140, 140] }, 13 | { pos: 0.8, rgb: [204, 185, 116] }, 14 | { pos: 0.9, rgb: [100, 181, 205] }, 15 | { pos: 1.0, rgb: [253, 231, 37] } // This last color is not part of seaborn's deep, it's added for completion 16 | ], 17 | viridis: [ 18 | {pos: 0, rgb: [68, 1, 84]} , 19 | {pos: 0.1, rgb: [72, 34, 115]}, 20 | {pos: 0.2, rgb: [64, 67, 135]}, 21 | {pos: 0.3, rgb: [52, 94, 141]}, 22 | {pos: 0.4, rgb: [41, 120, 142]}, 23 | {pos: 0.5, rgb: [32, 144, 140]}, 24 | {pos: 0.6, rgb: [34, 167, 132]}, 25 | {pos: 0.7, rgb: [68, 190, 112]}, 26 | {pos: 0.8, rgb: [121, 209, 81]}, 27 | {pos: 0.9, rgb: [189, 222, 38]}, 28 | {pos: 1.0, rgb: [253, 231, 37]} 29 | ] 30 | }; 31 | 32 | // Select the color map based on the colorMapName argument 33 | const colorMap = colorMaps[colorMapName]; 34 | 35 | for (let i = 0; i < colorMap.length - 1; i++) { 36 | const start = colorMap[i]; 37 | const end = colorMap[i + 1]; 38 | if (intensity >= start.pos && intensity < end.pos) { 39 | const ratio = (intensity - start.pos) / (end.pos - start.pos); 40 | const r = Math.floor(start.rgb[0] + ratio * (end.rgb[0] - start.rgb[0])); 41 | const g = Math.floor(start.rgb[1] + ratio * (end.rgb[1] - start.rgb[1])); 42 | const b = Math.floor(start.rgb[2] + ratio * (end.rgb[2] - start.rgb[2])); 43 | return `rgba(${r}, ${g}, ${b}, 1.0)`; 44 | } 45 | } 46 | return `rgba(253, 231, 37, 1.0)`; 47 | } 48 | -------------------------------------------------------------------------------- /src/vit_prisma/visualization/patch_to_img.js: -------------------------------------------------------------------------------- 1 | function patchToImageData(patch, width, height) { 2 | var imgData = new ImageData(width, height); 3 | var data = imgData.data; 4 | for (let p = 0, q = 0; p < patch.length; p += 3, q += 4) { 5 | data.set([patch[p], patch[p + 1], patch[p + 2], 255], q); 6 | } 7 | return imgData; 8 | } -------------------------------------------------------------------------------- /src/vit_prisma/visualization/sample_cifar10_image_0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Prisma-Multimodal/ViT-Prisma/aa0a8d531e3bbbdffe37624bfc19a07feb2b8262/src/vit_prisma/visualization/sample_cifar10_image_0.npy -------------------------------------------------------------------------------- /src/vit_prisma/visualization/sample_cifar10_image_10.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Prisma-Multimodal/ViT-Prisma/aa0a8d531e3bbbdffe37624bfc19a07feb2b8262/src/vit_prisma/visualization/sample_cifar10_image_10.npy -------------------------------------------------------------------------------- /src/vit_prisma/visualization/template.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 |
8 | 9 | 10 |
11 |
12 | 13 | 14 | 15 |
16 | 17 | 28 | 29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /src/vit_prisma/visualization/visualize_attention.py: -------------------------------------------------------------------------------- 1 | # Helper function: javascript visualization 2 | import numpy as np 3 | import json 4 | from IPython.display import display, HTML 5 | import string 6 | import random 7 | 8 | # Helper function to plot attention patterns, hide this function 9 | 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | 13 | 14 | def plot_attn_heads( 15 | total_activations, 16 | n_heads=12, 17 | n_layers=12, 18 | img_shape=50, 19 | idx=0, 20 | figsize=(20, 20), 21 | global_min_max=False, 22 | global_normalize=False, 23 | fourier_transform_local=False, 24 | log_transform=False, 25 | fourier_transform_global=False, 26 | graph_type="imshow_graph", 27 | cmap="viridis", 28 | ): 29 | 30 | # New shape handling: total_activations is now expected to be of shape [n_layers*n_heads, img_shape, img_shape] 31 | total_data = np.zeros((n_layers * n_heads, img_shape, img_shape)) 32 | 33 | # Adjusted processing for flattened layer-heads structure 34 | if global_min_max or global_normalize or fourier_transform_global: 35 | for i in range(n_layers * n_heads): 36 | data = total_activations[i, :, :] 37 | if log_transform: 38 | data = np.log10(np.maximum(data, 1e-6)) # log10_stable equivalent 39 | if fourier_transform_global: 40 | data = np.abs(np.fft.fftshift(np.fft.fft2(data))) 41 | total_data[i, :, :] = data 42 | 43 | total_min, total_max = np.min(total_data), np.max(total_data) 44 | print(f"Total Min: {total_min}, Total Max: {total_max}") 45 | 46 | if global_normalize: 47 | total_data = -1 + 2 * (total_data - total_min) / (total_max - total_min) 48 | 49 | fig, axes = plt.subplots( 50 | n_layers, n_heads, figsize=figsize, squeeze=False 51 | ) # Ensure axes is always 2D array 52 | total_data_dict = {} 53 | 54 | for i in range(n_layers): 55 | total_data_dict[f"Layer_{i}"] = {} 56 | for j in range(n_heads): 57 | # Adjust indexing for the flattened layer-head structure 58 | linear_idx = i * n_heads + j 59 | data = total_data[linear_idx, :, :] 60 | 61 | if graph_type == "histogram_graph": 62 | data = data.flatten() 63 | axes[i, j].hist(data, bins=100, log=log_transform, cmap=cmap) 64 | elif graph_type == "imshow_graph": 65 | if fourier_transform_local: 66 | data = np.abs(np.fft.fftshift(np.fft.fft2(data))) 67 | vmin, vmax = ( 68 | (total_min, total_max) 69 | if (global_min_max or global_normalize) 70 | else (data.min(), data.max()) 71 | ) 72 | im = axes[i, j].imshow(data, vmin=vmin, vmax=vmax, cmap=cmap) 73 | axes[i, j].axis("off") 74 | total_data_dict[f"Layer_{i}"][f"Head_{j}"] = data.tolist() 75 | 76 | axes[i, j].set_title(f"Head {j}", fontsize=12, pad=5) if i == 0 else None 77 | if j == 0: 78 | axes[i, j].text( 79 | -0.3, 80 | 0.5, 81 | f"Layer {i}", 82 | fontsize=12, 83 | rotation=90, 84 | ha="center", 85 | va="center", 86 | transform=axes[i, j].transAxes, 87 | ) 88 | 89 | # Add colorbar for imshow_graph 90 | if graph_type == "imshow_graph" and (global_min_max or global_normalize): 91 | cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7]) 92 | fig.colorbar(im, cax=cbar_ax) 93 | cbar_ax.set_title("Attention", size=12) 94 | 95 | plt.subplots_adjust(wspace=0.2, hspace=0.4) 96 | plt.suptitle(f"Attention for Image Idx {idx}", fontsize=20, y=0.93) 97 | plt.show() 98 | -------------------------------------------------------------------------------- /src/vit_prisma/visualization/visualize_image.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast 2 | 3 | import torch 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | import matplotlib.figure 7 | 8 | 9 | def display_grid_on_image(image: Union[np.ndarray, torch.Tensor], patch_size: int = 32, return_plot: bool = False) -> Optional[matplotlib.figure.Figure]: 10 | """ 11 | Separates an image into a grid of patches and overlays the grid on the image. 12 | 13 | Args: 14 | image (torch.Tensor): The input image, either as a numpy array or a PyTorch tensor. 15 | Dimensions of (H, W, C) if numpy or (C, H, W) if tensor 16 | patch_size (int, optional): The size of each patch in the grid. Default is 32. 17 | return_plot (bool, optional): If True, the function will return the plot figure. If False, it will display the plot. Default is False. 18 | 19 | Returns: 20 | matplotlib.figure.Figure or None: If return_plot is True, returns the matplotlib figure object. Otherwise, displays the image with the grid overlay. 21 | """ 22 | 23 | if isinstance(image, torch.Tensor): 24 | image = image.detach().numpy().transpose(1, 2, 0) 25 | if image.shape[0] != 224: 26 | image = image.transpose(1, 2, 0) 27 | if image.max() <= 1.0: 28 | image = (image * 255).astype(np.uint8) 29 | 30 | num_patches = (image.shape[0] / patch_size) ** 2 31 | grid_size = int(np.sqrt(num_patches)) 32 | 33 | # Calculate patch size 34 | patch_height = image.shape[0] // grid_size 35 | patch_width = image.shape[1] // grid_size 36 | 37 | # Overlay grid 38 | grid_image = np.copy(image) 39 | for i in range(1, grid_size): 40 | # Vertical lines 41 | grid_image[:, patch_width * i, :] = [255, 255, 255] 42 | # Horizontal lines 43 | grid_image[patch_height * i, :, :] = [255, 255, 255] 44 | 45 | fig, ax = plt.subplots(figsize=(4, 4), dpi=100) # Adjust figsize and dpi as needed 46 | 47 | # Place labels 48 | for i in range(grid_size): 49 | for j in range(grid_size): 50 | x_center = (j * patch_width) + (patch_width // 2) 51 | y_center = (i * patch_height) + (patch_height // 2) 52 | # Convert the patch index to 2D coordinates (row, column) 53 | patch_index = i * grid_size + j 54 | row, col = divmod(patch_index, grid_size) 55 | ax.text(x_center, y_center, f"{patch_index+1}", color='red', fontsize=8, ha='center', va='center') 56 | 57 | # Display image with grid and labels 58 | ax.imshow(grid_image) 59 | ax.axis('off') 60 | 61 | if return_plot: 62 | return fig 63 | else: 64 | plt.show() 65 | -------------------------------------------------------------------------------- /src/vit_prisma/vjepa_hf/README.md: -------------------------------------------------------------------------------- 1 | ## Huggingface versions of VJEPA 2 | 3 | This app contains the code for the Huggingface version of VJEPA, along with conversion scripts to convert our trained checkpoints into the HF format. 4 | 5 | | Model | VJEPA checkpoint | HF Checkpoint (CW) | 6 | | --------------- | --------------- | --------------- | 7 | | [vjepa_v1_vit_huge](configs.py#L5) | /checkpoint/amaia/video/koustuvs/models/vjepa/v1_public/vith16.pth.tar | /checkpoint/amaia/video/koustuvs/models/vjepa/v1_vit_huge_conv/ | 8 | | [vjepa_v1_vit_huge_384](configs.py#L14) | /checkpoint/amaia/video/koustuvs/models/vjepa/v1_vith16_384/vith16-384.pth.tar | /checkpoint/amaia/video/koustuvs/models/vjepa/v1_vith16_384_hf/ | 9 | | [vjepa_v1.5_vit_g_256](configs.py#L26) | /checkpoint/amaia/video/koustuvs/models/vjepa_v1.5/vm2myt1bin1k-16fpc-vitg16-256px-local-ropev2/latest.pth.tar | /checkpoint/amaia/video/koustuvs/models/vjepa_v1.5/vm2myt1bin1k-16fpc-vitg16-256px-local-ropev2/hf/ | 10 | 11 | The model paths are also added in [paths_cw.yaml](paths_cw.yaml), which should be updated subsequently with newer checkpoint conversions. 12 | 13 | ### Using VJEPA HF in your code 14 | 15 | ```python 16 | from PIL import Image 17 | import requests 18 | from app.vjepa_hf.modeling_vjepa import VJEPAModel, VJEPAImageProcessor 19 | from app.vjepa_hf.configs import CONFIGS 20 | import yaml 21 | 22 | def prepare_img(): 23 | url = "http://images.cocodataset.org/val2017/000000039769.jpg" 24 | image = Image.open(requests.get(url, stream=True).raw).convert("RGB") 25 | return image 26 | 27 | model_name = "vjepa_v1.5_vit_g_256" 28 | config = CONFIGS["v1.5"]["vit_g_256"] 29 | model_paths = yaml.safe_load(open('paths_cw.yaml')) 30 | model_path = model_paths[model_name]["loc"] 31 | 32 | image = prepare_img() 33 | processor = VJEPAImageProcessor(crop_size=config.crop_size) 34 | pixel_values = processor(image, return_tensors="pt").pixel_values 35 | pixel_values = pixel_values.repeat(1, 16, 1, 1, 1) # repeating image 16 times for now 36 | pixel_values = pixel_values.permute(0, 2, 1, 3, 4) # B x C x T x H x W 37 | 38 | model = VJEPAModel.from_pretrained(model_path) 39 | outputs = model(pixel_values) 40 | ``` 41 | 42 | ### Converting VJEPA checkpoints to Huggingface 43 | 44 | Check the [conversion script](convert_vjepa_to_hf.jpy) for details. For instance, to convert v1.5, do the following: 45 | 46 | ```bash 47 | python -m app.vjepa_hf.convert_vjepa_to_hf \ 48 | --model_name vit_g_256 \ 49 | --ckpt_loc \ 50 | --outp_loc \ 51 | --model_type v1.5 \ 52 | --convert 53 | ``` 54 | 55 | ### Test converted checkpoint 56 | 57 | After every conversion, check if the outputs match from vjepa model and the HF model: 58 | 59 | ```bash 60 | python -m app.vjepa_hf.convert_vjepa_to_hf \ 61 | --model_name vit_g_256 \ 62 | --ckpt_loc \ 63 | --outp_loc \ 64 | --model_type v1.5 \ 65 | --test 66 | ``` 67 | 68 | ### Questions? 69 | 70 | Contact Koustuv. 71 | 72 | -------------------------------------------------------------------------------- /src/vit_prisma/vjepa_hf/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Prisma-Multimodal/ViT-Prisma/aa0a8d531e3bbbdffe37624bfc19a07feb2b8262/src/vit_prisma/vjepa_hf/__init__.py -------------------------------------------------------------------------------- /src/vit_prisma/vjepa_hf/configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Prisma-Multimodal/ViT-Prisma/aa0a8d531e3bbbdffe37624bfc19a07feb2b8262/src/vit_prisma/vjepa_hf/configs/__init__.py -------------------------------------------------------------------------------- /src/vit_prisma/vjepa_hf/configs/configs.py: -------------------------------------------------------------------------------- 1 | from vit_prisma.vjepa_hf.modeling_vjepa import VJEPAConfig 2 | 3 | CONFIGS = { 4 | "v1": { 5 | "vit_h": VJEPAConfig( 6 | model_name="vit_huge", 7 | uniform_power=True, 8 | hidden_size=1280, 9 | num_attention_heads=16, 10 | num_hidden_layers=32, 11 | layer_norm_eps=1e-6, 12 | use_sdpa=True, 13 | ), 14 | "vit_h_384": VJEPAConfig( 15 | model_name="vit_huge", 16 | crop_size=384, 17 | uniform_power=True, 18 | hidden_size=1280, 19 | num_attention_heads=16, 20 | num_hidden_layers=32, 21 | layer_norm_eps=1e-6, 22 | use_sdpa=True, 23 | ), 24 | }, 25 | "v1.5": { 26 | "vit_g_256": VJEPAConfig( 27 | model_name="vit_giant_xformers_rope", 28 | crop_size=256, 29 | uniform_power=True, 30 | hidden_size=1408, 31 | num_attention_heads=22, 32 | num_hidden_layers=40, 33 | use_rope=True, 34 | mlp_ratio=48 / 11, 35 | use_sdpa=True, 36 | ) 37 | }, 38 | } 39 | -------------------------------------------------------------------------------- /src/vit_prisma/vjepa_hf/configs/sae_jepa_config.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from vit_prisma.sae.config import VisionModelSAERunnerConfig 4 | from dataclasses import dataclass 5 | 6 | from typing import Literal 7 | 8 | 9 | @dataclass 10 | class JEPABaseConfig(VisionModelSAERunnerConfig): 11 | 12 | 13 | architecture: Literal["standard", "gated", "jumprelu"] = "standard" 14 | expansion_factor: int = 16 15 | 16 | model_name: str = "vjepa_v1_vit_huge" 17 | # model_name = "wkcn/TinyCLIP-ViT-40M-32-Text-19M-LAION400M" 18 | checkpoint_path: str = "/checkpoint/soniajoseph" # meta cluster 19 | wandb_project: str = "vjepa_l_sae" 20 | 21 | d_in: int = 1280 22 | 23 | min_l0 = 5 24 | min_explained_variance = 0.30 25 | 26 | l1_coefficient: float = 0.7 27 | lr: float = 0.0002 28 | 29 | layer_subtype: str = "hook_mlp_out" 30 | 31 | num_epochs: int = 500 32 | 33 | lr_scheduler_name: str = ( 34 | "cosineannealingwarmup" # constant, constantwithwarmup, linearwarmupdecay, cosineannealing, cosineannealingwarmup 35 | ) 36 | 37 | lr_warm_up_steps: int = 200 38 | wandb_log_frequency: int = 100 39 | 40 | cls_token_only: bool = True # Only do CLS token training 41 | 42 | normalize_activations: str = None # What should this be? 43 | 44 | feature_sampling_window: int = 1000 45 | dead_feature_window: int = 5000 46 | dead_feature_threshold: float = 1e-08 47 | 48 | device: str = "cuda" 49 | 50 | n_validation_runs: int = 10 # spaced linearly throughout training 51 | 52 | 53 | train_batch_size: int = 4096 54 | 55 | use_ghost_grads: bool = False 56 | 57 | def __post_init__(self): 58 | # initialize parent 59 | super().__post_init__() 60 | 61 | if self.cls_token_only: 62 | self.context_size = 1 63 | 64 | self.hook_point = f"blocks.{self.hook_point_layer}.{self.layer_subtype}" # change hookpoint name here 65 | self.total_training_images = 1_300_000 * self.num_epochs 66 | self.total_training_tokens = self.total_training_images * self.context_size 67 | 68 | 69 | print("Running with CLIP Base Config") 70 | 71 | #/home/mila/s/sonia.joseph/SAE_factory/src/SAE_factory/quick_train/run_sweep_clip.py --hook_point_layer=9 --l1_coefficient=0.8011371355363572 --lr=0.000292136313548604 --train_batch_size=4096 72 | -------------------------------------------------------------------------------- /src/vit_prisma/vjepa_hf/evals/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Prisma-Multimodal/ViT-Prisma/aa0a8d531e3bbbdffe37624bfc19a07feb2b8262/src/vit_prisma/vjepa_hf/evals/__init__.py -------------------------------------------------------------------------------- /src/vit_prisma/vjepa_hf/evals/eval_configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Prisma-Multimodal/ViT-Prisma/aa0a8d531e3bbbdffe37624bfc19a07feb2b8262/src/vit_prisma/vjepa_hf/evals/eval_configs/__init__.py -------------------------------------------------------------------------------- /src/vit_prisma/vjepa_hf/evals/eval_configs/imagenet_config.yaml: -------------------------------------------------------------------------------- 1 | tag: imagenet 2 | eval_name: imagenet_classification 3 | nodes: 4 4 | tasks_per_node: 8 5 | cpus_per_task: 10 6 | num_probe_blocks: 1 7 | resume_checkpoint: false 8 | data: 9 | dataset_name: 'ImageNet' 10 | num_classes: 1000 11 | # root_path: None 12 | # image_folder: None 13 | optimization: 14 | num_epochs: 10 15 | batch_size: 32 16 | lr: 0.001 17 | start_lr: 0.001 18 | final_lr: 0.00001 19 | warmup: 0.1 20 | use_bfloat16: false 21 | weight_decay: 0.04 22 | pretrain: 23 | model_name: vit_h 24 | tubelet_size: 2 25 | pos_embed_type: 3d 26 | use_xformers: false 27 | use_sdpa: false 28 | patch_size: 16 29 | folder: "/checkpoint/soniajoseph/v1_vit_huge_conv/" 30 | write_tag: jepa 31 | 32 | 33 | 34 | # args_eval = { 35 | # 'pretrain': { 36 | # 'checkpoint_key': 'target_encoder', 37 | # 'model_name': 'vit_huge', 38 | # 'patch_size': 16, 39 | # 'folder': "/checkpoint/soniajoseph/v1_vit_huge_conv/", 40 | # 'ckp_fname': 'model_checkpoint.pth', 41 | # 'tag': 'my_model_tag', 42 | # 'use_sdpa': True, 43 | # 'use_silu': False, 44 | # 'wide_silu': True, 45 | # 'uniform_power': False, 46 | # 'is_causal': False 47 | # }, 48 | # 'data': { 49 | # 'dataset_name': 'ImageNet', 50 | # 'num_classes': 1000, 51 | # 'root_path': None, 52 | # 'image_folder': None, 53 | # }, 54 | # 'optimization': { 55 | # 'batch_size': 32, 56 | # 'num_epochs': 10, 57 | # 'wd': 1e-6, 58 | # 'start_lr': 1e-4, 59 | # 'lr': 1e-3, 60 | # 'final_lr': 1e-5, 61 | # 'warmup': 0.1, 62 | # 'use_bfloat16': False 63 | # }, 64 | # 'num_probe_blocks': 1, 65 | # 'resume_checkpoint': False, 66 | # 'tag': 'my_model_tag' 67 | # } -------------------------------------------------------------------------------- /src/vit_prisma/vjepa_hf/evals/scaffold.py: -------------------------------------------------------------------------------- 1 | 2 | import importlib 3 | 4 | from vit_prisma.vjepa_hf.src.utils.logging import get_logger 5 | 6 | logger = get_logger("Eval runner scaffold") 7 | 8 | 9 | def main(eval_name, args_eval, resume_preempt=False): 10 | logger.info(f"Running evaluation: {eval_name}") 11 | import_path = f'vit_prisma.vjepa_hf.evals.{eval_name}' 12 | # if eval_name.startswith("app."): 13 | # import_path = f"{eval_name}.eval" 14 | # else: 15 | # import_path = f"evals.{eval_name}.eval" 16 | return importlib.import_module(import_path).main(args_eval=args_eval, resume_preempt=resume_preempt) 17 | -------------------------------------------------------------------------------- /src/vit_prisma/vjepa_hf/paths_cw.yaml: -------------------------------------------------------------------------------- 1 | vjepa_v1_vit_huge: 2 | loc: "/checkpoint/soniajoseph/v1_vit_huge_conv/" 3 | vjepa_v1_vit_huge_384: 4 | loc: "/checkpoint/soniajoseph/v1_vith16_384_hf/" 5 | vjepa_v1.5_vit_g_256: 6 | loc: "/checkpoint/soniajoseph/vm2myt1bin1k-16fpc-vitg16-256px-local-ropev2/hf/" 7 | -------------------------------------------------------------------------------- /src/vit_prisma/vjepa_hf/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Prisma-Multimodal/ViT-Prisma/aa0a8d531e3bbbdffe37624bfc19a07feb2b8262/src/vit_prisma/vjepa_hf/src/__init__.py -------------------------------------------------------------------------------- /src/vit_prisma/vjepa_hf/src/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Prisma-Multimodal/ViT-Prisma/aa0a8d531e3bbbdffe37624bfc19a07feb2b8262/src/vit_prisma/vjepa_hf/src/datasets/__init__.py -------------------------------------------------------------------------------- /src/vit_prisma/vjepa_hf/src/datasets/utils/kenburns.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import cv2 9 | import numpy as np 10 | import torch 11 | 12 | 13 | def zoom_pan(img, zoom=1, delta_xy=None): 14 | 15 | _, h, w = img.shape # [C, H, W] 16 | img = np.transpose(img, (1, 2, 0)) # [H, W, C] 17 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) # convert torch img to cv2 (RGB->BGR) 18 | 19 | # Pan (translation) 20 | dx, dy = (delta_xy[0] * w, delta_xy[1] * h) 21 | M = np.array([[1, 0, dx], [0, 1, dy]], dtype=np.float32) 22 | img = cv2.warpAffine(img, M, (w, h), flags=cv2.INTER_LINEAR) 23 | 24 | # Zoom 25 | cx, cy = (w / 2, h / 2) 26 | M = cv2.getRotationMatrix2D((cx, cy), 0, zoom) 27 | img = cv2.warpAffine(img, M, (w, h), flags=cv2.INTER_LINEAR) 28 | 29 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 30 | img = np.transpose(img, (2, 0, 1)) 31 | 32 | return img 33 | 34 | 35 | def apply_kenburns(img, nframe=16, max_zoom=2, end_xy=None, zoom_mode="in", dtype=None): 36 | 37 | # Checks to ensure zoom_mode is passed in as a valid option 38 | assert isinstance(zoom_mode, str), "Flag zoom_mode is not a sring" 39 | zoom_mode = zoom_mode.lower() 40 | assert zoom_mode in ["in", "out", "inout", "outin"], "Unrecognized zoom_mode" 41 | 42 | og_nframe = nframe 43 | if zoom_mode in ["inout", "outin"]: 44 | nframe = nframe // 2 45 | 46 | # Sample start/end coordinates for panning 47 | start_xy = (0.5, 0.5) # start location is img center 48 | if end_xy is None: 49 | end_xy = torch.rand(2).tolist() 50 | end_xy = tuple([v / 4 for v in end_xy]) # only drift from center by 1/4 img width/height 51 | 52 | # -- delta-x,y in each frame (for panning) 53 | dx = (end_xy[0] - start_xy[0]) / nframe 54 | dy = (end_xy[1] - start_xy[1]) / nframe 55 | 56 | # -- delta-zoom in each frame 57 | dz = (max_zoom - 1.0) / nframe 58 | 59 | # Create zoom and pan animation 60 | prev_frame = img 61 | video = [prev_frame] 62 | for _ in range(nframe - 1): 63 | zoom = 1.0 + dz 64 | prev_frame = zoom_pan(prev_frame, zoom=zoom, delta_xy=(dx, dy)) 65 | video += [prev_frame] 66 | 67 | if zoom_mode == "in": 68 | pass # Nothing to do; already zoomed in :) 69 | elif zoom_mode == "out": 70 | video = video[::-1] 71 | elif zoom_mode == "inout": 72 | video = video + video[::-1] 73 | elif zoom_mode == "outin": 74 | video = video[::-1] + video 75 | 76 | # Repeat last frame until video length is the request number of frames, 77 | # there may be a mismatch if 'zoom_mode' is 'inout' or 'outin' and 78 | # nframes is an odd number 79 | if len(video) < og_nframe: 80 | video += [video[-1]] 81 | 82 | if dtype is not None: 83 | video = torch.cat([torch.tensor(f, dtype=dtype).unsqueeze(1) for f in video], dim=1) 84 | 85 | return video 86 | 87 | 88 | def batch_apply_kenburns(batch, nframe=16, max_zoom=2, end_xy=None, zoom_mode="in", dtype=None): 89 | dtype = batch.dtype 90 | device = batch.device 91 | batch = [i.numpy(force=True) for i in batch] # tensor->numpy 92 | batch = [apply_kenburns(i, nframe=nframe, dtype=dtype) for i in batch] 93 | batch = torch.cat([i.unsqueeze(0) for i in batch], dim=0).to(device) 94 | return batch 95 | -------------------------------------------------------------------------------- /src/vit_prisma/vjepa_hf/src/datasets/utils/tokenizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .tokenizer import SimpleTokenizer, tokenize 2 | 3 | __all__ = ["SimpleTokenizer", "tokenize"] 4 | -------------------------------------------------------------------------------- /src/vit_prisma/vjepa_hf/src/datasets/utils/tokenizers/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Prisma-Multimodal/ViT-Prisma/aa0a8d531e3bbbdffe37624bfc19a07feb2b8262/src/vit_prisma/vjepa_hf/src/datasets/utils/tokenizers/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /src/vit_prisma/vjepa_hf/src/datasets/utils/utils.py: -------------------------------------------------------------------------------- 1 | from vit_prisma.vjepa_hf.src.utils.cluster import dataset_paths 2 | from vit_prisma.vjepa_hf.src.utils.logging import get_logger 3 | 4 | logger = get_logger("Datasets utils") 5 | 6 | 7 | def get_dataset_paths(datasets: list[str]): 8 | paths = [] 9 | for d in datasets: 10 | try: 11 | path = dataset_paths().get(d) 12 | except Exception: 13 | raise Exception(f"Unknown dataset: {d}") 14 | paths.append(path) 15 | logger.info(f"Datapaths {paths}") 16 | return paths 17 | -------------------------------------------------------------------------------- /src/vit_prisma/vjepa_hf/src/datasets/utils/video/functional.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | 3 | import cv2 4 | import numpy as np 5 | import PIL 6 | import torch 7 | from torchvision.transforms import functional as tvf 8 | 9 | 10 | def _is_tensor_clip(clip): 11 | return torch.is_tensor(clip) and clip.ndimension() == 4 12 | 13 | 14 | def crop_clip(clip, min_h, min_w, h, w): 15 | if isinstance(clip[0], np.ndarray) or isinstance(clip[0], torch.Tensor): 16 | if clip[0].shape[-1] == 3: 17 | cropped = [img[min_h : min_h + h, min_w : min_w + w, :] for img in clip] 18 | else: 19 | assert clip[0].shape[0] == 3 20 | cropped = [img[:, min_h : min_h + h, min_w : min_w + w] for img in clip] 21 | 22 | elif isinstance(clip[0], PIL.Image.Image): 23 | cropped = [img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip] 24 | 25 | else: 26 | raise TypeError( 27 | "Expected numpy.ndarray or PIL.Image or torch.Tensor):" + "but got list of {0}".format(type(clip[0])) 28 | ) 29 | return cropped 30 | 31 | 32 | def resize_clip(clip, size, interpolation="bilinear"): 33 | if isinstance(clip[0], np.ndarray) or isinstance(clip[0], torch.Tensor): 34 | if isinstance(size, numbers.Number): 35 | if clip[0].shape[-1] == 3: 36 | im_h, im_w, im_c = clip[0].shape 37 | else: 38 | assert clip[0].shape[0] == 3 39 | im_c, im_h, im_w = clip[0].shape 40 | # Min spatial dim already matches minimal size 41 | if (im_w <= im_h and im_w == size) or (im_h <= im_w and im_h == size): 42 | return clip 43 | new_h, new_w = get_resize_sizes(im_h, im_w, size) 44 | size = (new_w, new_h) 45 | else: 46 | size = size[0], size[1] 47 | 48 | if isinstance(clip[0], np.ndarray): 49 | if interpolation == "bilinear": 50 | np_inter = cv2.INTER_LINEAR 51 | else: 52 | np_inter = cv2.INTER_NEAREST 53 | scaled = [cv2.resize(img, size, interpolation=np_inter) for img in clip] 54 | else: # isinstance(clip[0], torch.Tensor) 55 | if interpolation == "bilinear": 56 | np_inter = tvf.InterpolationMode.BILINEAR 57 | else: 58 | np_inter = tvf.InterpolationMode.NEAREST 59 | size = (size[1], size[0]) # torchvision transformers expect the size in (h, w) order. 60 | scaled = [tvf.resize(img, size, interpolation=np_inter) for img in clip] 61 | elif isinstance(clip[0], PIL.Image.Image): 62 | if isinstance(size, numbers.Number): 63 | im_w, im_h = clip[0].size 64 | # Min spatial dim already matches minimal size 65 | if (im_w <= im_h and im_w == size) or (im_h <= im_w and im_h == size): 66 | return clip 67 | new_h, new_w = get_resize_sizes(im_h, im_w, size) 68 | size = (new_w, new_h) 69 | else: 70 | size = size[1], size[0] 71 | if interpolation == "bilinear": 72 | pil_inter = PIL.Image.BILINEAR 73 | else: 74 | pil_inter = PIL.Image.NEAREST 75 | scaled = [img.resize(size, pil_inter) for img in clip] 76 | else: 77 | raise TypeError( 78 | "Expected numpy.ndarray or PIL.Image or torch.Tensor" + "but got list of {0}".format(type(clip[0])) 79 | ) 80 | return scaled 81 | 82 | 83 | def get_resize_sizes(im_h, im_w, size): 84 | if im_w < im_h: 85 | ow = size 86 | oh = int(size * im_h / im_w) 87 | else: 88 | oh = size 89 | ow = int(size * im_w / im_h) 90 | return oh, ow 91 | 92 | 93 | def normalize(clip, mean, std, inplace=False): 94 | if not _is_tensor_clip(clip): 95 | raise TypeError("tensor is not a torch clip.") 96 | 97 | if not inplace: 98 | clip = clip.clone() 99 | 100 | dtype = clip.dtype 101 | mean = torch.as_tensor(mean, dtype=dtype, device=clip.device) 102 | std = torch.as_tensor(std, dtype=dtype, device=clip.device) 103 | clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None]) 104 | 105 | return clip 106 | -------------------------------------------------------------------------------- /src/vit_prisma/vjepa_hf/src/masks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Prisma-Multimodal/ViT-Prisma/aa0a8d531e3bbbdffe37624bfc19a07feb2b8262/src/vit_prisma/vjepa_hf/src/masks/__init__.py -------------------------------------------------------------------------------- /src/vit_prisma/vjepa_hf/src/masks/causal.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from logging import getLogger 9 | from multiprocessing import Value 10 | 11 | import torch 12 | 13 | _GLOBAL_SEED = 0 14 | logger = getLogger() 15 | 16 | 17 | class MaskCollator(object): 18 | 19 | def __init__( 20 | self, 21 | cfgs_mask, 22 | crop_size=(224, 224), 23 | num_frames=16, 24 | patch_size=(16, 16), 25 | tubelet_size=2, 26 | ): 27 | super(MaskCollator, self).__init__() 28 | 29 | self.mask_generators = [] 30 | for m in cfgs_mask: 31 | mask_generator = _MaskGenerator( 32 | crop_size=crop_size, 33 | num_frames=num_frames, 34 | spatial_patch_size=patch_size, 35 | temporal_patch_size=tubelet_size, 36 | ctxt_frames_ratio=m.get("ctxt_frames_ratio"), 37 | ) 38 | self.mask_generators.append(mask_generator) 39 | 40 | def step(self): 41 | for mask_generator in self.mask_generators: 42 | mask_generator.step() 43 | 44 | def __call__(self, batch): 45 | 46 | batch_size = len(batch) 47 | collated_batch = torch.utils.data.default_collate(batch) 48 | 49 | collated_masks_pred, collated_masks_enc = [], [] 50 | for i, mask_generator in enumerate(self.mask_generators): 51 | masks_enc, masks_pred = mask_generator(batch_size) 52 | collated_masks_enc.append(masks_enc) 53 | collated_masks_pred.append(masks_pred) 54 | 55 | return collated_batch, collated_masks_enc, collated_masks_pred 56 | 57 | 58 | class _MaskGenerator(object): 59 | 60 | def __init__( 61 | self, 62 | crop_size=(224, 224), 63 | num_frames=16, 64 | spatial_patch_size=(16, 16), 65 | temporal_patch_size=2, 66 | ctxt_frames_ratio=0.5, 67 | ): 68 | super(_MaskGenerator, self).__init__() 69 | if not isinstance(crop_size, tuple): 70 | crop_size = (crop_size,) * 2 71 | if not isinstance(spatial_patch_size, tuple): 72 | spatial_patch_size = (spatial_patch_size,) * 2 73 | self.crop_size = crop_size 74 | self.height, self.width = [crop_size[i] // spatial_patch_size[i] for i in (0, 1)] 75 | self.duration = num_frames // temporal_patch_size 76 | 77 | self.spatial_patch_size = spatial_patch_size 78 | self.temporal_patch_size = temporal_patch_size 79 | 80 | self.nctxt_frames = max(1, int(self.duration * ctxt_frames_ratio)) 81 | self._itr_counter = Value("i", -1) # collator is shared across worker processes 82 | 83 | def step(self): 84 | i = self._itr_counter 85 | with i.get_lock(): 86 | i.value += 1 87 | v = i.value 88 | return v 89 | 90 | def _sample_mask(self, nctxt_frames): 91 | mask = torch.ones((self.duration, self.height, self.width), dtype=torch.int32) 92 | mask[nctxt_frames:, :, :] = 0 93 | return mask 94 | 95 | def __call__(self, batch_size): 96 | """ 97 | Create encoder and predictor masks when collating imgs into a batch 98 | # 1. sample pred block size using seed 99 | # 2. sample several pred block locations for each image (w/o seed) 100 | # 3. return pred masks and complement (enc mask) 101 | """ 102 | seed = self.step() 103 | g = torch.Generator() 104 | g.manual_seed(seed) 105 | 106 | collated_masks_pred, collated_masks_enc = [], [] 107 | min_keep_enc = min_keep_pred = self.duration * self.height * self.width 108 | for _ in range(batch_size): 109 | 110 | empty_context = True 111 | while empty_context: 112 | 113 | mask_e = self._sample_mask(self.nctxt_frames) 114 | mask_e = mask_e.flatten() 115 | mask_p = torch.argwhere(mask_e == 0).squeeze() 116 | mask_e = torch.nonzero(mask_e).squeeze() 117 | 118 | empty_context = len(mask_e) == 0 119 | if not empty_context: 120 | min_keep_pred = min(min_keep_pred, len(mask_p)) 121 | min_keep_enc = min(min_keep_enc, len(mask_e)) 122 | collated_masks_pred.append(mask_p) 123 | collated_masks_enc.append(mask_e) 124 | 125 | collated_masks_pred = [cm[:min_keep_pred] for cm in collated_masks_pred] 126 | collated_masks_pred = torch.utils.data.default_collate(collated_masks_pred) 127 | # -- 128 | collated_masks_enc = [cm[:min_keep_enc] for cm in collated_masks_enc] 129 | collated_masks_enc = torch.utils.data.default_collate(collated_masks_enc) 130 | 131 | return collated_masks_enc, collated_masks_pred 132 | -------------------------------------------------------------------------------- /src/vit_prisma/vjepa_hf/src/masks/default.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from logging import getLogger 9 | 10 | import torch 11 | 12 | _GLOBAL_SEED = 0 13 | logger = getLogger() 14 | 15 | 16 | class DefaultCollator(object): 17 | 18 | def __call__(self, batch): 19 | collated_batch = torch.utils.data.default_collate(batch) 20 | return collated_batch, None, None 21 | -------------------------------------------------------------------------------- /src/vit_prisma/vjepa_hf/src/masks/random.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from logging import getLogger 9 | from multiprocessing import Value 10 | 11 | import torch 12 | 13 | _GLOBAL_SEED = 0 14 | logger = getLogger() 15 | 16 | 17 | class MaskCollator(object): 18 | 19 | def __init__( 20 | self, 21 | ratio=(0.4, 0.6), 22 | input_size=(224, 224), 23 | patch_size=16, 24 | ): 25 | super(MaskCollator, self).__init__() 26 | if not isinstance(input_size, tuple): 27 | input_size = (input_size,) * 2 28 | self.patch_size = patch_size 29 | self.height, self.width = input_size[0] // patch_size, input_size[1] // patch_size 30 | self.num_patches = self.height * self.width 31 | self.ratio = ratio 32 | self._itr_counter = Value("i", -1) # collator is shared across worker processes 33 | 34 | def step(self): 35 | i = self._itr_counter 36 | with i.get_lock(): 37 | i.value += 1 38 | v = i.value 39 | return v 40 | 41 | def __call__(self, batch): 42 | """Create encoder and predictor masks when collating imgs into a batch""" 43 | B = len(batch) 44 | 45 | collated_batch = torch.utils.data.default_collate(batch) 46 | 47 | seed = self.step() 48 | g = torch.Generator() 49 | g.manual_seed(seed) 50 | ratio = self.ratio 51 | ratio = ratio[0] + torch.rand(1, generator=g).item() * (ratio[1] - ratio[0]) 52 | num_keep = int(self.num_patches * (1.0 - ratio)) 53 | 54 | collated_masks_pred, collated_masks_enc = [], [] 55 | for _ in range(B): 56 | m = torch.randperm(self.num_patches) 57 | collated_masks_enc.append([m[:num_keep]]) 58 | collated_masks_pred.append([m[num_keep:]]) 59 | 60 | collated_masks_pred = torch.utils.data.default_collate(collated_masks_pred) 61 | collated_masks_enc = torch.utils.data.default_collate(collated_masks_enc) 62 | 63 | return collated_batch, collated_masks_enc, collated_masks_pred 64 | -------------------------------------------------------------------------------- /src/vit_prisma/vjepa_hf/src/masks/random_tube.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from logging import getLogger 9 | from multiprocessing import Value 10 | 11 | import numpy as np 12 | import torch 13 | 14 | _GLOBAL_SEED = 0 15 | logger = getLogger() 16 | 17 | 18 | class MaskCollator(object): 19 | 20 | def __init__( 21 | self, 22 | cfgs_mask, 23 | crop_size=(224, 224), 24 | num_frames=16, 25 | patch_size=(16, 16), 26 | tubelet_size=2, 27 | ): 28 | super(MaskCollator, self).__init__() 29 | 30 | self.mask_generators = [] 31 | for m in cfgs_mask: 32 | mask_generator = _MaskGenerator( 33 | crop_size=crop_size, 34 | num_frames=num_frames, 35 | spatial_patch_size=patch_size, 36 | temporal_patch_size=tubelet_size, 37 | ratio=m.get("ratio"), 38 | ) 39 | self.mask_generators.append(mask_generator) 40 | 41 | def step(self): 42 | for mask_generator in self.mask_generators: 43 | mask_generator.step() 44 | 45 | def __call__(self, batch): 46 | 47 | batch_size = len(batch) 48 | collated_batch = torch.utils.data.default_collate(batch) 49 | 50 | collated_masks_pred, collated_masks_enc = [], [] 51 | for mask_generator in self.mask_generators: 52 | masks_enc, masks_pred = mask_generator(batch_size) 53 | collated_masks_enc.append(masks_enc) 54 | collated_masks_pred.append(masks_pred) 55 | 56 | return collated_batch, collated_masks_enc, collated_masks_pred 57 | 58 | 59 | class _MaskGenerator(object): 60 | 61 | def __init__( 62 | self, 63 | crop_size=(224, 224), 64 | num_frames=16, 65 | spatial_patch_size=(16, 16), 66 | temporal_patch_size=2, 67 | ratio=0.9, 68 | ): 69 | super(_MaskGenerator, self).__init__() 70 | if not isinstance(crop_size, tuple): 71 | crop_size = (crop_size,) * 2 72 | if not isinstance(spatial_patch_size, tuple): 73 | spatial_patch_size = (spatial_patch_size,) * 2 74 | self.crop_size = crop_size 75 | self.height, self.width = [crop_size[i] // spatial_patch_size[i] for i in (0, 1)] 76 | self.duration = num_frames // temporal_patch_size 77 | 78 | self.spatial_patch_size = spatial_patch_size 79 | self.temporal_patch_size = temporal_patch_size 80 | self.num_patches_spatial = self.height * self.width 81 | 82 | self.ratio = ratio 83 | 84 | self.num_keep_spatial = int(self.num_patches_spatial * (1.0 - self.ratio)) 85 | self.num_keep = self.num_keep_spatial * self.duration 86 | 87 | self._itr_counter = Value("i", -1) # collator is shared across worker processes 88 | 89 | def step(self): 90 | i = self._itr_counter 91 | with i.get_lock(): 92 | i.value += 1 93 | v = i.value 94 | return v 95 | 96 | def __call__(self, batch_size): 97 | def sample_mask(): 98 | mask = np.hstack( 99 | [ 100 | np.zeros(self.num_patches_spatial - self.num_keep_spatial), 101 | np.ones(self.num_keep_spatial), 102 | ] 103 | ) 104 | np.random.shuffle(mask) 105 | mask = torch.tensor(np.tile(mask, (self.duration, 1))) 106 | mask = mask.flatten() 107 | mask_p = torch.argwhere(mask == 0).squeeze() 108 | mask_e = torch.nonzero(mask).squeeze() 109 | return mask_e, mask_p 110 | 111 | collated_masks_pred, collated_masks_enc = [], [] 112 | for _ in range(batch_size): 113 | mask_e, mask_p = sample_mask() 114 | collated_masks_enc.append(mask_e) 115 | collated_masks_pred.append(mask_p) 116 | 117 | collated_masks_enc = torch.utils.data.default_collate(collated_masks_enc) 118 | collated_masks_pred = torch.utils.data.default_collate(collated_masks_pred) 119 | 120 | return collated_masks_enc, collated_masks_pred 121 | -------------------------------------------------------------------------------- /src/vit_prisma/vjepa_hf/src/masks/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import torch 9 | 10 | 11 | def apply_masks(x, masks, concat=True): 12 | """ 13 | :param x: tensor of shape [B (batch-size), N (num-patches), D (feature-dim)] 14 | :param masks: list of tensors of shape [B, K] containing indices of K patches in [N] to keep 15 | """ 16 | all_x = [] 17 | for m in masks: 18 | mask_keep = m.unsqueeze(-1).repeat(1, 1, x.size(-1)) 19 | all_x += [torch.gather(x, dim=1, index=mask_keep)] 20 | if not concat: 21 | return all_x 22 | 23 | return torch.cat(all_x, dim=0) 24 | -------------------------------------------------------------------------------- /src/vit_prisma/vjepa_hf/src/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Prisma-Multimodal/ViT-Prisma/aa0a8d531e3bbbdffe37624bfc19a07feb2b8262/src/vit_prisma/vjepa_hf/src/models/__init__.py -------------------------------------------------------------------------------- /src/vit_prisma/vjepa_hf/src/models/attentive_pooler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | 9 | import math 10 | 11 | import torch 12 | import torch.nn as nn 13 | 14 | from vit_prisma.vjepa_hf.src.models.utils.modules import Block, CrossAttention, CrossAttentionBlock 15 | from vit_prisma.vjepa_hf.src.utils.tensors import trunc_normal_ 16 | 17 | 18 | class AttentivePooler(nn.Module): 19 | """Attentive Pooler""" 20 | 21 | def __init__( 22 | self, 23 | num_queries=1, 24 | embed_dim=768, 25 | num_heads=12, 26 | mlp_ratio=4.0, 27 | depth=1, 28 | norm_layer=nn.LayerNorm, 29 | init_std=0.02, 30 | qkv_bias=True, 31 | complete_block=True, 32 | use_activation_checkpointing=False, 33 | ): 34 | super().__init__() 35 | self.use_activation_checkpointing = use_activation_checkpointing 36 | self.query_tokens = nn.Parameter(torch.zeros(1, num_queries, embed_dim)) 37 | 38 | self.complete_block = complete_block 39 | if complete_block: 40 | self.cross_attention_block = CrossAttentionBlock( 41 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, norm_layer=norm_layer 42 | ) 43 | else: 44 | self.cross_attention_block = CrossAttention(dim=embed_dim, num_heads=num_heads, qkv_bias=qkv_bias) 45 | 46 | self.blocks = None 47 | if depth > 1: 48 | self.blocks = nn.ModuleList( 49 | [ 50 | Block( 51 | dim=embed_dim, 52 | num_heads=num_heads, 53 | mlp_ratio=mlp_ratio, 54 | qkv_bias=qkv_bias, 55 | qk_scale=False, 56 | norm_layer=norm_layer, 57 | ) 58 | for i in range(depth - 1) 59 | ] 60 | ) 61 | 62 | self.init_std = init_std 63 | trunc_normal_(self.query_tokens, std=self.init_std) 64 | self.apply(self._init_weights) 65 | self._rescale_blocks() 66 | 67 | def _rescale_blocks(self): 68 | def rescale(param, layer_id): 69 | param.div_(math.sqrt(2.0 * layer_id)) 70 | 71 | layer_id = 0 72 | if self.blocks is not None: 73 | for layer_id, layer in enumerate(self.blocks): 74 | rescale(layer.attn.proj.weight.data, layer_id + 1) 75 | rescale(layer.mlp.fc2.weight.data, layer_id + 1) 76 | 77 | if self.complete_block: 78 | # rescale(self.cross_attention_block.xattn.proj.weight.data, layer_id + 1) 79 | rescale(self.cross_attention_block.mlp.fc2.weight.data, layer_id + 1) 80 | # else: 81 | # rescale(self.cross_attention_block.proj.weight.data, layer_id + 1) 82 | 83 | def _init_weights(self, m): 84 | if isinstance(m, nn.Linear): 85 | trunc_normal_(m.weight, std=self.init_std) 86 | if isinstance(m, nn.Linear) and m.bias is not None: 87 | nn.init.constant_(m.bias, 0) 88 | elif isinstance(m, nn.LayerNorm): 89 | nn.init.constant_(m.bias, 0) 90 | nn.init.constant_(m.weight, 1.0) 91 | elif isinstance(m, nn.Conv2d): 92 | trunc_normal_(m.weight, std=self.init_std) 93 | if m.bias is not None: 94 | nn.init.constant_(m.bias, 0) 95 | 96 | def forward(self, x): 97 | if self.blocks is not None: 98 | for blk in self.blocks: 99 | if self.use_activation_checkpointing: 100 | x = torch.utils.checkpoint.checkpoint(blk, x, False, None, use_reentrant=False) 101 | else: 102 | x = blk(x) 103 | q = self.query_tokens.repeat(len(x), 1, 1) 104 | q = self.cross_attention_block(q, x) 105 | return q 106 | 107 | 108 | class AttentiveClassifier(nn.Module): 109 | """Attentive Classifier""" 110 | 111 | def __init__( 112 | self, 113 | embed_dim=768, 114 | num_heads=12, 115 | mlp_ratio=4.0, 116 | depth=1, 117 | norm_layer=nn.LayerNorm, 118 | init_std=0.02, 119 | qkv_bias=True, 120 | num_classes=1000, 121 | complete_block=True, 122 | use_activation_checkpointing=False, 123 | ): 124 | super().__init__() 125 | self.pooler = AttentivePooler( 126 | num_queries=1, 127 | embed_dim=embed_dim, 128 | num_heads=num_heads, 129 | mlp_ratio=mlp_ratio, 130 | depth=depth, 131 | norm_layer=norm_layer, 132 | init_std=init_std, 133 | qkv_bias=qkv_bias, 134 | complete_block=complete_block, 135 | use_activation_checkpointing=use_activation_checkpointing, 136 | ) 137 | self.linear = nn.Linear(embed_dim, num_classes, bias=True) 138 | 139 | def forward(self, x): 140 | x = self.pooler(x).squeeze(1) 141 | x = self.linear(x) 142 | return x 143 | -------------------------------------------------------------------------------- /src/vit_prisma/vjepa_hf/src/models/baselines/dinov2_layers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .dino_head import DINOHead 8 | from .mlp import Mlp 9 | from .patch_embed import PatchEmbed 10 | from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused 11 | from .block import NestedTensorBlock 12 | from .attention import MemEffAttention 13 | -------------------------------------------------------------------------------- /src/vit_prisma/vjepa_hf/src/models/baselines/dinov2_layers/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py 10 | 11 | import logging 12 | 13 | from torch import Tensor 14 | from torch import nn 15 | 16 | 17 | logger = logging.getLogger() 18 | 19 | 20 | try: 21 | from xformers.ops import memory_efficient_attention, unbind, fmha 22 | 23 | XFORMERS_AVAILABLE = True 24 | except ImportError: 25 | # logger.warning("xFormers not available") 26 | XFORMERS_AVAILABLE = False 27 | 28 | 29 | class Attention(nn.Module): 30 | def __init__( 31 | self, 32 | dim: int, 33 | num_heads: int = 8, 34 | qkv_bias: bool = False, 35 | proj_bias: bool = True, 36 | attn_drop: float = 0.0, 37 | proj_drop: float = 0.0, 38 | ) -> None: 39 | super().__init__() 40 | self.num_heads = num_heads 41 | head_dim = dim // num_heads 42 | self.scale = head_dim**-0.5 43 | 44 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 45 | self.attn_drop = nn.Dropout(attn_drop) 46 | self.proj = nn.Linear(dim, dim, bias=proj_bias) 47 | self.proj_drop = nn.Dropout(proj_drop) 48 | 49 | def forward(self, x: Tensor) -> Tensor: 50 | B, N, C = x.shape 51 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 52 | 53 | q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] 54 | attn = q @ k.transpose(-2, -1) 55 | 56 | attn = attn.softmax(dim=-1) 57 | attn = self.attn_drop(attn) 58 | 59 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 60 | x = self.proj(x) 61 | x = self.proj_drop(x) 62 | return x 63 | 64 | 65 | class MemEffAttention(Attention): 66 | def forward(self, x: Tensor, attn_bias=None) -> Tensor: 67 | if not XFORMERS_AVAILABLE: 68 | assert attn_bias is None, "xFormers is required for nested tensors usage" 69 | return super().forward(x) 70 | 71 | B, N, C = x.shape 72 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) 73 | 74 | q, k, v = unbind(qkv, 2) 75 | 76 | x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) 77 | x = x.reshape([B, N, C]) 78 | 79 | x = self.proj(x) 80 | x = self.proj_drop(x) 81 | return x 82 | -------------------------------------------------------------------------------- /src/vit_prisma/vjepa_hf/src/models/baselines/dinov2_layers/dino_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn.init import trunc_normal_ 10 | from torch.nn.utils import weight_norm 11 | 12 | 13 | class DINOHead(nn.Module): 14 | def __init__( 15 | self, 16 | in_dim, 17 | out_dim, 18 | use_bn=False, 19 | nlayers=3, 20 | hidden_dim=2048, 21 | bottleneck_dim=256, 22 | mlp_bias=True, 23 | ): 24 | super().__init__() 25 | nlayers = max(nlayers, 1) 26 | self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias) 27 | self.apply(self._init_weights) 28 | self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) 29 | self.last_layer.weight_g.data.fill_(1) 30 | 31 | def _init_weights(self, m): 32 | if isinstance(m, nn.Linear): 33 | trunc_normal_(m.weight, std=0.02) 34 | if isinstance(m, nn.Linear) and m.bias is not None: 35 | nn.init.constant_(m.bias, 0) 36 | 37 | def forward(self, x): 38 | x = self.mlp(x) 39 | eps = 1e-6 if x.dtype == torch.float16 else 1e-12 40 | x = nn.functional.normalize(x, dim=-1, p=2, eps=eps) 41 | x = self.last_layer(x) 42 | return x 43 | 44 | 45 | def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True): 46 | if nlayers == 1: 47 | return nn.Linear(in_dim, bottleneck_dim, bias=bias) 48 | else: 49 | layers = [nn.Linear(in_dim, hidden_dim, bias=bias)] 50 | if use_bn: 51 | layers.append(nn.BatchNorm1d(hidden_dim)) 52 | layers.append(nn.GELU()) 53 | for _ in range(nlayers - 2): 54 | layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias)) 55 | if use_bn: 56 | layers.append(nn.BatchNorm1d(hidden_dim)) 57 | layers.append(nn.GELU()) 58 | layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias)) 59 | return nn.Sequential(*layers) 60 | -------------------------------------------------------------------------------- /src/vit_prisma/vjepa_hf/src/models/baselines/dinov2_layers/drop_path.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py 10 | 11 | 12 | from torch import nn 13 | 14 | 15 | def drop_path(x, drop_prob: float = 0.0, training: bool = False): 16 | if drop_prob == 0.0 or not training: 17 | return x 18 | keep_prob = 1 - drop_prob 19 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 20 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob) 21 | if keep_prob > 0.0: 22 | random_tensor.div_(keep_prob) 23 | output = x * random_tensor 24 | return output 25 | 26 | 27 | class DropPath(nn.Module): 28 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" 29 | 30 | def __init__(self, drop_prob=None): 31 | super(DropPath, self).__init__() 32 | self.drop_prob = drop_prob 33 | 34 | def forward(self, x): 35 | return drop_path(x, self.drop_prob, self.training) 36 | -------------------------------------------------------------------------------- /src/vit_prisma/vjepa_hf/src/models/baselines/dinov2_layers/layer_scale.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 8 | 9 | from typing import Union 10 | 11 | import torch 12 | from torch import Tensor 13 | from torch import nn 14 | 15 | 16 | class LayerScale(nn.Module): 17 | def __init__( 18 | self, 19 | dim: int, 20 | init_values: Union[float, Tensor] = 1e-5, 21 | inplace: bool = False, 22 | ) -> None: 23 | super().__init__() 24 | self.inplace = inplace 25 | self.gamma = nn.Parameter(init_values * torch.ones(dim)) 26 | 27 | def forward(self, x: Tensor) -> Tensor: 28 | return x.mul_(self.gamma) if self.inplace else x * self.gamma 29 | -------------------------------------------------------------------------------- /src/vit_prisma/vjepa_hf/src/models/baselines/dinov2_layers/mlp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py 10 | 11 | 12 | from typing import Callable, Optional 13 | 14 | from torch import Tensor, nn 15 | 16 | 17 | class Mlp(nn.Module): 18 | def __init__( 19 | self, 20 | in_features: int, 21 | hidden_features: Optional[int] = None, 22 | out_features: Optional[int] = None, 23 | act_layer: Callable[..., nn.Module] = nn.GELU, 24 | drop: float = 0.0, 25 | bias: bool = True, 26 | ) -> None: 27 | super().__init__() 28 | out_features = out_features or in_features 29 | hidden_features = hidden_features or in_features 30 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) 31 | self.act = act_layer() 32 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) 33 | self.drop = nn.Dropout(drop) 34 | 35 | def forward(self, x: Tensor) -> Tensor: 36 | x = self.fc1(x) 37 | x = self.act(x) 38 | x = self.drop(x) 39 | x = self.fc2(x) 40 | x = self.drop(x) 41 | return x 42 | -------------------------------------------------------------------------------- /src/vit_prisma/vjepa_hf/src/models/baselines/dinov2_layers/patch_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py 10 | 11 | from typing import Callable, Optional, Tuple, Union 12 | 13 | from torch import Tensor 14 | import torch.nn as nn 15 | 16 | 17 | def make_2tuple(x): 18 | if isinstance(x, tuple): 19 | assert len(x) == 2 20 | return x 21 | 22 | assert isinstance(x, int) 23 | return (x, x) 24 | 25 | 26 | class PatchEmbed(nn.Module): 27 | """ 28 | 2D image to patch embedding: (B,C,H,W) -> (B,N,D) 29 | 30 | Args: 31 | img_size: Image size. 32 | patch_size: Patch token size. 33 | in_chans: Number of input image channels. 34 | embed_dim: Number of linear projection output channels. 35 | norm_layer: Normalization layer. 36 | """ 37 | 38 | def __init__( 39 | self, 40 | img_size: Union[int, Tuple[int, int]] = 224, 41 | patch_size: Union[int, Tuple[int, int]] = 16, 42 | in_chans: int = 3, 43 | embed_dim: int = 768, 44 | norm_layer: Optional[Callable] = None, 45 | flatten_embedding: bool = True, 46 | ) -> None: 47 | super().__init__() 48 | 49 | image_HW = make_2tuple(img_size) 50 | patch_HW = make_2tuple(patch_size) 51 | patch_grid_size = ( 52 | image_HW[0] // patch_HW[0], 53 | image_HW[1] // patch_HW[1], 54 | ) 55 | 56 | self.img_size = image_HW 57 | self.patch_size = patch_HW 58 | self.patches_resolution = patch_grid_size 59 | self.num_patches = patch_grid_size[0] * patch_grid_size[1] 60 | 61 | self.in_chans = in_chans 62 | self.embed_dim = embed_dim 63 | 64 | self.flatten_embedding = flatten_embedding 65 | 66 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) 67 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 68 | 69 | def forward(self, x: Tensor) -> Tensor: 70 | _, _, H, W = x.shape 71 | patch_H, patch_W = self.patch_size 72 | 73 | # assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" 74 | # assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" 75 | 76 | x = self.proj(x) # B C H W 77 | H, W = x.size(2), x.size(3) 78 | x = x.flatten(2).transpose(1, 2) # B HW C 79 | x = self.norm(x) 80 | if not self.flatten_embedding: 81 | x = x.reshape(-1, H, W, self.embed_dim) # B H W C 82 | return x 83 | 84 | def flops(self) -> float: 85 | Ho, Wo = self.patches_resolution 86 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) 87 | if self.norm is not None: 88 | flops += Ho * Wo * self.embed_dim 89 | return flops 90 | -------------------------------------------------------------------------------- /src/vit_prisma/vjepa_hf/src/models/baselines/dinov2_layers/swiglu_ffn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Callable, Optional 8 | 9 | from torch import Tensor, nn 10 | import torch.nn.functional as F 11 | 12 | 13 | class SwiGLUFFN(nn.Module): 14 | def __init__( 15 | self, 16 | in_features: int, 17 | hidden_features: Optional[int] = None, 18 | out_features: Optional[int] = None, 19 | act_layer: Callable[..., nn.Module] = None, 20 | drop: float = 0.0, 21 | bias: bool = True, 22 | ) -> None: 23 | super().__init__() 24 | out_features = out_features or in_features 25 | hidden_features = hidden_features or in_features 26 | self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) 27 | self.w3 = nn.Linear(hidden_features, out_features, bias=bias) 28 | 29 | def forward(self, x: Tensor) -> Tensor: 30 | x12 = self.w12(x) 31 | x1, x2 = x12.chunk(2, dim=-1) 32 | hidden = F.silu(x1) * x2 33 | return self.w3(hidden) 34 | 35 | 36 | try: 37 | from xformers.ops import SwiGLU 38 | 39 | XFORMERS_AVAILABLE = True 40 | except ImportError: 41 | SwiGLU = SwiGLUFFN 42 | XFORMERS_AVAILABLE = False 43 | 44 | 45 | class SwiGLUFFNFused(SwiGLU): 46 | def __init__( 47 | self, 48 | in_features: int, 49 | hidden_features: Optional[int] = None, 50 | out_features: Optional[int] = None, 51 | act_layer: Callable[..., nn.Module] = None, 52 | drop: float = 0.0, 53 | bias: bool = True, 54 | ) -> None: 55 | out_features = out_features or in_features 56 | hidden_features = hidden_features or in_features 57 | hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 58 | super().__init__( 59 | in_features=in_features, 60 | hidden_features=hidden_features, 61 | out_features=out_features, 62 | bias=bias, 63 | ) 64 | -------------------------------------------------------------------------------- /src/vit_prisma/vjepa_hf/src/models/utils/multimask.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import torch.nn as nn 9 | 10 | 11 | class MultiMaskWrapper(nn.Module): 12 | 13 | def __init__(self, backbone): 14 | super().__init__() 15 | self.backbone = backbone 16 | 17 | def forward(self, x, masks=None, targets=None): 18 | if masks is None: 19 | return self.backbone(x) 20 | 21 | if (masks is not None) and not isinstance(masks, list): 22 | masks = [masks] 23 | outs = [] 24 | for m in masks: 25 | if targets is None: 26 | outs += [self.backbone(x, masks=m)] 27 | else: 28 | outs += [self.backbone(x, masks=m, targets=targets)] 29 | return outs 30 | 31 | 32 | class PredictorMultiMaskWrapper(nn.Module): 33 | 34 | def __init__(self, backbone): 35 | super().__init__() 36 | self.backbone = backbone 37 | 38 | def forward(self, ctxt, tgt, masks_ctxt, masks_tgt): 39 | if type(ctxt) is not list: 40 | ctxt = [ctxt] 41 | if type(tgt) is not list: 42 | tgt = [tgt] 43 | if type(masks_ctxt) is not list: 44 | masks_ctxt = [masks_ctxt] 45 | if type(masks_tgt) is not list: 46 | masks_tgt = [masks_tgt] 47 | 48 | outs = [] 49 | for i, (zi, hi, mc, mt) in enumerate(zip(ctxt, tgt, masks_ctxt, masks_tgt)): 50 | outs += [self.backbone(zi, hi, mc, mt, mask_index=i)] 51 | return outs 52 | -------------------------------------------------------------------------------- /src/vit_prisma/vjepa_hf/src/models/utils/patch_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from einops import rearrange 9 | 10 | import torch.nn as nn 11 | 12 | 13 | class AudioPatchEmbed(nn.Module): 14 | """ 15 | Audio to Patch Embedding 16 | """ 17 | def __init__( 18 | self, 19 | freq_bands=128, 20 | tubelet_size=2, 21 | embed_dim=768 22 | ): 23 | super().__init__() 24 | self.freq_bands = freq_bands 25 | self.tubelet_size = tubelet_size 26 | self.proj = nn.Conv2d(1, embed_dim, 27 | kernel_size=(freq_bands, tubelet_size), 28 | stride=(freq_bands, tubelet_size)) 29 | 30 | def forward(self, x): 31 | x = rearrange(x, 'b t c f -> b c f t') 32 | x = self.proj(x).flatten(2).transpose(1, 2) 33 | return x 34 | 35 | 36 | class PatchEmbed(nn.Module): 37 | """ 38 | Image to Patch Embedding 39 | """ 40 | def __init__( 41 | self, 42 | patch_size=16, 43 | in_chans=3, 44 | embed_dim=768 45 | ): 46 | super().__init__() 47 | self.patch_size = patch_size 48 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 49 | 50 | def forward(self, x): 51 | B, C, H, W = x.shape 52 | x = self.proj(x).flatten(2).transpose(1, 2) 53 | return x 54 | 55 | 56 | class PatchEmbed3D(nn.Module): 57 | """ 58 | Image to Patch Embedding 59 | """ 60 | 61 | def __init__( 62 | self, 63 | patch_size=16, 64 | tubelet_size=2, 65 | in_chans=3, 66 | embed_dim=768, 67 | ): 68 | super().__init__() 69 | self.patch_size = patch_size 70 | self.tubelet_size = tubelet_size 71 | 72 | self.proj = nn.Conv3d( 73 | in_channels=in_chans, 74 | out_channels=embed_dim, 75 | kernel_size=(tubelet_size, patch_size, patch_size), 76 | stride=(tubelet_size, patch_size, patch_size), 77 | ) 78 | 79 | def forward(self, x, **kwargs): 80 | B, C, T, H, W = x.shape 81 | x = self.proj(x).flatten(2).transpose(1, 2) 82 | return x 83 | -------------------------------------------------------------------------------- /src/vit_prisma/vjepa_hf/src/models/utils/pos_embs.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import numpy as np 9 | 10 | 11 | def get_3d_sincos_pos_embed( 12 | embed_dim, 13 | grid_size, 14 | grid_depth, 15 | cls_token=False, 16 | uniform_power=False 17 | ): 18 | """ 19 | grid_size: int of the grid height and width 20 | grid_depth: int of the grid depth 21 | returns: 22 | pos_embed: [grid_depth*grid_size*grid_size, embed_dim] (w/o cls_token) 23 | or [1+grid_depth*grid_size*grid_size, embed_dim] (w/ cls_token) 24 | """ 25 | grid_d = np.arange(grid_depth, dtype=float) 26 | grid_h = np.arange(grid_size, dtype=float) 27 | grid_w = np.arange(grid_size, dtype=float) 28 | grid_h, grid_d, grid_w = np.meshgrid(grid_h, grid_d, grid_w) # order of meshgrid is very important for indexing as [d,h,w] 29 | 30 | if not uniform_power: 31 | h_embed_dim = embed_dim // 4 32 | w_embed_dim = embed_dim // 4 33 | d_embed_dim = embed_dim // 2 34 | else: 35 | h_embed_dim = w_embed_dim = d_embed_dim = int(np.ceil(embed_dim/6)*2) 36 | 37 | emb_h = get_1d_sincos_pos_embed_from_grid(h_embed_dim, grid_h) # (T*H*W, D1) 38 | emb_w = get_1d_sincos_pos_embed_from_grid(w_embed_dim, grid_w) # (T*H*W, D2) 39 | emb_d = get_1d_sincos_pos_embed_from_grid(d_embed_dim, grid_d) # (T*H*W, D3) 40 | pos_embed = np.concatenate([emb_d, emb_h, emb_w], axis=1) 41 | pos_embed = pos_embed[:, :embed_dim] 42 | if cls_token: 43 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 44 | return pos_embed 45 | 46 | 47 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 48 | """ 49 | grid_size: int of the grid height and width 50 | returns: 51 | pos_embed: [grid_size*grid_size, embed_dim] (w/o cls_token) 52 | or [1+grid_size*grid_size, embed_dim] (w/ cls_token) 53 | """ 54 | grid_h = np.arange(grid_size, dtype=float) 55 | grid_w = np.arange(grid_size, dtype=float) 56 | grid_w, grid_h = np.meshgrid(grid_w, grid_h) # order of meshgrid is very important for indexing as [h, w] 57 | 58 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid_h) # (H*W, D/2) 59 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid_w) # (H*W, D/2) 60 | pos_embed = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 61 | if cls_token: 62 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 63 | return pos_embed 64 | 65 | 66 | def get_1d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 67 | """ 68 | embed_dim: output dimension for each position 69 | grid_size: int of the grid length 70 | returns: 71 | pos_embed: [grid_size, embed_dim] (w/o cls_token) 72 | or [1+grid_size, embed_dim] (w/ cls_token) 73 | """ 74 | grid = np.arange(grid_size, dtype=float) 75 | pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid) 76 | if cls_token: 77 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 78 | return pos_embed 79 | 80 | 81 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 82 | """ 83 | embed_dim: output dimension for each position 84 | pos: a list of positions to be encoded: size (M,) 85 | returns: (M, D) 86 | """ 87 | assert embed_dim % 2 == 0 88 | omega = np.arange(embed_dim // 2, dtype=float) 89 | omega /= embed_dim / 2. 90 | omega = 1. / 10000**omega # (D/2,) 91 | 92 | pos = pos.reshape(-1) # (M,) 93 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 94 | 95 | emb_sin = np.sin(out) # (M, D/2) 96 | emb_cos = np.cos(out) # (M, D/2) 97 | 98 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 99 | return emb 100 | -------------------------------------------------------------------------------- /src/vit_prisma/vjepa_hf/src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Prisma-Multimodal/ViT-Prisma/aa0a8d531e3bbbdffe37624bfc19a07feb2b8262/src/vit_prisma/vjepa_hf/src/utils/__init__.py -------------------------------------------------------------------------------- /src/vit_prisma/vjepa_hf/src/utils/cluster_processor/__init__.py: -------------------------------------------------------------------------------- 1 | from ._cluster_processor import ClusterProcessor 2 | from ._local_file_dataset import LocalFileDataset 3 | 4 | try: 5 | from ._rsc_cluster_processor import BlobstoreIterator, RSCClusterProcessor 6 | except ImportError: 7 | pass 8 | -------------------------------------------------------------------------------- /src/vit_prisma/vjepa_hf/src/utils/cluster_processor/_local_file_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | from vit_prisma.vjepa_hf.src.utils.logging import get_logger 5 | 6 | 7 | class LocalFileDataset: 8 | """An iterator over local video files, input as a list of file paths.""" 9 | 10 | def __init__( 11 | self, 12 | data_path: str, 13 | offset: int = 0, 14 | limit: int = -1, 15 | ext: str = ".mp4", 16 | id_key: str = "video_id", 17 | data_key: str = "video", 18 | ): 19 | self.logger = get_logger(self.__class__.__name__) 20 | self.data_path = data_path 21 | if offset < 0: 22 | raise ValueError("offset cannot be negative") 23 | self.offset = offset 24 | self.limit = limit 25 | self.flist = sorted(Path(data_path).rglob(f"*{ext}")) 26 | self.id_key = id_key 27 | self.data_key = data_key 28 | 29 | if self.limit > 0: 30 | self.flist = self.flist[self.offset : self.offset + self.limit] 31 | else: 32 | self.flist = self.flist[self.offset :] 33 | 34 | def __len__(self): 35 | return len(self.flist) 36 | 37 | def __getitem__(self, idx): 38 | local_file_path = str(self.flist[idx].resolve()) 39 | video_id = os.path.splitext(os.path.basename(local_file_path))[0] 40 | with open(local_file_path, "rb") as f: 41 | data = f.read() 42 | # Iterator format should be the same as for AIRStore 43 | return {self.id_key: video_id, self.data_key: data} 44 | -------------------------------------------------------------------------------- /src/vit_prisma/vjepa_hf/src/utils/distributed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import os 9 | from pathlib import Path 10 | 11 | import torch 12 | import torch.distributed as dist 13 | 14 | from vit_prisma.vjepa_hf.src.utils.logging import get_logger 15 | 16 | logger = get_logger() 17 | 18 | 19 | def init_distributed(port=37129, rank_and_world_size=(None, None)): 20 | 21 | if dist.is_available() and dist.is_initialized(): 22 | return dist.get_world_size(), dist.get_rank() 23 | 24 | rank, world_size = rank_and_world_size 25 | os.environ["MASTER_ADDR"] = "localhost" 26 | 27 | if (rank is None) or (world_size is None): 28 | try: 29 | world_size = int(os.environ["SLURM_NTASKS"]) 30 | rank = int(os.environ["SLURM_PROCID"]) 31 | os.environ["MASTER_ADDR"] = os.environ["HOSTNAME"] 32 | except Exception: 33 | logger.info("SLURM vars not set (distributed training not available)") 34 | world_size, rank = 1, 0 35 | return world_size, rank 36 | 37 | try: 38 | os.environ["MASTER_PORT"] = str(port) 39 | torch.distributed.init_process_group(backend="nccl", world_size=world_size, rank=rank) 40 | except Exception as e: 41 | world_size, rank = 1, 0 42 | logger.info(f"Rank: {rank}. Distributed training not available {e}") 43 | 44 | if "SLURM_JOB_ID" in os.environ: 45 | # Use the slurm_tmpdir (if it exists) instead of /tmp 46 | tmpdir = Path(f"/scratch/slurm_tmpdir/{os.environ['SLURM_JOB_ID']}") 47 | if tmpdir.exists(): 48 | os.environ["TMPDIR"] = str(tmpdir) 49 | 50 | return world_size, rank 51 | 52 | 53 | class AllGather(torch.autograd.Function): 54 | 55 | @staticmethod 56 | def forward(ctx, x): 57 | if dist.is_available() and dist.is_initialized() and (dist.get_world_size() > 1): 58 | x = x.contiguous() 59 | outputs = [torch.zeros_like(x) for _ in range(dist.get_world_size())] 60 | dist.all_gather(outputs, x) 61 | return torch.cat(outputs, 0) 62 | return x 63 | 64 | @staticmethod 65 | def backward(ctx, grads): 66 | if dist.is_available() and dist.is_initialized() and (dist.get_world_size() > 1): 67 | s = (grads.shape[0] // dist.get_world_size()) * dist.get_rank() 68 | e = (grads.shape[0] // dist.get_world_size()) * (dist.get_rank() + 1) 69 | grads = grads.contiguous() 70 | dist.all_reduce(grads) 71 | return grads[s:e] 72 | return grads 73 | 74 | 75 | class AllReduceSum(torch.autograd.Function): 76 | 77 | @staticmethod 78 | def forward(ctx, x): 79 | if dist.is_available() and dist.is_initialized() and (dist.get_world_size() > 1): 80 | x = x.contiguous() 81 | dist.all_reduce(x) 82 | return x 83 | 84 | @staticmethod 85 | def backward(ctx, grads): 86 | return grads 87 | 88 | 89 | class AllReduce(torch.autograd.Function): 90 | 91 | @staticmethod 92 | def forward(ctx, x): 93 | if dist.is_available() and dist.is_initialized() and (dist.get_world_size() > 1): 94 | x = x.contiguous() / dist.get_world_size() 95 | dist.all_reduce(x) 96 | return x 97 | 98 | @staticmethod 99 | def backward(ctx, grads): 100 | return grads 101 | -------------------------------------------------------------------------------- /src/vit_prisma/vjepa_hf/src/utils/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from vit_prisma.vjepa_hf.src.utils.tensors import trunc_normal_ 4 | 5 | 6 | def count_parameters(model): 7 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 8 | 9 | 10 | def init_weights(m): 11 | if isinstance(m, torch.nn.Linear): 12 | trunc_normal_(m.weight, std=0.02) 13 | if m.bias is not None: 14 | torch.nn.init.constant_(m.bias, 0) 15 | elif isinstance(m, torch.nn.LayerNorm): 16 | torch.nn.init.constant_(m.bias, 0) 17 | torch.nn.init.constant_(m.weight, 1.0) 18 | -------------------------------------------------------------------------------- /src/vit_prisma/vjepa_hf/src/utils/profiler.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | 5 | from vit_prisma.vjepa_hf.src.datasets.utils.dataloader import get_worker_info 6 | from vit_prisma.vjepa_hf.src.utils.logging import get_logger 7 | 8 | logger = get_logger("Profiler utils") 9 | 10 | 11 | def get_profiler_trace_handler(output_dir, rank, print_number_rows=30): 12 | 13 | def trace_handler(prof): 14 | if print_number_rows > 0: 15 | print(prof.key_averages(group_by_stack_n=5).table(sort_by="self_cuda_time_total", row_limit=30)) 16 | print(prof.key_averages(group_by_stack_n=5).table(sort_by="cuda_time_total", row_limit=30)) 17 | print(prof.key_averages(group_by_stack_n=5).table(sort_by="self_cpu_time_total", row_limit=30)) 18 | print(prof.key_averages(group_by_stack_n=5).table(sort_by="cpu_time_total", row_limit=30)) 19 | 20 | if output_dir is None: 21 | return 22 | 23 | # Save the trace to a file 24 | profiler_output_dir = os.path.join(output_dir, "profiler_trace") 25 | os.makedirs(profiler_output_dir, exist_ok=True) 26 | 27 | _, worker_id = get_worker_info() 28 | output_fpath = os.path.join(profiler_output_dir, f"rank_{rank}_worker{worker_id}.json") 29 | logger.info(f"Saving DATALOADER trace for rank {rank} to {output_fpath}") 30 | prof.export_chrome_trace(output_fpath) 31 | logger.info(f"Successfully saved DATALOADER profiler trace to {output_fpath}") 32 | 33 | return trace_handler 34 | 35 | 36 | def get_profiler_scheduler(scheduler_conf): 37 | """ 38 | Creates a profiler schedule based on the config. 39 | 40 | For details of the conif, please refer to: 41 | https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html#using-profiler-to-analyze-long-running-jobs 42 | """ 43 | if not scheduler_conf: 44 | return None 45 | 46 | logger.info(f"Profiler scheduler config: {scheduler_conf}") 47 | 48 | skip_first = scheduler_conf.get("skip_first", 4) 49 | wait = scheduler_conf.get("wait", 1) 50 | warmup = scheduler_conf.get("warmup", 1) 51 | active = scheduler_conf.get("active", 2) 52 | repeat = scheduler_conf.get("repeat", 2) 53 | return torch.profiler.schedule(skip_first=skip_first, warmup=warmup, wait=wait, active=active, repeat=repeat) 54 | 55 | 56 | def create_profiler(profiler_conf, folder, rank): 57 | if profiler_conf is None: 58 | return None 59 | 60 | profiler_activities = [] 61 | profiler_scheduler_conf = profiler_conf.get("scheduler", {}) 62 | 63 | if profiler_conf.get("cpu", False): 64 | profiler_activities.append(torch.profiler.ProfilerActivity.CPU) 65 | if profiler_conf.get("cuda", False): 66 | profiler_activities.append(torch.profiler.ProfilerActivity.CUDA) 67 | 68 | if folder: 69 | job_id = os.environ.get("SLURM_JOB_ID", None) 70 | folder = os.path.join(folder, f"dataloader_job_{job_id}") 71 | 72 | profiler = torch.profiler.profile( 73 | activities=profiler_activities, 74 | schedule=get_profiler_scheduler(profiler_scheduler_conf), 75 | on_trace_ready=get_profiler_trace_handler(folder, rank), 76 | record_shapes=profiler_conf.get("record_shapes", False), 77 | profile_memory=profiler_conf.get("profile_memory", True), 78 | use_cuda=profiler_conf.get("use_cuda", True), 79 | with_stack=profiler_conf.get("with_stack", True), 80 | ) 81 | return profiler 82 | -------------------------------------------------------------------------------- /src/vit_prisma/vjepa_hf/src/utils/schedulers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import math 9 | 10 | 11 | class WarmupCosineSchedule(object): 12 | 13 | def __init__(self, optimizer, warmup_steps, start_lr, ref_lr, T_max, last_epoch=-1, final_lr=0.0): 14 | self.optimizer = optimizer 15 | self.start_lr = start_lr 16 | self.ref_lr = ref_lr 17 | self.final_lr = final_lr 18 | self.warmup_steps = warmup_steps 19 | self.T_max = T_max - warmup_steps 20 | self._step = 0.0 21 | 22 | def step(self): 23 | self._step += 1 24 | if self._step < self.warmup_steps: 25 | progress = float(self._step) / float(max(1, self.warmup_steps)) 26 | new_lr = self.start_lr + progress * (self.ref_lr - self.start_lr) 27 | else: 28 | # -- progress after warmup 29 | progress = float(self._step - self.warmup_steps) / float(max(1, self.T_max)) 30 | new_lr = max( 31 | self.final_lr, 32 | self.final_lr + (self.ref_lr - self.final_lr) * 0.5 * (1.0 + math.cos(math.pi * progress)), 33 | ) 34 | 35 | for group in self.optimizer.param_groups: 36 | group["lr"] = new_lr 37 | 38 | return new_lr 39 | 40 | 41 | class CosineWDSchedule(object): 42 | 43 | def __init__(self, optimizer, ref_wd, T_max, final_wd=0.0): 44 | self.optimizer = optimizer 45 | self.ref_wd = ref_wd 46 | self.final_wd = final_wd 47 | self.T_max = T_max 48 | self._step = 0.0 49 | 50 | def step(self): 51 | self._step += 1 52 | progress = self._step / self.T_max 53 | new_wd = self.final_wd + (self.ref_wd - self.final_wd) * 0.5 * (1.0 + math.cos(math.pi * progress)) 54 | 55 | if self.final_wd <= self.ref_wd: 56 | new_wd = max(self.final_wd, new_wd) 57 | else: 58 | new_wd = min(self.final_wd, new_wd) 59 | 60 | for group in self.optimizer.param_groups: 61 | if ("WD_exclude" not in group) or not group["WD_exclude"]: 62 | group["weight_decay"] = new_wd 63 | return new_wd 64 | -------------------------------------------------------------------------------- /src/vit_prisma/vjepa_hf/src/utils/temp_ssd_dir.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import lru_cache 3 | 4 | from vit_prisma.vjepa_hf.src.utils.logging import get_logger 5 | 6 | LOGGER = get_logger("temp_ssd_dir") 7 | 8 | 9 | @lru_cache 10 | def temp_ssd_dir(): 11 | job_id = os.environ.get("SLURM_JOB_ID", None) 12 | if job_id: 13 | slurm_ssd_temp_path = os.path.join("/scratch/slurm_tmpdir", job_id) 14 | if os.path.exists(slurm_ssd_temp_path): 15 | LOGGER.info(f"Using existing slurm temp dir: {slurm_ssd_temp_path}") 16 | return slurm_ssd_temp_path 17 | LOGGER.info("Using the default OS path for temp dir") 18 | -------------------------------------------------------------------------------- /src/vit_prisma/vjepa_hf/src/utils/tensors.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import math 9 | from logging import getLogger 10 | 11 | import torch 12 | 13 | logger = getLogger() 14 | 15 | 16 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 17 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 18 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 19 | def norm_cdf(x): 20 | # Computes standard normal cumulative distribution function 21 | return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 22 | 23 | with torch.no_grad(): 24 | # Values are generated by using a truncated uniform distribution and 25 | # then using the inverse CDF for the normal distribution. 26 | # Get upper and lower cdf values 27 | lower = norm_cdf((a - mean) / std) 28 | upper = norm_cdf((b - mean) / std) 29 | 30 | # Uniformly fill tensor with values from [lower, upper], then translate to 31 | # [2*lower-1, 2*upper-1]. 32 | tensor.uniform_(2 * lower - 1, 2 * upper - 1) 33 | 34 | # Use inverse cdf transform for normal distribution to get truncated 35 | # standard normal 36 | tensor.erfinv_() 37 | 38 | # Transform to proper mean, std 39 | tensor.mul_(std * math.sqrt(2.0)) 40 | tensor.add_(mean) 41 | 42 | # Clamp to ensure it's in the proper range 43 | tensor.clamp_(min=a, max=b) 44 | return tensor 45 | 46 | 47 | def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): 48 | # type: (Tensor, float, float, float, float) -> Tensor 49 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 50 | 51 | 52 | def apply_masks(x, masks): 53 | """ 54 | :param x: tensor of shape [B (batch-size), N (num-patches), D (feature-dim)] 55 | :param masks: list of tensors containing indices of patches [0,N) to keep 56 | """ 57 | all_x = [] 58 | for m in masks: 59 | mask_keep = m.unsqueeze(-1).repeat(1, 1, x.size(-1)) 60 | all_x += [torch.gather(x, dim=1, index=mask_keep)] 61 | return torch.cat(all_x, dim=0) 62 | 63 | 64 | def repeat_interleave_batch(x, B, repeat): 65 | N = len(x) // B 66 | x = torch.cat([torch.cat([x[i * B : (i + 1) * B] for _ in range(repeat)], dim=0) for i in range(N)], dim=0) 67 | return x 68 | -------------------------------------------------------------------------------- /src/vit_prisma/vjepa_hf/test_vjepa.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import requests 3 | from vit_prisma.vjepa_hf.modeling_vjepa import VJEPAModel, VJEPAImageProcessor 4 | from vit_prisma.vjepa_hf.configs import CONFIGS 5 | import yaml 6 | 7 | import torch 8 | 9 | from vit_prisma.models.base_vit import HookedViT 10 | 11 | 12 | def prepare_img(): 13 | url = "http://images.cocodataset.org/val2017/000000039769.jpg" 14 | image = Image.open(requests.get(url, stream=True).raw).convert("RGB") 15 | return image 16 | 17 | 18 | def check_matching(outputs, outputs_hooked, tolerance=1e-3): 19 | # Compare the outputs 20 | diff = torch.abs(outputs_hooked - outputs) 21 | max_diff = torch.max(diff) 22 | if max_diff < tolerance: 23 | print("Outputs match within tolerance") 24 | else: 25 | print(f"Outputs do not match within tolerance. Max difference: {max_diff.item()}") 26 | # Alternatively, you can use torch.allclose() function 27 | if torch.allclose(outputs_hooked, outputs, atol=tolerance): 28 | print("Outputs match within tolerance") 29 | else: 30 | print("Outputs do not match within tolerance") 31 | # You can also calculate the mean squared error (MSE) or mean absolute error (MAE) 32 | mse = torch.mean((outputs_hooked - outputs) ** 2) 33 | mae = torch.mean(torch.abs(outputs_hooked - outputs)) 34 | print(f"MSE: {mse.item():.6f}") 35 | print(f"MAE: {mae.item():.6f}") 36 | 37 | diff = torch.abs(outputs_hooked - outputs) 38 | max_diff = torch.max(diff) 39 | # Get the indices of the elements that don't match 40 | print("number of total indices", len(diff.flatten())) 41 | non_match_indices = torch.nonzero(diff > tolerance) 42 | if non_match_indices.numel() > 0: 43 | print(f"Number of non-matching indices: {non_match_indices.shape[0]}") 44 | print("Non-matching indices:") 45 | # for i in range(non_match_indices.shape[0]): 46 | # batch_idx, seq_idx, embed_idx = non_match_indices[i] 47 | # print(f"Batch index: {batch_idx.item()}, Sequence index: {seq_idx.item()}, Embedding index: {embed_idx.item()}") 48 | else: 49 | print("All elements match within tolerance") 50 | 51 | def run_imagenet(): 52 | # run for original model 53 | 54 | 55 | return 56 | 57 | 58 | # 1. LOAD HOOKED MODEL 59 | og_model_name = "vjepa_v1_vit_huge" 60 | print("Loading hooked model") 61 | hooked_model = HookedViT.from_pretrained(og_model_name, is_timm=False, is_clip=False, fold_ln=False) # in future, do all models 62 | print("hooked model config", hooked_model.cfg) 63 | 64 | # 2. LOAD ORIGINAL MODEL 65 | model_name = "vjepa_v1_vit_huge" 66 | config = CONFIGS["v1"]["vit_h"] 67 | model_paths = yaml.safe_load(open('paths_cw.yaml')) 68 | model_path = model_paths[model_name]["loc"] 69 | model = VJEPAModel.from_pretrained(model_path) 70 | 71 | print("model config", model.config) 72 | 73 | # 3. LOAD INPUT FOR BOTH MODELS 74 | image = prepare_img() 75 | processor = VJEPAImageProcessor(crop_size=config.crop_size) 76 | pixel_values = processor(image, return_tensors="pt").pixel_values 77 | pixel_values = pixel_values.repeat(1, 16, 1, 1, 1) # repeating image 16 times for now 78 | pixel_values = pixel_values.permute(0, 2, 1, 3, 4) # B x C x T x H x W 79 | 80 | # 4. FEED INTO BOTH MODELS 81 | outputs_hooked = hooked_model(pixel_values) 82 | print(outputs_hooked) 83 | print("Hooked outputs shape", outputs_hooked.shape) 84 | 85 | outputs = model(pixel_values) 86 | print(outputs) 87 | print("Outputs shape:", outputs[0].shape) 88 | 89 | # 5. CHECK IF OUTPUTS MATCH 90 | check_matching(outputs[0], outputs_hooked) 91 | 92 | # 6. RUN IMAGENET 93 | -------------------------------------------------------------------------------- /src/vit_prisma/vjepa_hf/train_jepa_sae.py: -------------------------------------------------------------------------------- 1 | from vit_prisma.vjepa_hf.configs.sae_jepa_config import JEPABaseConfig 2 | from vit_prisma.sae.train_sae import VisionSAETrainer 3 | 4 | # # Load HookedViT 5 | 6 | def load_hooked_model(og_model_name = "vjepa_v1_vit_huge"): 7 | from vit_prisma.models.base_vit import HookedViT 8 | # print("Loading hooked model") 9 | hooked_model = HookedViT.from_pretrained(og_model_name, is_timm=False, is_clip=False, fold_ln=False) # in future, do all models 10 | # print("hooked model config", hooked_model.cfg) 11 | return hooked_model 12 | 13 | # pr 14 | 15 | hooked_model = load_hooked_model() 16 | 17 | 18 | cfg = JEPABaseConfig() 19 | 20 | print(cfg) 21 | 22 | 23 | trainer = VisionSAETrainer(cfg) 24 | sae = trainer.run() 25 | 26 | 27 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Prisma-Multimodal/ViT-Prisma/aa0a8d531e3bbbdffe37624bfc19a07feb2b8262/tests/__init__.py -------------------------------------------------------------------------------- /tests/models/test_loading_timm.py: -------------------------------------------------------------------------------- 1 | import timm 2 | import torch 3 | 4 | from vit_prisma.models.base_vit import HookedViT 5 | 6 | 7 | #currently only vit_base_patch16_224 supported (config loading issue) 8 | def test_loading_timm(): 9 | TOLERANCE = 1e-4 10 | 11 | model_name = "vit_base_patch16_224" 12 | batch_size = 5 13 | channels = 3 14 | height = 224 15 | width = 224 16 | device = "cpu" 17 | 18 | hooked_model = HookedViT.from_pretrained(model_name) 19 | hooked_model.to(device) 20 | timm_model = timm.create_model(model_name, pretrained=True) 21 | timm_model.to(device) 22 | 23 | with torch.random.fork_rng(): 24 | torch.manual_seed(1) 25 | input_image = torch.rand((batch_size, channels, height, width)).to(device) 26 | 27 | hooked_output, timm_output = hooked_model(input_image), timm_model(input_image) 28 | 29 | assert torch.allclose(hooked_output, timm_output, atol=TOLERANCE), f"Model output diverges! Max diff: {torch.max(torch.abs(hooked_output - timm_output))}" 30 | 31 | 32 | test_loading_timm() 33 | -------------------------------------------------------------------------------- /tests/models/test_loading_vit_for_image_classification.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from vit_prisma.models.base_vit import HookedViT 4 | import torch 5 | 6 | from transformers import ViTForImageClassification 7 | 8 | 9 | def test_loading_vit_for_image_classification(): 10 | TOLERANCE = 1e-4 11 | model_name="google/vit-base-patch16-224" 12 | 13 | batch_size = 5 14 | channels = 3 15 | height = 224 16 | width = 224 17 | device = "cpu" 18 | 19 | hooked_model = HookedViT.from_pretrained(model_name=model_name, is_timm=False) 20 | hooked_model.to(device) 21 | vit_model = ViTForImageClassification.from_pretrained(model_name) 22 | vit_model.to(device) 23 | 24 | with torch.random.fork_rng(): 25 | torch.manual_seed(1) 26 | input_image = torch.rand((batch_size, channels, height, width)).to(device) 27 | 28 | hooked_output, vit_output = hooked_model(input_image), vit_model(input_image).logits 29 | 30 | assert torch.allclose(hooked_output, vit_output, atol=TOLERANCE), f"Model output diverges! Max diff: {torch.max(torch.abs(hooked_output - vit_output))}" 31 | -------------------------------------------------------------------------------- /tests/models/test_loading_vivet.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from vit_prisma.models.base_vit import HookedViT 4 | import numpy as np 5 | import torch 6 | 7 | from transformers import VivitForVideoClassification 8 | 9 | 10 | 11 | def test_loading_vivet(): 12 | TOLERANCE = 1e-4 13 | 14 | model_name = "google/vivit-b-16x2-kinetics400" 15 | batch_size = 5 16 | channels = 3 17 | num_frames = 32 18 | height = 224 19 | width = 224 20 | device = "cpu" 21 | 22 | hooked_model = HookedViT.from_pretrained(model_name, is_timm=False) 23 | hooked_model.to(device) 24 | google_model = VivitForVideoClassification.from_pretrained(model_name) 25 | google_model.to(device) 26 | 27 | with torch.random.fork_rng(): 28 | torch.manual_seed(1) 29 | input_image = torch.rand((batch_size, num_frames, channels, height, width)).to(device) 30 | with torch.no_grad(): 31 | hooked_output, timm_output = hooked_model(input_image), google_model(input_image).logits 32 | 33 | assert torch.allclose(hooked_output, timm_output, atol=TOLERANCE), f"Model output diverges! Max diff: {torch.max(torch.abs(hooked_output - timm_output))}" 34 | 35 | 36 | -------------------------------------------------------------------------------- /tests/models/test_models.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | from vit_prisma.models.layers.attention import Attention 4 | from vit_prisma.models.layers.mlp import MLP 5 | from vit_prisma.models.layers.transformer_block import TransformerBlock 6 | from vit_prisma.models.base_vit import HookedViT 7 | from vit_prisma.models.layers.patch_embedding import PatchEmbedding 8 | 9 | import logging 10 | 11 | import torch.nn as nn 12 | 13 | from vit_prisma.configs.HookedViTConfig import HookedViTConfig 14 | 15 | 16 | def get_test_config(): 17 | n_layers = 1 18 | d_head = 8 19 | d_model = 8 20 | d_mlp=8 21 | return HookedViTConfig(n_layers=n_layers, d_head=d_head, d_model=d_model, d_mlp=d_mlp) 22 | 23 | class TestAttention(unittest.TestCase): 24 | def test_attention(self): 25 | config = get_test_config() 26 | attention = Attention(config) 27 | x = torch.randn(8, 16, config.d_model) 28 | 29 | # Test forward pass 30 | output = attention(x, x, x) 31 | self.assertEqual(output.shape, (8, 16, config.d_model)) 32 | 33 | 34 | class TestMLP(unittest.TestCase): 35 | def test_mlp(self): 36 | config = get_test_config() 37 | mlp = MLP(config) 38 | x = torch.randn(8, 16, config.d_model) 39 | 40 | # Test forward pass 41 | output = mlp(x) 42 | self.assertEqual(output.shape, (8, 16, config.d_model)) 43 | 44 | class TestTransformerBlock(unittest.TestCase): 45 | def test_transformer_block(self): 46 | config = get_test_config() 47 | transformer_block = TransformerBlock(config) 48 | x = torch.randn(8, 16, config.d_model) 49 | 50 | # Test forward pass 51 | output = transformer_block(x) 52 | self.assertEqual(output.shape, (8, 16, config.d_model)) 53 | 54 | class TestHookedViT(unittest.TestCase): 55 | def test_hooked_vit(self): 56 | config = get_test_config() 57 | config.return_type = "class_logits" 58 | model = HookedViT(config) 59 | 60 | x = torch.randn(8, config.n_channels, config.image_size, config.image_size) 61 | 62 | # Test forward pass 63 | output = model(x) 64 | self.assertEqual(output.shape, (8, config.n_classes)) 65 | 66 | class TestPatchEmbedding(unittest.TestCase): 67 | def setUp(self): 68 | logging.basicConfig(level=logging.INFO) 69 | self.logger = logging.getLogger(__name__) 70 | 71 | def test_patch_embedding(self): 72 | config = get_test_config() 73 | patch_embedding = PatchEmbedding(config, self.logger) 74 | 75 | x = torch.randn(8, config.n_channels, config.image_size, config.image_size) 76 | 77 | 78 | # Calculate expected number of patches 79 | num_patches = (config.image_size // config.patch_size) ** 2 80 | 81 | # Test forward pass 82 | output = patch_embedding(x) 83 | self.assertEqual(output.shape, (8, num_patches, config.d_model)) 84 | 85 | 86 | if __name__ == '__main__': 87 | unittest.main() 88 | -------------------------------------------------------------------------------- /tests/sae/test_activation_cache_inference_speed.py: -------------------------------------------------------------------------------- 1 | 2 | from vit_prisma.sae.config import VisionModelSAERunnerConfig 3 | from vit_prisma.sae.train_sae import VisionSAETrainer 4 | import time 5 | from tqdm import tqdm 6 | 7 | # @TODO DEBUG what the hell is going on with prsima updates not loading in apptainer 8 | 9 | 10 | def iterate_through_dataset(trainer): 11 | start = time.time() 12 | n_training_tokens = 0 13 | # pbar = tqdm(total=trainer.cfg.total_training_tokens, desc="Testing Activation Storage Inference Times") 14 | while n_training_tokens < trainer.cfg.total_training_tokens: 15 | _ = trainer.activations_store.next_batch() 16 | n_training_tokens += trainer.cfg.train_batch_size 17 | # pbar.update(trainer.cfg.train_batch_size) 18 | end = time.time() 19 | return end - start 20 | 21 | 22 | 23 | cfg = VisionModelSAERunnerConfig.load_config("conf_test.json") 24 | cfg.use_cached_activations = True 25 | 26 | trainer = VisionSAETrainer(cfg) 27 | 28 | precomputed_time = iterate_through_dataset(trainer) 29 | 30 | cfg.use_cached_activations = False 31 | trainer.activations_store = trainer.initialize_activations_store( 32 | trainer.dataset, trainer.eval_dataset 33 | ) 34 | on_demand_calc_time = iterate_through_dataset(trainer) 35 | 36 | 37 | print(f"On Demand took {on_demand_calc_time:.2f} seconds for {cfg.num_epochs}") 38 | print(f"Precomputed took {precomputed_time:.2f} seconds for {cfg.num_epochs}") -------------------------------------------------------------------------------- /tests/sae/test_load_VisionModelSAERunnerConfig.py: -------------------------------------------------------------------------------- 1 | from vit_prisma.sae.config import VisionModelSAERunnerConfig 2 | 3 | 4 | def test_load_VisionModelSAERunnerConfig(): 5 | 6 | # Create a configuration instance 7 | config = VisionModelSAERunnerConfig() 8 | 9 | # Save the configuration to a file 10 | config.save_config("config.json") 11 | 12 | # Load the configuration from the file 13 | loaded_config = VisionModelSAERunnerConfig.load_config("config.json") 14 | 15 | # Verify that the loaded configuration matches the original 16 | assert config == loaded_config 17 | 18 | # Optionally, print the configuration 19 | config.pretty_print() 20 | -------------------------------------------------------------------------------- /tests/sae/test_sae_training.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import freeze_support 2 | 3 | from vit_prisma.sae.config import VisionModelSAERunnerConfig 4 | from vit_prisma.sae.train_sae import VisionSAETrainer 5 | from vit_prisma.utils.constants import DATA_DIR 6 | 7 | 8 | def test_train_sae(): 9 | freeze_support() 10 | 11 | cfg = VisionModelSAERunnerConfig() 12 | print("Config created") 13 | 14 | cfg.model_name = "open-clip:laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K" 15 | cfg.d_in = 768 16 | cfg.lr = 0.001 17 | cfg.l1_coefficient = 0.00008 18 | 19 | cfg.dataset_name = "cifar10" 20 | cfg.dataset_path = str(DATA_DIR / "cifar-test") 21 | cfg.dataset_train_path = str(DATA_DIR / "cifar-test") 22 | cfg.dataset_val_path = str(DATA_DIR / "cifar-test") 23 | 24 | cfg.log_to_wandb = False 25 | 26 | cfg.n_checkpoints = 0 27 | cfg.n_validation_runs = 0 28 | cfg.num_epochs = 1 29 | cfg.total_training_images = 100 30 | cfg.total_training_tokens = cfg.total_training_images * cfg.context_size 31 | 32 | trainer = VisionSAETrainer(cfg) 33 | trainer.run() 34 | -------------------------------------------------------------------------------- /tests/test_cache_hook_names.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from vit_prisma.models.base_vit import HookedViT 5 | from vit_prisma.configs.HookedViTConfig import HookedViTConfig 6 | 7 | # Test taken from transformerlens with minor modifications 8 | 9 | batch_size = 2 10 | channels = 3 11 | height = 224 12 | width = 224 13 | input_images = torch.rand((batch_size, channels, height, width)) 14 | 15 | # instantiate a untrained network 16 | n_layers = 1 17 | d_head = 8 18 | d_model = 8 19 | d_mlp=8 20 | 21 | model = HookedViT(HookedViTConfig(n_layers,d_head,d_model, d_mlp, return_type="logits", activation_name="solu_ln")) 22 | 23 | act_names_in_cache = [ 24 | "hook_embed", 25 | "hook_pos_embed", 26 | "hook_full_embed", 27 | "blocks.0.hook_resid_pre", 28 | "blocks.0.ln1.hook_scale", 29 | "blocks.0.ln1.hook_normalized", 30 | "blocks.0.attn.hook_q", 31 | "blocks.0.attn.hook_k", 32 | "blocks.0.attn.hook_v", 33 | "blocks.0.attn.hook_attn_scores", 34 | "blocks.0.attn.hook_pattern", 35 | "blocks.0.attn.hook_z", 36 | "blocks.0.hook_attn_out", 37 | "blocks.0.hook_resid_mid", 38 | "blocks.0.ln2.hook_scale", 39 | "blocks.0.ln2.hook_normalized", 40 | "blocks.0.mlp.hook_pre", 41 | "blocks.0.mlp.hook_mid", 42 | "blocks.0.mlp.ln.hook_scale", 43 | "blocks.0.mlp.ln.hook_normalized", 44 | "blocks.0.mlp.hook_post", 45 | "blocks.0.hook_mlp_out", 46 | "blocks.0.hook_resid_post", 47 | "ln_final.hook_scale", 48 | "ln_final.hook_normalized", 49 | "hook_ln_final", 50 | "hook_post_head_pre_normalize", 51 | ] 52 | 53 | 54 | def test_cache_hook_names(): 55 | _, cache = model.run_with_cache(input_images) 56 | assert list(cache.keys()) == act_names_in_cache -------------------------------------------------------------------------------- /tests/test_text_encoder/test_diagram_img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Prisma-Multimodal/ViT-Prisma/aa0a8d531e3bbbdffe37624bfc19a07feb2b8262/tests/test_text_encoder/test_diagram_img.png -------------------------------------------------------------------------------- /tests/test_weight_properties.py: -------------------------------------------------------------------------------- 1 | from vit_prisma.configs.HookedViTConfig import HookedViTConfig 2 | from vit_prisma.models.base_vit import HookedViT 3 | 4 | # instantiate untrained networks to check dimensions 5 | conf_1 = HookedViTConfig( 6 | n_layers=3, 7 | d_head=32, 8 | d_model=64, 9 | d_mlp=128, 10 | n_heads=2, 11 | patch_size=4, 12 | ) 13 | 14 | model_1 = HookedViT(conf_1) 15 | 16 | conf_2 = HookedViTConfig( 17 | n_layers=2, 18 | d_head=16, 19 | d_model=128, 20 | d_mlp=300, 21 | n_heads=8, 22 | patch_size=16, 23 | ) 24 | 25 | model_2 = HookedViT(conf_2) 26 | 27 | 28 | def test_weight_property_shapes(): 29 | for model, conf in [(model_1, conf_1), (model_2, conf_2)]: 30 | # check shapes of weights 31 | assert model.W_Q.shape == ( 32 | conf.n_layers, 33 | conf.n_heads, 34 | conf.d_model, 35 | conf.d_head, 36 | ) 37 | assert model.W_K.shape == ( 38 | conf.n_layers, 39 | conf.n_heads, 40 | conf.d_model, 41 | conf.d_head, 42 | ) 43 | assert model.W_V.shape == ( 44 | conf.n_layers, 45 | conf.n_heads, 46 | conf.d_model, 47 | conf.d_head, 48 | ) 49 | assert model.W_O.shape == ( 50 | conf.n_layers, 51 | conf.n_heads, 52 | conf.d_head, 53 | conf.d_model, 54 | ) 55 | assert model.W_in.shape == ( 56 | conf.n_layers, 57 | conf.d_model, 58 | conf.d_mlp, 59 | ) 60 | assert model.W_out.shape == ( 61 | conf.n_layers, 62 | conf.d_mlp, 63 | conf.d_model, 64 | ) 65 | assert model.W_E.shape == ( 66 | conf.d_model, 67 | conf.n_channels, 68 | conf.patch_size, 69 | conf.patch_size, 70 | ) 71 | assert model.W_H.shape == (conf.d_model, conf.n_classes) 72 | 73 | # check shapes of biases 74 | assert model.b_Q.shape == ( 75 | conf.n_layers, 76 | conf.n_heads, 77 | conf.d_head, 78 | ) 79 | assert model.b_K.shape == ( 80 | conf.n_layers, 81 | conf.n_heads, 82 | conf.d_head, 83 | ) 84 | assert model.b_V.shape == ( 85 | conf.n_layers, 86 | conf.n_heads, 87 | conf.d_head, 88 | ) 89 | assert model.b_O.shape == ( 90 | conf.n_layers, 91 | conf.d_model, 92 | ) 93 | assert model.b_in.shape == ( 94 | conf.n_layers, 95 | conf.d_mlp, 96 | ) 97 | assert model.b_out.shape == ( 98 | conf.n_layers, 99 | conf.d_model, 100 | ) 101 | assert model.b_H.shape == (conf.n_classes,) 102 | --------------------------------------------------------------------------------